mistralrs_core/vision_models/idefics3/
vision.rs

1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module};
3use mistralrs_quant::ShardedVarBuilder;
4use std::ops::Mul;
5
6use crate::{
7    layers::{self, conv2d, embedding, layer_norm, Activation, CausalMasker, MatMul},
8    utils::unvarbuilder::UnVarBuilder,
9};
10
11use super::config::{Idefics3Config, Idefics3VisionConfig};
12
13pub(crate) struct Idefics3SimpleMLP {
14    pub(crate) proj: Linear,
15}
16
17impl Idefics3SimpleMLP {
18    pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
19        let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
20        let out_dim = cfg.text_config.hidden_size;
21        Ok(Self {
22            proj: layers::linear_no_bias(in_dim, out_dim, vb.pp("proj"))?,
23        })
24    }
25
26    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
27        x.apply(&self.proj)
28    }
29}
30
31pub struct Idefics3Connector {
32    scale_factor: usize,
33    pub(crate) modality_projection: Idefics3SimpleMLP,
34}
35
36impl Idefics3Connector {
37    pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
38        Ok(Self {
39            scale_factor: cfg.scale_factor,
40            modality_projection: Idefics3SimpleMLP::new(cfg, vb.pp("modality_projection"))?,
41        })
42    }
43
44    pub fn pixel_shuffle(&self, x: &Tensor, scale_factor: usize) -> Result<Tensor> {
45        let (bs, seq, embed_dim) = x.dims3()?;
46        let height = (seq as f32).sqrt() as usize;
47        let width = height;
48        let mut x = x.reshape((bs, height, width, embed_dim))?;
49        x = x.reshape((bs, height, width / scale_factor, embed_dim * scale_factor))?;
50        x = x.permute((0, 2, 1, 3))?;
51        x = x.reshape((
52            bs,
53            width / scale_factor,
54            height / scale_factor,
55            embed_dim * scale_factor.pow(2),
56        ))?;
57        x = x.permute((0, 2, 1, 3))?;
58        x.reshape((
59            bs,
60            (seq as f32 / scale_factor.pow(2) as f32) as usize,
61            embed_dim * scale_factor.pow(2),
62        ))
63    }
64
65    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
66        let image_hidden_states = self.pixel_shuffle(x, self.scale_factor)?;
67        self.modality_projection.forward(&image_hidden_states)
68    }
69}
70
71struct VisionEmbeddings {
72    patch_size: usize,
73    patch_embedding: Conv2d,
74    num_patches_per_side: usize,
75    position_embedding: Embedding,
76}
77
78/// torch.bucketize with right=True
79/// Returns a 1d tensor of shape (xs.len(),) on the CPU
80fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
81    use std::cmp::Ordering;
82
83    let mut result = Vec::with_capacity(xs.len());
84
85    for &x in xs {
86        // binary_search_by returns:
87        //   Ok(i)   if boundaries[i] == x
88        //   Err(i)  if x would be inserted at i
89        //
90        // The returned i is the "insertion point" for x to keep
91        // boundaries sorted. That i is the smallest position
92        // where boundaries[i] >= x (i.e. bisect_left).
93
94        let idx = match boundaries.binary_search_by(|&val| {
95            // Use partial_cmp here; assume no NaNs.
96            // For robust handling of NaNs, you might need a custom comparison.
97            val.partial_cmp(&x).unwrap_or(Ordering::Less)
98        }) {
99            Ok(i) => i,
100            Err(i) => i,
101        };
102
103        result.push(idx as u32);
104    }
105
106    Tensor::from_vec(result, (xs.len(),), device)
107}
108
109impl VisionEmbeddings {
110    fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
111        let conv_config = Conv2dConfig {
112            stride: config.patch_size,
113            ..Default::default()
114        };
115        let patch_embedding = conv2d(
116            config.num_channels,
117            config.hidden_size,
118            config.patch_size,
119            conv_config,
120            vb.pp("patch_embedding"),
121        )?;
122        let num_patches_per_side = config.image_size / config.patch_size;
123        let num_patches = num_patches_per_side.pow(2);
124        Ok(Self {
125            patch_size: config.patch_size,
126            patch_embedding,
127            num_patches_per_side,
128            position_embedding: embedding(
129                num_patches,
130                config.hidden_size,
131                vb.pp("position_embedding"),
132                &None,
133            )?,
134        })
135    }
136
137    fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
138        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
139
140        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
141
142        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
143
144        let (max_nb_patches_h, max_nb_patches_w) =
145            (max_im_h / self.patch_size, max_im_w / self.patch_size);
146        let boundaries = Tensor::arange_step(
147            1.0 / self.num_patches_per_side as f32,
148            1.0,
149            1.0 / self.num_patches_per_side as f32,
150            pixel_values.device(),
151        )?
152        .to_vec1::<f32>()?;
153        let position_ids = Tensor::full(
154            0u32,
155            (bs, max_nb_patches_h * max_nb_patches_w),
156            pixel_values.device(),
157        )?;
158
159        let mut new_position_ids = Vec::new();
160        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
161            let p_attn_mask = p_attn_mask.squeeze(0)?;
162            let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
163            let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
164
165            let fractional_coords_h = Tensor::arange_step(
166                0.0,
167                1.0 - 1e-6,
168                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
169                pixel_values.device(),
170            )?
171            .to_vec1::<f32>()?;
172            let fractional_coords_w = Tensor::arange_step(
173                0.0,
174                1.0 - 1e-6,
175                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
176                pixel_values.device(),
177            )?
178            .to_vec1::<f32>()?;
179
180            let bucket_coords_h =
181                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
182            let bucket_coords_w =
183                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
184
185            let pos_ids = bucket_coords_h
186                .unsqueeze(D::Minus1)?
187                .mul(self.num_patches_per_side as f64)?
188                .broadcast_add(&bucket_coords_w)?
189                .flatten_all()?
190                .to_vec1::<u32>()?;
191
192            let true_indices = p_attn_mask
193                .flatten_all()?
194                .to_vec1::<u8>()?
195                .iter()
196                .enumerate()
197                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
198                .collect::<Vec<_>>();
199            let position_ids_b = position_ids.i(b_idx)?;
200
201            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
202            let new_position_ids_b_len = new_position_ids_b.len();
203            for (i, true_idx) in true_indices.into_iter().enumerate() {
204                new_position_ids_b[true_idx] = pos_ids[i];
205            }
206
207            new_position_ids.push(Tensor::from_vec(
208                new_position_ids_b,
209                new_position_ids_b_len,
210                pixel_values.device(),
211            )?);
212        }
213        let position_ids = Tensor::stack(&new_position_ids, 0)?;
214        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
215        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
216    }
217
218    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
219        let uvb = UnVarBuilder::new();
220
221        uvb.pp("patch_embedding").add(&self.patch_embedding);
222        uvb.pp("position_embedding").add(&self.position_embedding);
223
224        uvb.to_safetensors()
225    }
226}
227
228struct Attention {
229    embed_dim: usize,
230    num_heads: usize,
231    head_dim: usize,
232    scale: f64,
233    q_proj: Linear,
234    k_proj: Linear,
235    v_proj: Linear,
236    o_proj: Linear,
237    neg_inf: Tensor,
238}
239
240impl Attention {
241    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
242        let embed_dim = config.hidden_size;
243        let num_heads = config.num_attention_heads;
244        let head_dim = embed_dim / num_heads;
245        let scale = 1.0 / (head_dim as f64).sqrt();
246
247        let q_proj = layers::linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
248        let k_proj = layers::linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
249        let v_proj = layers::linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
250        let o_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
251
252        Ok(Self {
253            embed_dim,
254            num_heads,
255            head_dim,
256            scale,
257            q_proj,
258            k_proj,
259            v_proj,
260            o_proj,
261            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
262        })
263    }
264
265    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
266        let (b_sz, q_len, _) = xs.dims3()?;
267
268        let mut q = self.q_proj.forward(xs)?;
269        let mut k = self.k_proj.forward(xs)?;
270        let mut v = self.v_proj.forward(xs)?;
271
272        q = q
273            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
274            .transpose(1, 2)?;
275        k = k
276            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
277            .transpose(1, 2)?;
278        v = v
279            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
280            .transpose(1, 2)?;
281
282        let attn_weights =
283            (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
284
285        let mut attn_weights = CausalMasker.apply_mask_one_and_zero(
286            &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
287            attn_weights,
288            &self.neg_inf,
289        )?;
290        attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
291        let attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
292
293        attn_output
294            .transpose(1, 2)?
295            .reshape((b_sz, q_len, self.embed_dim))?
296            .apply(&self.o_proj)
297    }
298
299    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
300        let uvb = UnVarBuilder::new();
301
302        uvb.pp("q_proj").add(&self.q_proj);
303        uvb.pp("k_proj").add(&self.k_proj);
304        uvb.pp("v_proj").add(&self.v_proj);
305        uvb.pp("out_proj").add(&self.o_proj);
306
307        uvb.to_safetensors()
308    }
309}
310
311struct VisionMLP {
312    activation: Activation,
313    fc1: Linear,
314    fc2: Linear,
315}
316
317impl VisionMLP {
318    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
319        let fc1 = layers::linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
320        let fc2 = layers::linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
321        Ok(Self {
322            activation: config.hidden_act,
323            fc1,
324            fc2,
325        })
326    }
327
328    fn forward(&self, x: &Tensor) -> Result<Tensor> {
329        let mut x = self.fc1.forward(x)?;
330        x = self.activation.forward(&x)?;
331        self.fc2.forward(&x)
332    }
333
334    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
335        let uvb = UnVarBuilder::new();
336
337        uvb.pp("fc1").add(&self.fc1);
338        uvb.pp("fc2").add(&self.fc2);
339
340        uvb.to_safetensors()
341    }
342}
343
344struct EncoderLayer {
345    mlp: VisionMLP,
346    attn: Attention,
347    layer_norm_1: LayerNorm,
348    layer_norm_2: LayerNorm,
349}
350
351impl EncoderLayer {
352    fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
353        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
354        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
355        let layer_norm_1 = layer_norm(
356            config.hidden_size,
357            config.layer_norm_eps,
358            vb.pp("layer_norm1"),
359        )?;
360        let layer_norm_2 = layer_norm(
361            config.hidden_size,
362            config.layer_norm_eps,
363            vb.pp("layer_norm2"),
364        )?;
365        Ok(Self {
366            mlp,
367            attn,
368            layer_norm_1,
369            layer_norm_2,
370        })
371    }
372
373    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
374        let residual = xs.clone();
375
376        let hidden_states = self.layer_norm_1.forward(xs)?;
377        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
378        let hidden_states = (hidden_states + residual)?;
379
380        let residual = &hidden_states;
381        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
382        let hidden_states = self.mlp.forward(&hidden_states)?;
383        hidden_states + residual
384    }
385}
386
387struct Encoder {
388    layers: Vec<EncoderLayer>,
389}
390
391impl Encoder {
392    fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
393        let mut layers = Vec::new();
394        let vb_l = vb.pp("layers");
395        for i in 0..config.num_hidden_layers {
396            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
397        }
398        Ok(Self { layers })
399    }
400
401    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
402        let mut hidden_states = xs.clone();
403        for layer in &self.layers {
404            hidden_states = layer.forward(&hidden_states, attention_mask)?;
405        }
406        Ok(hidden_states)
407    }
408}
409
410pub struct Idefics3VisionTransformer {
411    embeddings: VisionEmbeddings,
412    encoder: Encoder,
413    post_layernorm: LayerNorm,
414    patch_size: usize,
415}
416
417impl Idefics3VisionTransformer {
418    pub fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
419        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
420        let post_layernorm = layer_norm(
421            config.hidden_size,
422            config.layer_norm_eps,
423            vb.pp("post_layernorm"),
424        )?;
425        let encoder = Encoder::new(config, vb.pp("encoder"))?;
426        Ok(Self {
427            embeddings,
428            encoder,
429            post_layernorm,
430            patch_size: config.patch_size,
431        })
432    }
433
434    pub fn forward(
435        &self,
436        pixel_values: &Tensor,
437        attention_mask: Option<&Tensor>,
438    ) -> Result<Tensor> {
439        let bs = pixel_values.dim(0)?;
440        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
441            attn_mask.clone()
442        } else {
443            Tensor::ones(
444                (
445                    bs,
446                    pixel_values.dim(2)? / self.patch_size,
447                    pixel_values.dim(3)? / self.patch_size,
448                ),
449                DType::U8,
450                pixel_values.device(),
451            )?
452        };
453
454        let hidden_states = self
455            .embeddings
456            .forward(pixel_values, &patch_attention_mask)?;
457
458        let attention_mask = if attention_mask.is_none() {
459            None
460        } else {
461            let mask = patch_attention_mask
462                .reshape((patch_attention_mask.dim(0)?, ()))?
463                .to_dtype(hidden_states.dtype())?;
464            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
465        };
466        let hidden_states = self
467            .encoder
468            .forward(&hidden_states, attention_mask.as_ref())?;
469        hidden_states.apply(&self.post_layernorm)
470    }
471
472    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
473        let uvb = UnVarBuilder::new();
474
475        uvb.pp("post_layernorm").add(&self.post_layernorm);
476        uvb.pp("embeddings")
477            .extend(self.embeddings.residual_tensors());
478
479        let uvb_enc = uvb.pp("encoder");
480        for (i, layer) in self.encoder.layers.iter().enumerate() {
481            let uvb_l = uvb_enc.pp("layers").pp(i);
482
483            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
484            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
485            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
486            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
487        }
488
489        uvb.to_safetensors()
490    }
491}