1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module};
3use mistralrs_quant::ShardedVarBuilder;
4use std::ops::Mul;
5
6use crate::{
7 layers::{self, conv2d, embedding, layer_norm, Activation, CausalMasker, MatMul},
8 utils::unvarbuilder::UnVarBuilder,
9};
10
11use super::config::{Idefics3Config, Idefics3VisionConfig};
12
13pub(crate) struct Idefics3SimpleMLP {
14 pub(crate) proj: Linear,
15}
16
17impl Idefics3SimpleMLP {
18 pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
19 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
20 let out_dim = cfg.text_config.hidden_size;
21 Ok(Self {
22 proj: layers::linear_no_bias(in_dim, out_dim, vb.pp("proj"))?,
23 })
24 }
25
26 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
27 x.apply(&self.proj)
28 }
29}
30
31pub struct Idefics3Connector {
32 scale_factor: usize,
33 pub(crate) modality_projection: Idefics3SimpleMLP,
34}
35
36impl Idefics3Connector {
37 pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
38 Ok(Self {
39 scale_factor: cfg.scale_factor,
40 modality_projection: Idefics3SimpleMLP::new(cfg, vb.pp("modality_projection"))?,
41 })
42 }
43
44 pub fn pixel_shuffle(&self, x: &Tensor, scale_factor: usize) -> Result<Tensor> {
45 let (bs, seq, embed_dim) = x.dims3()?;
46 let height = (seq as f32).sqrt() as usize;
47 let width = height;
48 let mut x = x.reshape((bs, height, width, embed_dim))?;
49 x = x.reshape((bs, height, width / scale_factor, embed_dim * scale_factor))?;
50 x = x.permute((0, 2, 1, 3))?;
51 x = x.reshape((
52 bs,
53 width / scale_factor,
54 height / scale_factor,
55 embed_dim * scale_factor.pow(2),
56 ))?;
57 x = x.permute((0, 2, 1, 3))?;
58 x.reshape((
59 bs,
60 (seq as f32 / scale_factor.pow(2) as f32) as usize,
61 embed_dim * scale_factor.pow(2),
62 ))
63 }
64
65 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
66 let image_hidden_states = self.pixel_shuffle(x, self.scale_factor)?;
67 self.modality_projection.forward(&image_hidden_states)
68 }
69}
70
71struct VisionEmbeddings {
72 patch_size: usize,
73 patch_embedding: Conv2d,
74 num_patches_per_side: usize,
75 position_embedding: Embedding,
76}
77
78fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
81 use std::cmp::Ordering;
82
83 let mut result = Vec::with_capacity(xs.len());
84
85 for &x in xs {
86 let idx = match boundaries.binary_search_by(|&val| {
95 val.partial_cmp(&x).unwrap_or(Ordering::Less)
98 }) {
99 Ok(i) => i,
100 Err(i) => i,
101 };
102
103 result.push(idx as u32);
104 }
105
106 Tensor::from_vec(result, (xs.len(),), device)
107}
108
109impl VisionEmbeddings {
110 fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
111 let conv_config = Conv2dConfig {
112 stride: config.patch_size,
113 ..Default::default()
114 };
115 let patch_embedding = conv2d(
116 config.num_channels,
117 config.hidden_size,
118 config.patch_size,
119 conv_config,
120 vb.pp("patch_embedding"),
121 )?;
122 let num_patches_per_side = config.image_size / config.patch_size;
123 let num_patches = num_patches_per_side.pow(2);
124 Ok(Self {
125 patch_size: config.patch_size,
126 patch_embedding,
127 num_patches_per_side,
128 position_embedding: embedding(
129 num_patches,
130 config.hidden_size,
131 vb.pp("position_embedding"),
132 )?,
133 })
134 }
135
136 fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
137 let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
138
139 let patch_embeds = self.patch_embedding.forward(pixel_values)?;
140
141 let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
142
143 let (max_nb_patches_h, max_nb_patches_w) =
144 (max_im_h / self.patch_size, max_im_w / self.patch_size);
145 let boundaries = Tensor::arange_step(
146 1.0 / self.num_patches_per_side as f32,
147 1.0,
148 1.0 / self.num_patches_per_side as f32,
149 pixel_values.device(),
150 )?
151 .to_vec1::<f32>()?;
152 let position_ids = Tensor::full(
153 0u32,
154 (bs, max_nb_patches_h * max_nb_patches_w),
155 pixel_values.device(),
156 )?;
157
158 let mut new_position_ids = Vec::new();
159 for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
160 let p_attn_mask = p_attn_mask.squeeze(0)?;
161 let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
162 let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
163
164 let fractional_coords_h = Tensor::arange_step(
165 0.0,
166 1.0 - 1e-6,
167 1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
168 pixel_values.device(),
169 )?
170 .to_vec1::<f32>()?;
171 let fractional_coords_w = Tensor::arange_step(
172 0.0,
173 1.0 - 1e-6,
174 1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
175 pixel_values.device(),
176 )?
177 .to_vec1::<f32>()?;
178
179 let bucket_coords_h =
180 bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
181 let bucket_coords_w =
182 bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
183
184 let pos_ids = bucket_coords_h
185 .unsqueeze(D::Minus1)?
186 .mul(self.num_patches_per_side as f64)?
187 .broadcast_add(&bucket_coords_w)?
188 .flatten_all()?
189 .to_vec1::<u32>()?;
190
191 let true_indices = p_attn_mask
192 .flatten_all()?
193 .to_vec1::<u8>()?
194 .iter()
195 .enumerate()
196 .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
197 .collect::<Vec<_>>();
198 let position_ids_b = position_ids.i(b_idx)?;
199
200 let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
201 let new_position_ids_b_len = new_position_ids_b.len();
202 for (i, true_idx) in true_indices.into_iter().enumerate() {
203 new_position_ids_b[true_idx] = pos_ids[i];
204 }
205
206 new_position_ids.push(Tensor::from_vec(
207 new_position_ids_b,
208 new_position_ids_b_len,
209 pixel_values.device(),
210 )?);
211 }
212 let position_ids = Tensor::stack(&new_position_ids, 0)?;
213 let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
214 embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
215 }
216
217 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
218 let uvb = UnVarBuilder::new();
219
220 uvb.pp("patch_embedding").add(&self.patch_embedding);
221 uvb.pp("position_embedding").add(&self.position_embedding);
222
223 uvb.to_safetensors()
224 }
225}
226
227struct Attention {
228 embed_dim: usize,
229 num_heads: usize,
230 head_dim: usize,
231 scale: f64,
232 q_proj: Linear,
233 k_proj: Linear,
234 v_proj: Linear,
235 o_proj: Linear,
236 neg_inf: Tensor,
237}
238
239impl Attention {
240 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
241 let embed_dim = config.hidden_size;
242 let num_heads = config.num_attention_heads;
243 let head_dim = embed_dim / num_heads;
244 let scale = 1.0 / (head_dim as f64).sqrt();
245
246 let q_proj = layers::linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
247 let k_proj = layers::linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
248 let v_proj = layers::linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
249 let o_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
250
251 Ok(Self {
252 embed_dim,
253 num_heads,
254 head_dim,
255 scale,
256 q_proj,
257 k_proj,
258 v_proj,
259 o_proj,
260 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
261 })
262 }
263
264 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
265 let (b_sz, q_len, _) = xs.dims3()?;
266
267 let mut q = self.q_proj.forward(xs)?;
268 let mut k = self.k_proj.forward(xs)?;
269 let mut v = self.v_proj.forward(xs)?;
270
271 q = q
272 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
273 .transpose(1, 2)?;
274 k = k
275 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
276 .transpose(1, 2)?;
277 v = v
278 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
279 .transpose(1, 2)?;
280
281 let attn_weights =
282 (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
283
284 let mut attn_weights = CausalMasker.apply_mask_one_and_zero(
285 &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
286 attn_weights,
287 &self.neg_inf,
288 )?;
289 attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
290 let attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
291
292 attn_output
293 .transpose(1, 2)?
294 .reshape((b_sz, q_len, self.embed_dim))?
295 .apply(&self.o_proj)
296 }
297
298 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
299 let uvb = UnVarBuilder::new();
300
301 uvb.pp("q_proj").add(&self.q_proj);
302 uvb.pp("k_proj").add(&self.k_proj);
303 uvb.pp("v_proj").add(&self.v_proj);
304 uvb.pp("out_proj").add(&self.o_proj);
305
306 uvb.to_safetensors()
307 }
308}
309
310struct VisionMLP {
311 activation: Activation,
312 fc1: Linear,
313 fc2: Linear,
314}
315
316impl VisionMLP {
317 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
318 let fc1 = layers::linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
319 let fc2 = layers::linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
320 Ok(Self {
321 activation: config.hidden_act,
322 fc1,
323 fc2,
324 })
325 }
326
327 fn forward(&self, x: &Tensor) -> Result<Tensor> {
328 let mut x = self.fc1.forward(x)?;
329 x = self.activation.forward(&x)?;
330 self.fc2.forward(&x)
331 }
332
333 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
334 let uvb = UnVarBuilder::new();
335
336 uvb.pp("fc1").add(&self.fc1);
337 uvb.pp("fc2").add(&self.fc2);
338
339 uvb.to_safetensors()
340 }
341}
342
343struct EncoderLayer {
344 mlp: VisionMLP,
345 attn: Attention,
346 layer_norm_1: LayerNorm,
347 layer_norm_2: LayerNorm,
348}
349
350impl EncoderLayer {
351 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
352 let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
353 let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
354 let layer_norm_1 = layer_norm(
355 config.hidden_size,
356 config.layer_norm_eps,
357 vb.pp("layer_norm1"),
358 )?;
359 let layer_norm_2 = layer_norm(
360 config.hidden_size,
361 config.layer_norm_eps,
362 vb.pp("layer_norm2"),
363 )?;
364 Ok(Self {
365 mlp,
366 attn,
367 layer_norm_1,
368 layer_norm_2,
369 })
370 }
371
372 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
373 let residual = xs.clone();
374
375 let hidden_states = self.layer_norm_1.forward(xs)?;
376 let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
377 let hidden_states = (hidden_states + residual)?;
378
379 let residual = &hidden_states;
380 let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
381 let hidden_states = self.mlp.forward(&hidden_states)?;
382 hidden_states + residual
383 }
384}
385
386struct Encoder {
387 layers: Vec<EncoderLayer>,
388}
389
390impl Encoder {
391 fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
392 let mut layers = Vec::new();
393 let vb_l = vb.pp("layers");
394 for i in 0..config.num_hidden_layers {
395 layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
396 }
397 Ok(Self { layers })
398 }
399
400 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
401 let mut hidden_states = xs.clone();
402 for layer in &self.layers {
403 hidden_states = layer.forward(&hidden_states, attention_mask)?;
404 }
405 Ok(hidden_states)
406 }
407}
408
409pub struct Idefics3VisionTransformer {
410 embeddings: VisionEmbeddings,
411 encoder: Encoder,
412 post_layernorm: LayerNorm,
413 patch_size: usize,
414}
415
416impl Idefics3VisionTransformer {
417 pub fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
418 let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
419 let post_layernorm = layer_norm(
420 config.hidden_size,
421 config.layer_norm_eps,
422 vb.pp("post_layernorm"),
423 )?;
424 let encoder = Encoder::new(config, vb.pp("encoder"))?;
425 Ok(Self {
426 embeddings,
427 encoder,
428 post_layernorm,
429 patch_size: config.patch_size,
430 })
431 }
432
433 pub fn forward(
434 &self,
435 pixel_values: &Tensor,
436 attention_mask: Option<&Tensor>,
437 ) -> Result<Tensor> {
438 let bs = pixel_values.dim(0)?;
439 let patch_attention_mask = if let Some(attn_mask) = attention_mask {
440 attn_mask.clone()
441 } else {
442 Tensor::ones(
443 (
444 bs,
445 pixel_values.dim(2)? / self.patch_size,
446 pixel_values.dim(3)? / self.patch_size,
447 ),
448 DType::U8,
449 pixel_values.device(),
450 )?
451 };
452
453 let hidden_states = self
454 .embeddings
455 .forward(pixel_values, &patch_attention_mask)?;
456
457 let attention_mask = if attention_mask.is_none() {
458 None
459 } else {
460 let mask = patch_attention_mask
461 .reshape((patch_attention_mask.dim(0)?, ()))?
462 .to_dtype(hidden_states.dtype())?;
463 Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
464 };
465 let hidden_states = self
466 .encoder
467 .forward(&hidden_states, attention_mask.as_ref())?;
468 hidden_states.apply(&self.post_layernorm)
469 }
470
471 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
472 let uvb = UnVarBuilder::new();
473
474 uvb.pp("post_layernorm").add(&self.post_layernorm);
475 uvb.pp("embeddings")
476 .extend(self.embeddings.residual_tensors());
477
478 let uvb_enc = uvb.pp("encoder");
479 for (i, layer) in self.encoder.layers.iter().enumerate() {
480 let uvb_l = uvb_enc.pp("layers").pp(i);
481
482 uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
483 uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
484 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
485 uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
486 }
487
488 uvb.to_safetensors()
489 }
490}