mistralrs_core/vision_models/idefics3/
vision.rs

1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module};
3use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
4use std::{ops::Mul, sync::Arc};
5
6use crate::{
7    attention::SdpaParams,
8    layers::{self, conv2d, embedding, layer_norm, Activation, CausalMasker, Sdpa},
9    utils::unvarbuilder::UnVarBuilder,
10};
11
12use super::config::{Idefics3Config, Idefics3VisionConfig};
13
14pub(crate) struct Idefics3SimpleMLP {
15    pub(crate) proj: Linear,
16}
17
18impl Idefics3SimpleMLP {
19    pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
20        let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
21        let out_dim = cfg.text_config.hidden_size;
22        Ok(Self {
23            proj: layers::linear_no_bias(in_dim, out_dim, vb.pp("proj"))?,
24        })
25    }
26
27    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
28        x.apply(&self.proj)
29    }
30}
31
32pub struct Idefics3Connector {
33    scale_factor: usize,
34    pub(crate) modality_projection: Idefics3SimpleMLP,
35}
36
37impl Idefics3Connector {
38    pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
39        Ok(Self {
40            scale_factor: cfg.scale_factor,
41            modality_projection: Idefics3SimpleMLP::new(cfg, vb.pp("modality_projection"))?,
42        })
43    }
44
45    pub fn pixel_shuffle(&self, x: &Tensor, scale_factor: usize) -> Result<Tensor> {
46        let (bs, seq, embed_dim) = x.dims3()?;
47        let height = (seq as f32).sqrt() as usize;
48        let width = height;
49        let mut x = x.reshape((bs, height, width, embed_dim))?;
50        x = x.reshape((bs, height, width / scale_factor, embed_dim * scale_factor))?;
51        x = x.permute((0, 2, 1, 3))?;
52        x = x.reshape((
53            bs,
54            width / scale_factor,
55            height / scale_factor,
56            embed_dim * scale_factor.pow(2),
57        ))?;
58        x = x.permute((0, 2, 1, 3))?;
59        x.reshape((
60            bs,
61            (seq as f32 / scale_factor.pow(2) as f32) as usize,
62            embed_dim * scale_factor.pow(2),
63        ))
64    }
65
66    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
67        let image_hidden_states = self.pixel_shuffle(x, self.scale_factor)?;
68        self.modality_projection.forward(&image_hidden_states)
69    }
70}
71
72struct VisionEmbeddings {
73    patch_size: usize,
74    patch_embedding: Conv2d,
75    num_patches_per_side: usize,
76    position_embedding: Embedding,
77}
78
79/// torch.bucketize with right=True
80/// Returns a 1d tensor of shape (xs.len(),) on the CPU
81fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
82    use std::cmp::Ordering;
83
84    let mut result = Vec::with_capacity(xs.len());
85
86    for &x in xs {
87        // binary_search_by returns:
88        //   Ok(i)   if boundaries[i] == x
89        //   Err(i)  if x would be inserted at i
90        //
91        // The returned i is the "insertion point" for x to keep
92        // boundaries sorted. That i is the smallest position
93        // where boundaries[i] >= x (i.e. bisect_left).
94
95        let idx = match boundaries.binary_search_by(|&val| {
96            // Use partial_cmp here; assume no NaNs.
97            // For robust handling of NaNs, you might need a custom comparison.
98            val.partial_cmp(&x).unwrap_or(Ordering::Less)
99        }) {
100            Ok(i) => i,
101            Err(i) => i,
102        };
103
104        result.push(idx as u32);
105    }
106
107    Tensor::from_vec(result, (xs.len(),), device)
108}
109
110impl VisionEmbeddings {
111    fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
112        let conv_config = Conv2dConfig {
113            stride: config.patch_size,
114            ..Default::default()
115        };
116        let patch_embedding = conv2d(
117            config.num_channels,
118            config.hidden_size,
119            config.patch_size,
120            conv_config,
121            vb.pp("patch_embedding"),
122        )?;
123        let num_patches_per_side = config.image_size / config.patch_size;
124        let num_patches = num_patches_per_side.pow(2);
125        Ok(Self {
126            patch_size: config.patch_size,
127            patch_embedding,
128            num_patches_per_side,
129            position_embedding: embedding(
130                num_patches,
131                config.hidden_size,
132                vb.pp("position_embedding"),
133                &None,
134            )?,
135        })
136    }
137
138    fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
139        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
140
141        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
142
143        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
144
145        let (max_nb_patches_h, max_nb_patches_w) =
146            (max_im_h / self.patch_size, max_im_w / self.patch_size);
147        let boundaries = Tensor::arange_step(
148            1.0 / self.num_patches_per_side as f32,
149            1.0,
150            1.0 / self.num_patches_per_side as f32,
151            pixel_values.device(),
152        )?
153        .to_vec1::<f32>()?;
154        let position_ids = Tensor::full(
155            0u32,
156            (bs, max_nb_patches_h * max_nb_patches_w),
157            pixel_values.device(),
158        )?;
159
160        let mut new_position_ids = Vec::new();
161        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
162            let p_attn_mask = p_attn_mask.squeeze(0)?;
163            let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
164            let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
165
166            let fractional_coords_h = Tensor::arange_step(
167                0.0,
168                1.0 - 1e-6,
169                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
170                pixel_values.device(),
171            )?
172            .to_vec1::<f32>()?;
173            let fractional_coords_w = Tensor::arange_step(
174                0.0,
175                1.0 - 1e-6,
176                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
177                pixel_values.device(),
178            )?
179            .to_vec1::<f32>()?;
180
181            let bucket_coords_h =
182                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
183            let bucket_coords_w =
184                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
185
186            let pos_ids = bucket_coords_h
187                .unsqueeze(D::Minus1)?
188                .mul(self.num_patches_per_side as f64)?
189                .broadcast_add(&bucket_coords_w)?
190                .flatten_all()?
191                .to_vec1::<u32>()?;
192
193            let true_indices = p_attn_mask
194                .flatten_all()?
195                .to_vec1::<u8>()?
196                .iter()
197                .enumerate()
198                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
199                .collect::<Vec<_>>();
200            let position_ids_b = position_ids.i(b_idx)?;
201
202            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
203            let new_position_ids_b_len = new_position_ids_b.len();
204            for (i, true_idx) in true_indices.into_iter().enumerate() {
205                new_position_ids_b[true_idx] = pos_ids[i];
206            }
207
208            new_position_ids.push(Tensor::from_vec(
209                new_position_ids_b,
210                new_position_ids_b_len,
211                pixel_values.device(),
212            )?);
213        }
214        let position_ids = Tensor::stack(&new_position_ids, 0)?;
215        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
216        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
217    }
218
219    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
220        let uvb = UnVarBuilder::new();
221
222        uvb.pp("patch_embedding").add(&self.patch_embedding);
223        uvb.pp("position_embedding").add(&self.position_embedding);
224
225        uvb.to_safetensors()
226    }
227}
228
229struct Attention {
230    embed_dim: usize,
231    num_heads: usize,
232    head_dim: usize,
233    q_proj: Arc<dyn QuantMethod>,
234    k_proj: Arc<dyn QuantMethod>,
235    v_proj: Arc<dyn QuantMethod>,
236    o_proj: Arc<dyn QuantMethod>,
237    sdpa_params: SdpaParams,
238}
239
240impl Attention {
241    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
242        let embed_dim = config.hidden_size;
243        let num_heads = config.num_attention_heads;
244        let head_dim = embed_dim / num_heads;
245        let scale = 1.0 / (head_dim as f32).sqrt();
246
247        let q_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("q_proj"))?;
248        let k_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("k_proj"))?;
249        let v_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("v_proj"))?;
250        let o_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("out_proj"))?;
251
252        Ok(Self {
253            embed_dim,
254            num_heads,
255            head_dim,
256            q_proj,
257            k_proj,
258            v_proj,
259            o_proj,
260            sdpa_params: SdpaParams {
261                n_kv_groups: 1,
262                softcap: None,
263                softmax_scale: scale,
264                sliding_window: None,
265            },
266        })
267    }
268
269    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
270        let (b_sz, q_len, _) = xs.dims3()?;
271
272        let mut q = self.q_proj.forward_autocast(xs)?;
273        let mut k = self.k_proj.forward_autocast(xs)?;
274        let mut v = self.v_proj.forward_autocast(xs)?;
275
276        q = q
277            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
278            .transpose(1, 2)?;
279        k = k
280            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
281            .transpose(1, 2)?;
282        v = v
283            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
284            .transpose(1, 2)?;
285
286        let attn_output =
287            Sdpa.run_attention(&q, &k, &v, attention_mask, None, &self.sdpa_params)?;
288
289        self.o_proj
290            .forward_autocast(&attn_output.transpose(1, 2)?.reshape((
291                b_sz,
292                q_len,
293                self.embed_dim,
294            ))?)
295    }
296
297    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
298        let uvb = UnVarBuilder::new();
299
300        uvb.pp("q_proj").add(&self.q_proj);
301        uvb.pp("k_proj").add(&self.k_proj);
302        uvb.pp("v_proj").add(&self.v_proj);
303        uvb.pp("out_proj").add(&self.o_proj);
304
305        uvb.to_safetensors()
306    }
307}
308
309struct VisionMLP {
310    activation: Activation,
311    fc1: Arc<dyn QuantMethod>,
312    fc2: Arc<dyn QuantMethod>,
313}
314
315impl VisionMLP {
316    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
317        let fc1 = mistralrs_quant::linear(
318            config.hidden_size,
319            config.intermediate_size,
320            &None,
321            vb.pp("fc1"),
322        )?;
323        let fc2 = mistralrs_quant::linear(
324            config.intermediate_size,
325            config.hidden_size,
326            &None,
327            vb.pp("fc2"),
328        )?;
329        Ok(Self {
330            activation: config.hidden_act,
331            fc1,
332            fc2,
333        })
334    }
335
336    fn forward(&self, x: &Tensor) -> Result<Tensor> {
337        let mut x = self.fc1.forward_autocast(x)?;
338        x = self.activation.forward(&x)?;
339        self.fc2.forward_autocast(&x)
340    }
341
342    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
343        let uvb = UnVarBuilder::new();
344
345        uvb.pp("fc1").add(&self.fc1);
346        uvb.pp("fc2").add(&self.fc2);
347
348        uvb.to_safetensors()
349    }
350}
351
352struct EncoderLayer {
353    mlp: VisionMLP,
354    attn: Attention,
355    layer_norm_1: LayerNorm,
356    layer_norm_2: LayerNorm,
357}
358
359impl EncoderLayer {
360    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
361        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
362        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
363        let layer_norm_1 = layer_norm(
364            config.hidden_size,
365            config.layer_norm_eps,
366            vb.pp("layer_norm1"),
367        )?;
368        let layer_norm_2 = layer_norm(
369            config.hidden_size,
370            config.layer_norm_eps,
371            vb.pp("layer_norm2"),
372        )?;
373        Ok(Self {
374            mlp,
375            attn,
376            layer_norm_1,
377            layer_norm_2,
378        })
379    }
380
381    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
382        let residual = xs.clone();
383
384        let mut hidden_states = self.layer_norm_1.forward(xs)?;
385        hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
386        hidden_states = (hidden_states + residual)?;
387
388        let residual = hidden_states.clone();
389        hidden_states = self.layer_norm_2.forward(&hidden_states)?;
390        hidden_states = self.mlp.forward(&hidden_states)?;
391        hidden_states + residual
392    }
393}
394
395struct Encoder {
396    layers: Vec<EncoderLayer>,
397}
398
399impl Encoder {
400    fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
401        let mut layers = Vec::new();
402        let vb_l = vb.pp("layers");
403        for i in 0..config.num_hidden_layers {
404            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
405        }
406        Ok(Self { layers })
407    }
408
409    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
410        let mut hidden_states = xs.clone();
411        for layer in &self.layers {
412            hidden_states = layer.forward(&hidden_states, attention_mask)?;
413        }
414        Ok(hidden_states)
415    }
416}
417
418pub struct Idefics3VisionTransformer {
419    embeddings: VisionEmbeddings,
420    encoder: Encoder,
421    post_layernorm: LayerNorm,
422    patch_size: usize,
423    neg_inf: Tensor,
424}
425
426impl Idefics3VisionTransformer {
427    pub fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
428        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
429        let post_layernorm = layer_norm(
430            config.hidden_size,
431            config.layer_norm_eps,
432            vb.pp("post_layernorm"),
433        )?;
434        let encoder = Encoder::new(config, vb.pp("encoder"))?;
435        Ok(Self {
436            embeddings,
437            encoder,
438            post_layernorm,
439            patch_size: config.patch_size,
440            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
441        })
442    }
443
444    pub fn forward(
445        &self,
446        pixel_values: &Tensor,
447        attention_mask: Option<&Tensor>,
448    ) -> Result<Tensor> {
449        let bs = pixel_values.dim(0)?;
450        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
451            attn_mask.clone()
452        } else {
453            Tensor::ones(
454                (
455                    bs,
456                    pixel_values.dim(2)? / self.patch_size,
457                    pixel_values.dim(3)? / self.patch_size,
458                ),
459                DType::U8,
460                pixel_values.device(),
461            )?
462        };
463
464        let hidden_states = self
465            .embeddings
466            .forward(pixel_values, &patch_attention_mask)?;
467
468        let attention_mask = if attention_mask.is_none() {
469            None
470        } else {
471            let mask = patch_attention_mask
472                .reshape((patch_attention_mask.dim(0)?, ()))?
473                .to_dtype(hidden_states.dtype())?;
474            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
475        };
476
477        let attention_mask = match &attention_mask {
478            Some(mask) => {
479                let (b, _h, q, k) = mask.dims4()?;
480                let dims = [b, self.encoder.layers[0].attn.num_heads, q, k];
481                let mask = mask.broadcast_as(&dims)?;
482                let mask = mask.to_dtype(DType::U8)?.where_cond(
483                    &self.neg_inf.to_device(mask.device())?.broadcast_as(&dims)?,
484                    &Tensor::zeros(&dims, self.neg_inf.dtype(), self.neg_inf.device())?,
485                )?;
486
487                Some(mask)
488            }
489            None => None,
490        };
491        let hidden_states = self
492            .encoder
493            .forward(&hidden_states, attention_mask.as_ref())?;
494        hidden_states.apply(&self.post_layernorm)
495    }
496
497    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
498        let uvb = UnVarBuilder::new();
499
500        uvb.pp("post_layernorm").add(&self.post_layernorm);
501        uvb.pp("embeddings")
502            .extend(self.embeddings.residual_tensors());
503
504        let uvb_enc = uvb.pp("encoder");
505        for (i, layer) in self.encoder.layers.iter().enumerate() {
506            let uvb_l = uvb_enc.pp("layers").pp(i);
507
508            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
509            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
510            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
511            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
512        }
513
514        uvb.to_safetensors()
515    }
516}