mistralrs_core/vision_models/idefics3/
vision.rs

1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module};
3use mistralrs_quant::ShardedVarBuilder;
4use std::ops::Mul;
5
6use crate::{
7    layers::{self, conv2d, embedding, layer_norm, Activation, CausalMasker, MatMul},
8    utils::unvarbuilder::UnVarBuilder,
9};
10
11use super::config::{Idefics3Config, Idefics3VisionConfig};
12
13pub(crate) struct Idefics3SimpleMLP {
14    pub(crate) proj: Linear,
15}
16
17impl Idefics3SimpleMLP {
18    pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
19        let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
20        let out_dim = cfg.text_config.hidden_size;
21        Ok(Self {
22            proj: layers::linear_no_bias(in_dim, out_dim, vb.pp("proj"))?,
23        })
24    }
25
26    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
27        x.apply(&self.proj)
28    }
29}
30
31pub struct Idefics3Connector {
32    scale_factor: usize,
33    pub(crate) modality_projection: Idefics3SimpleMLP,
34}
35
36impl Idefics3Connector {
37    pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
38        Ok(Self {
39            scale_factor: cfg.scale_factor,
40            modality_projection: Idefics3SimpleMLP::new(cfg, vb.pp("modality_projection"))?,
41        })
42    }
43
44    pub fn pixel_shuffle(&self, x: &Tensor, scale_factor: usize) -> Result<Tensor> {
45        let (bs, seq, embed_dim) = x.dims3()?;
46        let height = (seq as f32).sqrt() as usize;
47        let width = height;
48        let mut x = x.reshape((bs, height, width, embed_dim))?;
49        x = x.reshape((bs, height, width / scale_factor, embed_dim * scale_factor))?;
50        x = x.permute((0, 2, 1, 3))?;
51        x = x.reshape((
52            bs,
53            width / scale_factor,
54            height / scale_factor,
55            embed_dim * scale_factor.pow(2),
56        ))?;
57        x = x.permute((0, 2, 1, 3))?;
58        x.reshape((
59            bs,
60            (seq as f32 / scale_factor.pow(2) as f32) as usize,
61            embed_dim * scale_factor.pow(2),
62        ))
63    }
64
65    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
66        let image_hidden_states = self.pixel_shuffle(x, self.scale_factor)?;
67        self.modality_projection.forward(&image_hidden_states)
68    }
69}
70
71struct VisionEmbeddings {
72    patch_size: usize,
73    patch_embedding: Conv2d,
74    num_patches_per_side: usize,
75    position_embedding: Embedding,
76}
77
78/// torch.bucketize with right=True
79/// Returns a 1d tensor of shape (xs.len(),) on the CPU
80fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
81    use std::cmp::Ordering;
82
83    let mut result = Vec::with_capacity(xs.len());
84
85    for &x in xs {
86        // binary_search_by returns:
87        //   Ok(i)   if boundaries[i] == x
88        //   Err(i)  if x would be inserted at i
89        //
90        // The returned i is the "insertion point" for x to keep
91        // boundaries sorted. That i is the smallest position
92        // where boundaries[i] >= x (i.e. bisect_left).
93
94        let idx = match boundaries.binary_search_by(|&val| {
95            // Use partial_cmp here; assume no NaNs.
96            // For robust handling of NaNs, you might need a custom comparison.
97            val.partial_cmp(&x).unwrap_or(Ordering::Less)
98        }) {
99            Ok(i) => i,
100            Err(i) => i,
101        };
102
103        result.push(idx as u32);
104    }
105
106    Tensor::from_vec(result, (xs.len(),), device)
107}
108
109impl VisionEmbeddings {
110    fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
111        let conv_config = Conv2dConfig {
112            stride: config.patch_size,
113            ..Default::default()
114        };
115        let patch_embedding = conv2d(
116            config.num_channels,
117            config.hidden_size,
118            config.patch_size,
119            conv_config,
120            vb.pp("patch_embedding"),
121        )?;
122        let num_patches_per_side = config.image_size / config.patch_size;
123        let num_patches = num_patches_per_side.pow(2);
124        Ok(Self {
125            patch_size: config.patch_size,
126            patch_embedding,
127            num_patches_per_side,
128            position_embedding: embedding(
129                num_patches,
130                config.hidden_size,
131                vb.pp("position_embedding"),
132            )?,
133        })
134    }
135
136    fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
137        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
138
139        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
140
141        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
142
143        let (max_nb_patches_h, max_nb_patches_w) =
144            (max_im_h / self.patch_size, max_im_w / self.patch_size);
145        let boundaries = Tensor::arange_step(
146            1.0 / self.num_patches_per_side as f32,
147            1.0,
148            1.0 / self.num_patches_per_side as f32,
149            pixel_values.device(),
150        )?
151        .to_vec1::<f32>()?;
152        let position_ids = Tensor::full(
153            0u32,
154            (bs, max_nb_patches_h * max_nb_patches_w),
155            pixel_values.device(),
156        )?;
157
158        let mut new_position_ids = Vec::new();
159        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
160            let p_attn_mask = p_attn_mask.squeeze(0)?;
161            let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
162            let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
163
164            let fractional_coords_h = Tensor::arange_step(
165                0.0,
166                1.0 - 1e-6,
167                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
168                pixel_values.device(),
169            )?
170            .to_vec1::<f32>()?;
171            let fractional_coords_w = Tensor::arange_step(
172                0.0,
173                1.0 - 1e-6,
174                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
175                pixel_values.device(),
176            )?
177            .to_vec1::<f32>()?;
178
179            let bucket_coords_h =
180                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
181            let bucket_coords_w =
182                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
183
184            let pos_ids = bucket_coords_h
185                .unsqueeze(D::Minus1)?
186                .mul(self.num_patches_per_side as f64)?
187                .broadcast_add(&bucket_coords_w)?
188                .flatten_all()?
189                .to_vec1::<u32>()?;
190
191            let true_indices = p_attn_mask
192                .flatten_all()?
193                .to_vec1::<u8>()?
194                .iter()
195                .enumerate()
196                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
197                .collect::<Vec<_>>();
198            let position_ids_b = position_ids.i(b_idx)?;
199
200            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
201            let new_position_ids_b_len = new_position_ids_b.len();
202            for (i, true_idx) in true_indices.into_iter().enumerate() {
203                new_position_ids_b[true_idx] = pos_ids[i];
204            }
205
206            new_position_ids.push(Tensor::from_vec(
207                new_position_ids_b,
208                new_position_ids_b_len,
209                pixel_values.device(),
210            )?);
211        }
212        let position_ids = Tensor::stack(&new_position_ids, 0)?;
213        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
214        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
215    }
216
217    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
218        let uvb = UnVarBuilder::new();
219
220        uvb.pp("patch_embedding").add(&self.patch_embedding);
221        uvb.pp("position_embedding").add(&self.position_embedding);
222
223        uvb.to_safetensors()
224    }
225}
226
227struct Attention {
228    embed_dim: usize,
229    num_heads: usize,
230    head_dim: usize,
231    scale: f64,
232    q_proj: Linear,
233    k_proj: Linear,
234    v_proj: Linear,
235    o_proj: Linear,
236    neg_inf: Tensor,
237}
238
239impl Attention {
240    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
241        let embed_dim = config.hidden_size;
242        let num_heads = config.num_attention_heads;
243        let head_dim = embed_dim / num_heads;
244        let scale = 1.0 / (head_dim as f64).sqrt();
245
246        let q_proj = layers::linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
247        let k_proj = layers::linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
248        let v_proj = layers::linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
249        let o_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
250
251        Ok(Self {
252            embed_dim,
253            num_heads,
254            head_dim,
255            scale,
256            q_proj,
257            k_proj,
258            v_proj,
259            o_proj,
260            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
261        })
262    }
263
264    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
265        let (b_sz, q_len, _) = xs.dims3()?;
266
267        let mut q = self.q_proj.forward(xs)?;
268        let mut k = self.k_proj.forward(xs)?;
269        let mut v = self.v_proj.forward(xs)?;
270
271        q = q
272            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
273            .transpose(1, 2)?;
274        k = k
275            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
276            .transpose(1, 2)?;
277        v = v
278            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
279            .transpose(1, 2)?;
280
281        let attn_weights =
282            (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
283
284        let mut attn_weights = CausalMasker.apply_mask_one_and_zero(
285            &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
286            attn_weights,
287            &self.neg_inf,
288        )?;
289        attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
290        let attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
291
292        attn_output
293            .transpose(1, 2)?
294            .reshape((b_sz, q_len, self.embed_dim))?
295            .apply(&self.o_proj)
296    }
297
298    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
299        let uvb = UnVarBuilder::new();
300
301        uvb.pp("q_proj").add(&self.q_proj);
302        uvb.pp("k_proj").add(&self.k_proj);
303        uvb.pp("v_proj").add(&self.v_proj);
304        uvb.pp("out_proj").add(&self.o_proj);
305
306        uvb.to_safetensors()
307    }
308}
309
310struct VisionMLP {
311    activation: Activation,
312    fc1: Linear,
313    fc2: Linear,
314}
315
316impl VisionMLP {
317    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
318        let fc1 = layers::linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
319        let fc2 = layers::linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
320        Ok(Self {
321            activation: config.hidden_act,
322            fc1,
323            fc2,
324        })
325    }
326
327    fn forward(&self, x: &Tensor) -> Result<Tensor> {
328        let mut x = self.fc1.forward(x)?;
329        x = self.activation.forward(&x)?;
330        self.fc2.forward(&x)
331    }
332
333    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
334        let uvb = UnVarBuilder::new();
335
336        uvb.pp("fc1").add(&self.fc1);
337        uvb.pp("fc2").add(&self.fc2);
338
339        uvb.to_safetensors()
340    }
341}
342
343struct EncoderLayer {
344    mlp: VisionMLP,
345    attn: Attention,
346    layer_norm_1: LayerNorm,
347    layer_norm_2: LayerNorm,
348}
349
350impl EncoderLayer {
351    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
352        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
353        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
354        let layer_norm_1 = layer_norm(
355            config.hidden_size,
356            config.layer_norm_eps,
357            vb.pp("layer_norm1"),
358        )?;
359        let layer_norm_2 = layer_norm(
360            config.hidden_size,
361            config.layer_norm_eps,
362            vb.pp("layer_norm2"),
363        )?;
364        Ok(Self {
365            mlp,
366            attn,
367            layer_norm_1,
368            layer_norm_2,
369        })
370    }
371
372    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
373        let residual = xs.clone();
374
375        let hidden_states = self.layer_norm_1.forward(xs)?;
376        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
377        let hidden_states = (hidden_states + residual)?;
378
379        let residual = &hidden_states;
380        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
381        let hidden_states = self.mlp.forward(&hidden_states)?;
382        hidden_states + residual
383    }
384}
385
386struct Encoder {
387    layers: Vec<EncoderLayer>,
388}
389
390impl Encoder {
391    fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
392        let mut layers = Vec::new();
393        let vb_l = vb.pp("layers");
394        for i in 0..config.num_hidden_layers {
395            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
396        }
397        Ok(Self { layers })
398    }
399
400    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
401        let mut hidden_states = xs.clone();
402        for layer in &self.layers {
403            hidden_states = layer.forward(&hidden_states, attention_mask)?;
404        }
405        Ok(hidden_states)
406    }
407}
408
409pub struct Idefics3VisionTransformer {
410    embeddings: VisionEmbeddings,
411    encoder: Encoder,
412    post_layernorm: LayerNorm,
413    patch_size: usize,
414}
415
416impl Idefics3VisionTransformer {
417    pub fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
418        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
419        let post_layernorm = layer_norm(
420            config.hidden_size,
421            config.layer_norm_eps,
422            vb.pp("post_layernorm"),
423        )?;
424        let encoder = Encoder::new(config, vb.pp("encoder"))?;
425        Ok(Self {
426            embeddings,
427            encoder,
428            post_layernorm,
429            patch_size: config.patch_size,
430        })
431    }
432
433    pub fn forward(
434        &self,
435        pixel_values: &Tensor,
436        attention_mask: Option<&Tensor>,
437    ) -> Result<Tensor> {
438        let bs = pixel_values.dim(0)?;
439        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
440            attn_mask.clone()
441        } else {
442            Tensor::ones(
443                (
444                    bs,
445                    pixel_values.dim(2)? / self.patch_size,
446                    pixel_values.dim(3)? / self.patch_size,
447                ),
448                DType::U8,
449                pixel_values.device(),
450            )?
451        };
452
453        let hidden_states = self
454            .embeddings
455            .forward(pixel_values, &patch_attention_mask)?;
456
457        let attention_mask = if attention_mask.is_none() {
458            None
459        } else {
460            let mask = patch_attention_mask
461                .reshape((patch_attention_mask.dim(0)?, ()))?
462                .to_dtype(hidden_states.dtype())?;
463            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
464        };
465        let hidden_states = self
466            .encoder
467            .forward(&hidden_states, attention_mask.as_ref())?;
468        hidden_states.apply(&self.post_layernorm)
469    }
470
471    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
472        let uvb = UnVarBuilder::new();
473
474        uvb.pp("post_layernorm").add(&self.post_layernorm);
475        uvb.pp("embeddings")
476            .extend(self.embeddings.residual_tensors());
477
478        let uvb_enc = uvb.pp("encoder");
479        for (i, layer) in self.encoder.layers.iter().enumerate() {
480            let uvb_l = uvb_enc.pp("layers").pp(i);
481
482            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
483            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
484            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
485            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
486        }
487
488        uvb.to_safetensors()
489    }
490}