mistralrs_core/vision_models/idefics2/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod idefics2_input_processor;
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
7use mistralrs_quant::ShardedVarBuilder;
8use serde::Deserialize;
9use std::{any::Any, ops::Mul};
10
11use crate::{
12    amoe::{AnyMoeBaseModelMixin, MlpLayer},
13    device_map::DeviceMapper,
14    layers::{
15        conv2d, embedding, layer_norm, linear, linear_no_bias, repeat_kv, Activation, CausalMasker,
16        MatMul, QLinear, RmsNorm,
17    },
18    models::mistral::Model as Mistral,
19    paged_attention::{AttentionImplementation, ModelConfigMetadata},
20    pipeline::{
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
23    },
24    utils::unvarbuilder::UnVarBuilder,
25    AnyMoeConfig, AnyMoeExpertType,
26};
27
28use crate::models::mistral;
29
30// https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
31
32fn default_32000() -> usize {
33    32000
34}
35fn default_32001() -> usize {
36    32001
37}
38fn default_4096() -> usize {
39    4096
40}
41fn default_14336() -> usize {
42    14336
43}
44fn default_32() -> usize {
45    32
46}
47fn default_8() -> usize {
48    8
49}
50fn default_act() -> Activation {
51    Activation::Silu
52}
53fn default_131072() -> usize {
54    131072
55}
56fn default_eps() -> f64 {
57    1e-6
58}
59fn default_rope() -> f64 {
60    10000.0
61}
62fn default_false() -> bool {
63    false
64}
65fn default_sliding() -> Option<usize> {
66    Some(4096)
67}
68fn default_gelu() -> Activation {
69    Activation::GeluPytorchTanh
70}
71fn default_64() -> usize {
72    64
73}
74fn default_3() -> usize {
75    3
76}
77fn default_16() -> usize {
78    16
79}
80fn default_96() -> usize {
81    96
82}
83fn default_4() -> usize {
84    4
85}
86fn default_0_0() -> f32 {
87    0.0
88}
89fn default_0_02() -> f32 {
90    0.02
91}
92fn default_768() -> usize {
93    768
94}
95fn default_3072() -> usize {
96    3072
97}
98fn default_12() -> usize {
99    12
100}
101fn default_224() -> usize {
102    224
103}
104
105#[derive(Debug, Clone, PartialEq, Deserialize)]
106pub struct PerceiverConfig {
107    #[serde(default = "default_act")]
108    pub hidden_act: Activation,
109    #[serde(default = "default_64")]
110    pub resampler_n_latents: usize,
111    #[serde(default = "default_3")]
112    pub resampler_depth: usize,
113    #[serde(default = "default_16")]
114    pub resampler_n_heads: usize,
115    #[serde(default = "default_96")]
116    pub resampler_head_dim: usize,
117    #[serde(default = "default_4")]
118    pub num_key_value_heads: usize,
119    #[serde(default = "default_0_0")]
120    pub attention_dropout: f32,
121}
122
123#[derive(Debug, Clone, PartialEq, Deserialize)]
124pub struct VisionConfig {
125    #[serde(default = "default_768")]
126    pub hidden_size: usize,
127    #[serde(default = "default_3072")]
128    pub intermediate_size: usize,
129    #[serde(default = "default_12")]
130    pub num_hidden_layers: usize,
131    #[serde(default = "default_12")]
132    pub num_attention_heads: usize,
133    #[serde(default = "default_3")]
134    pub num_channels: usize,
135    #[serde(default = "default_224")]
136    pub image_size: usize,
137    #[serde(default = "default_32")]
138    pub patch_size: usize,
139    #[serde(default = "default_gelu")]
140    pub hidden_act: Activation,
141    #[serde(default = "default_eps")]
142    pub layer_norm_eps: f64,
143    #[serde(default = "default_0_0")]
144    pub attn_dropout: f32,
145    #[serde(default = "default_0_02")]
146    pub initiailizer_range: f32,
147}
148
149#[derive(Debug, Clone, PartialEq, Deserialize)]
150pub(crate) struct TextConfig {
151    #[serde(default = "default_32000")]
152    pub(crate) vocab_size: usize,
153    #[serde(default = "default_4096")]
154    pub(crate) hidden_size: usize,
155    #[serde(default = "default_14336")]
156    pub(crate) intermediate_size: usize,
157    #[serde(default = "default_32")]
158    pub(crate) num_hidden_layers: usize,
159    #[serde(default = "default_32")]
160    pub(crate) num_attention_heads: usize,
161    #[serde(default = "default_8")]
162    pub(crate) num_key_value_heads: usize,
163    #[serde(default = "default_act")]
164    pub(crate) hidden_act: Activation,
165    #[serde(default = "default_131072")]
166    pub(crate) max_position_embeddings: usize,
167    #[serde(default = "default_eps")]
168    pub(crate) rms_norm_eps: f64,
169    #[serde(default = "default_rope")]
170    pub(crate) rope_theta: f64,
171    #[serde(default = "default_sliding")]
172    pub(crate) sliding_window: Option<usize>,
173
174    #[serde(default = "default_false")]
175    pub(crate) use_flash_attn: bool,
176    model_type: String, // Must be mistral for now
177}
178
179impl From<TextConfig> for mistral::Config {
180    fn from(val: TextConfig) -> Self {
181        mistral::Config {
182            vocab_size: val.vocab_size,
183            hidden_act: val.hidden_act,
184            hidden_size: val.hidden_size,
185            intermediate_size: val.intermediate_size,
186            num_hidden_layers: val.num_hidden_layers,
187            num_attention_heads: val.num_attention_heads,
188            num_key_value_heads: val.num_key_value_heads,
189            max_position_embeddings: val.max_position_embeddings,
190            rms_norm_eps: val.rms_norm_eps,
191            rope_theta: val.rope_theta,
192            sliding_window: val.sliding_window,
193            use_flash_attn: val.use_flash_attn,
194            head_dim: None,
195            quantization_config: None,
196            tie_word_embeddings: false,
197        }
198    }
199}
200
201#[derive(Debug, Clone, PartialEq, Deserialize)]
202pub(crate) struct Config {
203    pub perceiver_config: PerceiverConfig,
204    pub vision_config: VisionConfig,
205    pub(crate) text_config: TextConfig,
206    #[serde(default = "default_32001")]
207    pub image_token_id: usize,
208    #[serde(default = "default_false")]
209    pub tie_word_embeddings: bool,
210}
211
212// == START VISION MODEL ==
213
214struct VisionEmbeddings {
215    patch_size: usize,
216    patch_embedding: Conv2d,
217    num_patches_per_side: usize,
218    position_embedding: Embedding,
219}
220
221/// torch.bucketize with right=True
222/// Returns a 1d tensor of shape (xs.len(),) on the CPU
223fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
224    use std::cmp::Ordering;
225
226    let mut result = Vec::with_capacity(xs.len());
227
228    for &x in xs {
229        // binary_search_by returns:
230        //   Ok(i)   if boundaries[i] == x
231        //   Err(i)  if x would be inserted at i
232        //
233        // The returned i is the "insertion point" for x to keep
234        // boundaries sorted. That i is the smallest position
235        // where boundaries[i] >= x (i.e. bisect_left).
236
237        let idx = match boundaries.binary_search_by(|&val| {
238            // Use partial_cmp here; assume no NaNs.
239            // For robust handling of NaNs, you might need a custom comparison.
240            val.partial_cmp(&x).unwrap_or(Ordering::Less)
241        }) {
242            Ok(i) => i,
243            Err(i) => i,
244        };
245
246        result.push(idx as u32);
247    }
248
249    Tensor::from_vec(result, (xs.len(),), device)
250}
251
252impl VisionEmbeddings {
253    fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
254        let conv_config = Conv2dConfig {
255            stride: config.patch_size,
256            ..Default::default()
257        };
258        let patch_embedding = conv2d(
259            config.num_channels,
260            config.hidden_size,
261            config.patch_size,
262            conv_config,
263            vb.pp("patch_embedding"),
264        )?;
265        let num_patches_per_side = config.image_size / config.patch_size;
266        let num_patches = num_patches_per_side.pow(2);
267        Ok(Self {
268            patch_size: config.patch_size,
269            patch_embedding,
270            num_patches_per_side,
271            position_embedding: embedding(
272                num_patches,
273                config.hidden_size,
274                vb.pp("position_embedding"),
275                &None,
276            )?,
277        })
278    }
279
280    fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
281        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
282
283        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
284
285        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
286
287        let (max_nb_patches_h, max_nb_patches_w) =
288            (max_im_h / self.patch_size, max_im_w / self.patch_size);
289        let boundaries = Tensor::arange_step(
290            1.0 / self.num_patches_per_side as f32,
291            1.0,
292            1.0 / self.num_patches_per_side as f32,
293            pixel_values.device(),
294        )?
295        .to_vec1::<f32>()?;
296        let position_ids = Tensor::full(
297            0u32,
298            (bs, max_nb_patches_h * max_nb_patches_w),
299            pixel_values.device(),
300        )?;
301
302        let mut new_position_ids = Vec::new();
303        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
304            let p_attn_mask = p_attn_mask.squeeze(0)?;
305            let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
306            let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
307
308            let fractional_coords_h = Tensor::arange_step(
309                0.0,
310                1.0 - 1e-6,
311                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
312                pixel_values.device(),
313            )?
314            .to_vec1::<f32>()?;
315            let fractional_coords_w = Tensor::arange_step(
316                0.0,
317                1.0 - 1e-6,
318                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
319                pixel_values.device(),
320            )?
321            .to_vec1::<f32>()?;
322
323            let bucket_coords_h =
324                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
325            let bucket_coords_w =
326                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
327
328            let pos_ids = bucket_coords_h
329                .unsqueeze(D::Minus1)?
330                .mul(self.num_patches_per_side as f64)?
331                .broadcast_add(&bucket_coords_w)?
332                .flatten_all()?
333                .to_vec1::<u32>()?;
334
335            let true_indices = p_attn_mask
336                .flatten_all()?
337                .to_vec1::<u8>()?
338                .iter()
339                .enumerate()
340                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
341                .collect::<Vec<_>>();
342            let position_ids_b = position_ids.i(b_idx)?;
343
344            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
345            let new_position_ids_b_len = new_position_ids_b.len();
346            for (i, true_idx) in true_indices.into_iter().enumerate() {
347                new_position_ids_b[true_idx] = pos_ids[i];
348            }
349
350            new_position_ids.push(Tensor::from_vec(
351                new_position_ids_b,
352                new_position_ids_b_len,
353                pixel_values.device(),
354            )?);
355        }
356        let position_ids = Tensor::stack(&new_position_ids, 0)?;
357        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
358        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
359    }
360
361    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
362        let uvb = UnVarBuilder::new();
363
364        uvb.pp("patch_embedding").add(&self.patch_embedding);
365        uvb.pp("position_embedding").add(&self.position_embedding);
366
367        uvb.to_safetensors()
368    }
369}
370
371struct Attention {
372    embed_dim: usize,
373    num_heads: usize,
374    head_dim: usize,
375    scale: f64,
376    q_proj: QLinear,
377    k_proj: QLinear,
378    v_proj: QLinear,
379    o_proj: QLinear,
380    neg_inf: Tensor,
381}
382
383impl Attention {
384    fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
385        let embed_dim = config.hidden_size;
386        let num_heads = config.num_attention_heads;
387        let head_dim = embed_dim / num_heads;
388        let scale = 1.0 / (head_dim as f64).sqrt();
389
390        let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
391        let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
392        let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
393        let o_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
394
395        Ok(Self {
396            embed_dim,
397            num_heads,
398            head_dim,
399            scale,
400            q_proj: QLinear::from_linear(q_proj),
401            k_proj: QLinear::from_linear(k_proj),
402            v_proj: QLinear::from_linear(v_proj),
403            o_proj: QLinear::from_linear(o_proj),
404            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
405        })
406    }
407
408    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
409        let (b_sz, q_len, _) = xs.dims3()?;
410
411        let original_dtype = xs.dtype();
412        let mut xs = xs.clone();
413        if self.q_proj.is_quant() {
414            xs = xs.to_dtype(DType::F32)?;
415        }
416        let mut q = self.q_proj.forward(&xs)?;
417        let mut k = self.k_proj.forward(&xs)?;
418        let mut v = self.v_proj.forward(&xs)?;
419        if self.q_proj.is_quant() {
420            q = q.to_dtype(original_dtype)?;
421            k = k.to_dtype(original_dtype)?;
422            v = v.to_dtype(original_dtype)?;
423        }
424
425        let q = q
426            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
427            .transpose(1, 2)?;
428        let k = k
429            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
430            .transpose(1, 2)?;
431        let v = v
432            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
433            .transpose(1, 2)?;
434
435        let attn_weights =
436            (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
437
438        let attn_weights = CausalMasker.apply_mask_one_and_zero(
439            &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
440            attn_weights,
441            &self.neg_inf,
442        )?;
443        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
444        let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
445
446        if self.q_proj.is_quant() {
447            attn_output = attn_output.to_dtype(DType::F32)?;
448        }
449        let mut res = attn_output
450            .transpose(1, 2)?
451            .reshape((b_sz, q_len, self.embed_dim))?
452            .apply(&self.o_proj)?;
453        if self.q_proj.is_quant() {
454            res = res.to_dtype(original_dtype)?;
455        }
456        Ok(res)
457    }
458
459    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
460        let uvb = UnVarBuilder::new();
461
462        uvb.pp("q_proj").add(&self.q_proj);
463        uvb.pp("k_proj").add(&self.k_proj);
464        uvb.pp("v_proj").add(&self.v_proj);
465        uvb.pp("out_proj").add(&self.o_proj);
466
467        uvb.to_safetensors()
468    }
469}
470
471struct VisionMLP {
472    activation: Activation,
473    fc1: QLinear,
474    fc2: QLinear,
475}
476
477impl VisionMLP {
478    fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
479        let fc1 = linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
480        let fc2 = linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
481        Ok(Self {
482            activation: config.hidden_act,
483            fc1: QLinear::from_linear(fc1),
484            fc2: QLinear::from_linear(fc2),
485        })
486    }
487
488    fn forward(&self, x: &Tensor) -> Result<Tensor> {
489        let mut x = x.clone();
490        let original_dtype = x.dtype();
491        if self.fc1.is_quant() {
492            x = x.to_dtype(DType::F32)?;
493        }
494        let x = self.fc1.forward(&x)?;
495        let x = self.activation.forward(&x)?;
496        let mut res = self.fc2.forward(&x)?;
497        if self.fc1.is_quant() {
498            res = res.to_dtype(original_dtype)?;
499        }
500        Ok(res)
501    }
502
503    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
504        let uvb = UnVarBuilder::new();
505
506        uvb.pp("fc1").add(&self.fc1);
507        uvb.pp("fc2").add(&self.fc2);
508
509        uvb.to_safetensors()
510    }
511}
512
513struct EncoderLayer {
514    mlp: VisionMLP,
515    attn: Attention,
516    layer_norm_1: LayerNorm,
517    layer_norm_2: LayerNorm,
518}
519
520impl EncoderLayer {
521    fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
522        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
523        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
524        let layer_norm_1 = layer_norm(
525            config.hidden_size,
526            config.layer_norm_eps,
527            vb.pp("layer_norm1"),
528        )?;
529        let layer_norm_2 = layer_norm(
530            config.hidden_size,
531            config.layer_norm_eps,
532            vb.pp("layer_norm2"),
533        )?;
534        Ok(Self {
535            mlp,
536            attn,
537            layer_norm_1,
538            layer_norm_2,
539        })
540    }
541
542    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
543        let residual = xs.clone();
544
545        let hidden_states = self.layer_norm_1.forward(xs)?;
546        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
547        let hidden_states = (hidden_states + residual)?;
548
549        let residual = &hidden_states;
550        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
551        let hidden_states = self.mlp.forward(&hidden_states)?;
552        hidden_states + residual
553    }
554}
555
556struct Encoder {
557    layers: Vec<EncoderLayer>,
558}
559
560impl Encoder {
561    fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
562        let mut layers = Vec::new();
563        let vb_l = vb.pp("layers");
564        for i in 0..config.num_hidden_layers {
565            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
566        }
567        Ok(Self { layers })
568    }
569
570    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
571        let mut hidden_states = xs.clone();
572        for layer in &self.layers {
573            hidden_states = layer.forward(&hidden_states, attention_mask)?;
574        }
575        Ok(hidden_states)
576    }
577}
578
579struct VisionTransformer {
580    embeddings: VisionEmbeddings,
581    encoder: Encoder,
582    post_layernorm: LayerNorm,
583    config: VisionConfig,
584}
585
586impl VisionTransformer {
587    fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
588        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
589        let post_layernorm = layer_norm(
590            config.hidden_size,
591            config.layer_norm_eps,
592            vb.pp("post_layernorm"),
593        )?;
594        let encoder = Encoder::new(config, vb.pp("encoder"))?;
595        Ok(Self {
596            embeddings,
597            encoder,
598            post_layernorm,
599            config: config.clone(),
600        })
601    }
602
603    fn forward(&self, pixel_values: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
604        let bs = pixel_values.dim(0)?;
605        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
606            attn_mask.clone()
607        } else {
608            let patch_size = self.config.patch_size;
609            Tensor::ones(
610                (
611                    bs,
612                    pixel_values.dim(2)? / patch_size,
613                    pixel_values.dim(3)? / patch_size,
614                ),
615                DType::U8,
616                pixel_values.device(),
617            )?
618        };
619
620        let hidden_states = self
621            .embeddings
622            .forward(pixel_values, &patch_attention_mask)?;
623
624        let attention_mask = if attention_mask.is_none() {
625            None
626        } else {
627            let mask = patch_attention_mask
628                .reshape((patch_attention_mask.dim(0)?, ()))?
629                .to_dtype(hidden_states.dtype())?;
630            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
631        };
632        let hidden_states = self
633            .encoder
634            .forward(&hidden_states, attention_mask.as_ref())?;
635        hidden_states.apply(&self.post_layernorm)
636    }
637
638    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
639        let uvb = UnVarBuilder::new();
640
641        uvb.pp("post_layernorm").add(&self.post_layernorm);
642        uvb.pp("embeddings")
643            .extend(self.embeddings.residual_tensors());
644
645        let uvb_enc = uvb.pp("encoder");
646        for (i, layer) in self.encoder.layers.iter().enumerate() {
647            let uvb_l = uvb_enc.pp("layers").pp(i);
648
649            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
650            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
651            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
652            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
653        }
654
655        uvb.to_safetensors()
656    }
657}
658
659// == END VISION MODEL ==
660
661// == START CONNECTOR ==
662struct Mlp {
663    gate_proj: QLinear,
664    up_proj: QLinear,
665    down_proj: QLinear,
666    activation: Activation,
667}
668
669impl Mlp {
670    fn new(
671        hidden_size: usize,
672        intermediate_size: usize,
673        output_size: usize,
674        activation: Activation,
675        vb: ShardedVarBuilder,
676    ) -> Result<Self> {
677        let gate_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?;
678        let up_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?;
679        let down_proj = linear_no_bias(intermediate_size, output_size, vb.pp("down_proj"))?;
680        Ok(Self {
681            gate_proj: QLinear::from_linear(gate_proj),
682            up_proj: QLinear::from_linear(up_proj),
683            down_proj: QLinear::from_linear(down_proj),
684            activation,
685        })
686    }
687
688    fn forward(&self, x: &Tensor) -> Result<Tensor> {
689        let mut x = x.clone();
690        let original_dtype = x.dtype();
691        if self.gate_proj.is_quant() {
692            x = x.to_dtype(DType::F32)?;
693        }
694        let mut res = self.down_proj.forward(
695            &(self.activation.forward(&self.gate_proj.forward(&x)?)?
696                * self.up_proj.forward(&x)?)?,
697        )?;
698        if self.gate_proj.is_quant() {
699            res = res.to_dtype(original_dtype)?;
700        }
701        Ok(res)
702    }
703
704    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
705        let uvb = UnVarBuilder::new();
706
707        uvb.pp("gate_proj").add(&self.gate_proj);
708        uvb.pp("up_proj").add(&self.up_proj);
709        uvb.pp("down_proj").add(&self.down_proj);
710
711        uvb.to_safetensors()
712    }
713}
714
715struct PerceiverAttention {
716    num_heads: usize,
717    num_kv_heads: usize,
718    num_kv_groups: usize,
719    head_dim: usize,
720    q_proj: QLinear,
721    k_proj: QLinear,
722    v_proj: QLinear,
723    o_proj: QLinear,
724    neg_inf: Tensor,
725}
726
727impl PerceiverAttention {
728    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
729        let hidden_size = config.text_config.hidden_size;
730        let num_heads = config.perceiver_config.resampler_n_heads;
731        let head_dim = config.perceiver_config.resampler_head_dim;
732        let num_key_value_heads = config.perceiver_config.num_key_value_heads;
733        let num_key_value_groups = num_heads / num_key_value_heads;
734
735        let q_proj = linear_no_bias(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
736        let k_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("k_proj"))?;
737        let v_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("v_proj"))?;
738        let o_proj = linear_no_bias(num_heads * head_dim, hidden_size, vb.pp("o_proj"))?;
739
740        Ok(Self {
741            num_heads,
742            head_dim,
743            q_proj: QLinear::from_linear(q_proj),
744            k_proj: QLinear::from_linear(k_proj),
745            v_proj: QLinear::from_linear(v_proj),
746            o_proj: QLinear::from_linear(o_proj),
747            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
748            num_kv_heads: num_key_value_heads,
749            num_kv_groups: num_key_value_groups,
750        })
751    }
752
753    fn forward(
754        &self,
755        latents: &Tensor,
756        context: &Tensor,
757        attention_mask: &Tensor,
758    ) -> Result<Tensor> {
759        let (b_sz, q_len, _) = latents.dims3()?;
760        let kv_seq_len = q_len + context.dims()[1];
761
762        let mut hidden_states = Tensor::cat(&[context, latents], D::Minus2)?;
763
764        let original_dtype = latents.dtype();
765        let mut latents = latents.clone();
766        if self.q_proj.is_quant() {
767            latents = latents.to_dtype(DType::F32)?;
768            hidden_states = hidden_states.to_dtype(DType::F32)?;
769        }
770        let mut q = self.q_proj.forward(&latents)?;
771        let mut k = self.k_proj.forward(&hidden_states)?;
772        let mut v = self.v_proj.forward(&hidden_states)?;
773        if self.q_proj.is_quant() {
774            q = q.to_dtype(original_dtype)?;
775            k = k.to_dtype(original_dtype)?;
776            v = v.to_dtype(original_dtype)?;
777        }
778
779        let q = q
780            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
781            .transpose(1, 2)?;
782        let k = k
783            .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
784            .transpose(1, 2)?;
785        let v = v
786            .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
787            .transpose(1, 2)?;
788
789        let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
790        let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
791
792        let attn_weights = (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)?
793            / (self.head_dim as f64).sqrt())?;
794
795        let attn_weights = CausalMasker.apply_mask_one_and_zero(
796            &Some(attention_mask.to_dtype(DType::U8)?),
797            attn_weights,
798            &self.neg_inf,
799        )?;
800        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
801        let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
802
803        if self.q_proj.is_quant() {
804            attn_output = attn_output.to_dtype(DType::F32)?;
805        }
806        let mut res = attn_output
807            .transpose(1, 2)?
808            .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
809            .apply(&self.o_proj)?;
810        if self.q_proj.is_quant() {
811            res = res.to_dtype(original_dtype)?;
812        }
813        Ok(res)
814    }
815
816    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
817        let uvb = UnVarBuilder::new();
818
819        uvb.pp("q_proj").add(&self.q_proj);
820        uvb.pp("k_proj").add(&self.k_proj);
821        uvb.pp("v_proj").add(&self.v_proj);
822        uvb.pp("o_proj").add(&self.o_proj);
823
824        uvb.to_safetensors()
825    }
826}
827
828struct PerceiverLayer {
829    input_latents_norm: RmsNorm,
830    input_context_norm: RmsNorm,
831    self_attn: PerceiverAttention,
832    post_attn_norm: RmsNorm,
833    mlp: Mlp,
834}
835
836impl PerceiverLayer {
837    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
838        let hidden_size = config.text_config.hidden_size;
839        let mlp_act = config.perceiver_config.hidden_act;
840        let rms_eps = config.text_config.rms_norm_eps;
841
842        Ok(Self {
843            input_latents_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_latents_norm"))?,
844            input_context_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_context_norm"))?,
845            self_attn: PerceiverAttention::new(config, vb.pp("self_attn"))?,
846            post_attn_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("post_attention_layernorm"))?,
847            mlp: Mlp::new(
848                hidden_size,
849                hidden_size * 4,
850                hidden_size,
851                mlp_act,
852                vb.pp("mlp"),
853            )?,
854        })
855    }
856
857    fn forward(
858        &self,
859        latents: &Tensor,
860        context: &Tensor,
861        attention_mask: &Tensor,
862    ) -> Result<Tensor> {
863        let residual = latents;
864
865        let latents = self.input_latents_norm.forward(latents)?;
866        let context = self.input_context_norm.forward(context)?;
867
868        let latents = self.self_attn.forward(&latents, &context, attention_mask)?;
869        let latents = (residual + latents)?;
870        let residual = &latents;
871
872        let latents = self.post_attn_norm.forward(&latents)?;
873        let latents = self.mlp.forward(&latents)?;
874        residual + latents
875    }
876}
877
878struct PerceiverResampler {
879    latents: Tensor,
880    layers: Vec<PerceiverLayer>,
881    norm: RmsNorm,
882    n_latents: usize,
883}
884
885impl PerceiverResampler {
886    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
887        let n_latents = config.perceiver_config.resampler_n_latents;
888        let hidden_size = config.text_config.hidden_size;
889        let depth = config.perceiver_config.resampler_depth;
890
891        let latents = vb.get((n_latents, hidden_size), "latents")?;
892        let mut layers = Vec::new();
893        let vb_l = vb.pp("layers");
894        for i in 0..depth {
895            layers.push(PerceiverLayer::new(config, vb_l.pp(i))?);
896        }
897        let norm = RmsNorm::new(hidden_size, config.text_config.rms_norm_eps, vb.pp("norm"))?;
898        Ok(Self {
899            latents,
900            layers,
901            norm,
902            n_latents,
903        })
904    }
905
906    fn forward(&self, context: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
907        let mut s = vec![context.dim(0)?];
908        s.extend(self.latents.dims());
909        let latents = self.latents.unsqueeze(0)?.expand(s)?;
910
911        let latent_attention_mask = Tensor::ones(
912            (attention_mask.dim(0)?, latents.dim(1)?),
913            attention_mask.dtype(),
914            attention_mask.device(),
915        )?;
916        let attention_mask = Tensor::cat(&[attention_mask, &latent_attention_mask], D::Minus1)?;
917        let attention_mask =
918            CausalMasker.expand_mask(&attention_mask, latents.dtype(), Some(self.n_latents))?;
919
920        let mut compressed_context = latents;
921        for perceiver_layer in &self.layers {
922            compressed_context =
923                perceiver_layer.forward(&compressed_context, context, &attention_mask)?;
924        }
925        self.norm.forward(&compressed_context)
926    }
927
928    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
929        let uvb = UnVarBuilder::new();
930
931        uvb.pp("norm").add(&self.norm);
932        uvb.add_tensor("latents", self.latents.clone());
933
934        for (i, layer) in self.layers.iter().enumerate() {
935            let uvb_l = uvb.pp("layers").pp(i);
936
937            uvb_l
938                .pp("input_latents_norm")
939                .add(&layer.input_latents_norm);
940            uvb_l
941                .pp("input_context_norm")
942                .add(&layer.input_context_norm);
943            uvb_l
944                .pp("post_attention_layernorm")
945                .add(&layer.post_attn_norm);
946            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
947            uvb_l
948                .pp("self_attn")
949                .extend(layer.self_attn.residual_tensors());
950        }
951
952        uvb.to_safetensors()
953    }
954}
955
956struct Connector {
957    modality_projection: Mlp,
958    perceiver_resampler: PerceiverResampler,
959}
960
961impl Connector {
962    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
963        let modality_projection = Mlp::new(
964            config.vision_config.hidden_size,
965            config.text_config.intermediate_size,
966            config.text_config.hidden_size,
967            config.text_config.hidden_act,
968            vb.pp("modality_projection"),
969        )?;
970        let perceiver_resampler = PerceiverResampler::new(config, vb.pp("perceiver_resampler"))?;
971        Ok(Self {
972            modality_projection,
973            perceiver_resampler,
974        })
975    }
976
977    fn forward(&self, image_hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
978        let image_hidden_states = self.modality_projection.forward(image_hidden_states)?;
979        self.perceiver_resampler
980            .forward(&image_hidden_states, attention_mask)
981    }
982
983    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
984        let uvb = UnVarBuilder::new();
985
986        uvb.pp("modality_projection")
987            .extend(self.modality_projection.residual_tensors());
988        uvb.pp("perceiver_resampler")
989            .extend(self.perceiver_resampler.residual_tensors());
990
991        uvb.to_safetensors()
992    }
993}
994
995// == END CONNECTOR ==
996
997// == START MODEL ==
998
999pub struct Idefics2 {
1000    vision_model: VisionTransformer,
1001    connector: Connector,
1002    text_model: Mistral,
1003    dtype: DType,
1004    config: Config,
1005}
1006
1007impl Idefics2 {
1008    pub fn new(
1009        config: &Config,
1010        vb: ShardedVarBuilder,
1011        is_gptx: bool,
1012        normal_loading_metadata: NormalLoadingMetadata,
1013        attention_mechanism: AttentionImplementation,
1014    ) -> Result<Self> {
1015        let vb_m = vb.pp("model");
1016        let text_model = Mistral::new_inner(
1017            &config.text_config.clone().into(),
1018            vb_m.pp("text_model"),
1019            vb.pp("lm_head"),
1020            is_gptx,
1021            normal_loading_metadata,
1022            attention_mechanism,
1023        )?;
1024        let vision_model = VisionTransformer::new(
1025            &config.vision_config,
1026            vb_m.pp("vision_model")
1027                .set_device(text_model.device().clone()),
1028        )?;
1029        let connector = Connector::new(
1030            config,
1031            vb_m.pp("connector").set_device(text_model.device().clone()),
1032        )?;
1033        Ok(Self {
1034            vision_model,
1035            connector,
1036            text_model,
1037            dtype: vb.dtype(),
1038            config: config.clone(),
1039        })
1040    }
1041
1042    fn inputs_merger(
1043        &self,
1044        input_ids: &Tensor,
1045        input_embeds: &Tensor,
1046        image_hidden_states: &Tensor,
1047    ) -> Result<Tensor> {
1048        // Docs copied from Transformers impl
1049        /*
1050        This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
1051        The merging happens as follows:
1052        - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
1053        - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
1054        We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
1055        - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
1056        - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
1057        */
1058        let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
1059        let bs = input_ids.dim(0)?;
1060        let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
1061        let mut new_inputs_embeds = input_embeds.clone();
1062        let reshaped_image_hidden_states =
1063            image_hidden_states.reshape((bs, (), vision_hidden_size))?;
1064        assert_eq!(input_embeds.dim(0)?, 1);
1065        assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
1066        let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
1067        let mut image_hidden_state_i = 0;
1068        for (i, v) in special_image_token_mask.iter().enumerate() {
1069            if *v != 0 {
1070                new_inputs_embeds = new_inputs_embeds.slice_assign(
1071                    &[&.., &i, &..],
1072                    &reshaped_image_hidden_states
1073                        .i((.., image_hidden_state_i, ..))?
1074                        .unsqueeze(1)?,
1075                )?;
1076                image_hidden_state_i += 1;
1077            }
1078        }
1079        Ok(new_inputs_embeds)
1080    }
1081
1082    #[allow(clippy::too_many_arguments)]
1083    fn forward_inner(
1084        &self,
1085        input_ids: &Tensor,
1086        pixel_values: Option<Tensor>,
1087        seqlen_offsets: &[usize],
1088        context_lens: Vec<(usize, usize)>,
1089        pixel_attention_mask: Option<Tensor>,
1090        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1091        flash_params: &FlashParams,
1092    ) -> Result<Tensor> {
1093        let input_embeds = if let Some(pixel_values) = pixel_values {
1094            // == START VISUAL INPUTS INTEGRATION ==
1095            let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
1096            let mut s = vec![batch_size * num_images];
1097            s.extend(pixel_values.dims()[2..].to_vec());
1098            let pixel_values = pixel_values.reshape(s)?;
1099
1100            // Remove padding images which are full of 0s
1101            let nb_values_per_image = pixel_values.dims()[1..].iter().product::<usize>();
1102            let real_images_inds = pixel_values
1103                .eq(0.0f64)?
1104                .sum(vec![
1105                    pixel_values.dims().len() - 1,
1106                    pixel_values.dims().len() - 2,
1107                    pixel_values.dims().len() - 3,
1108                ])?
1109                .ne(nb_values_per_image as f64)?;
1110            let mut batches = Vec::new();
1111            for (batch, use_it) in pixel_values
1112                .chunk(pixel_values.dim(0)?, 0)?
1113                .iter()
1114                .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1115            {
1116                let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1117                if use_it {
1118                    batches.push(batch.clone());
1119                }
1120            }
1121            let pixel_values = Tensor::cat(&batches, 0)?;
1122
1123            // Vision attention mask
1124            let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
1125                let pixel_attention_mask = pixel_attention_mask.reshape((
1126                    batch_size * num_images,
1127                    pixel_attention_mask.dims()[2],
1128                    pixel_attention_mask.dims()[3],
1129                ))?;
1130                let mut batches = Vec::new();
1131                for (batch, use_it) in pixel_attention_mask
1132                    .chunk(pixel_attention_mask.dim(0)?, 0)?
1133                    .iter()
1134                    .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1135                {
1136                    let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1137                    if use_it {
1138                        batches.push(batch.clone());
1139                    }
1140                }
1141                Tensor::cat(&batches, 0)?
1142            } else {
1143                Tensor::ones(
1144                    (
1145                        pixel_values.dims()[0],
1146                        pixel_values.dims()[2],
1147                        pixel_values.dims()[3],
1148                    ),
1149                    DType::U8,
1150                    pixel_values.device(),
1151                )?
1152            };
1153
1154            let patch_size = self.config.vision_config.patch_size;
1155            let patches_subgrid = pixel_attention_mask.unfold(1, patch_size, patch_size)?;
1156            let patches_subgrid = patches_subgrid.unfold(2, patch_size, patch_size)?;
1157
1158            let patch_attention_mask = patches_subgrid
1159                .sum((D::Minus1, D::Minus2))?
1160                .eq((patch_size * patch_size) as f64)?
1161                .to_dtype(DType::U8)?;
1162
1163            let pixel_values = pixel_values.to_dtype(self.dtype)?;
1164
1165            // Get seq from vision encoder
1166            let image_hidden_states = self
1167                .vision_model
1168                .forward(&pixel_values, Some(&patch_attention_mask))?;
1169
1170            // Modality proj and perceiver resampling
1171            let image_hidden_states = self.connector.forward(
1172                &image_hidden_states,
1173                &patch_attention_mask.reshape((pixel_values.dim(0)?, ()))?,
1174            )?;
1175
1176            if self.text_model.cache.normal().0[0].current_seq_len() == 0 {
1177                self.inputs_merger(
1178                    input_ids,
1179                    &self.text_model.get_input_embeddings(input_ids)?,
1180                    &image_hidden_states,
1181                )?
1182            } else {
1183                candle_core::bail!("Pixel values were specified for a non-prompt.")
1184            }
1185        } else {
1186            self.text_model.get_input_embeddings(input_ids)?
1187        };
1188
1189        self.text_model.forward_embeds(
1190            input_ids,
1191            input_embeds,
1192            seqlen_offsets,
1193            context_lens,
1194            metadata,
1195            flash_params,
1196        )
1197    }
1198}
1199
1200impl IsqModel for Idefics2 {
1201    fn get_layers(
1202        &mut self,
1203    ) -> (
1204        Vec<(
1205            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
1206            Option<usize>,
1207        )>,
1208        &dyn DeviceMapper,
1209    ) {
1210        self.text_model.get_layers()
1211    }
1212
1213    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1214        let uvb = UnVarBuilder::new();
1215
1216        let uvb_m = uvb.pp("model");
1217        uvb_m
1218            .pp("text_model")
1219            .extend(self.text_model.residual_tensors());
1220        uvb_m
1221            .pp("vision_model")
1222            .extend(self.vision_model.residual_tensors());
1223        uvb_m
1224            .pp("connector")
1225            .extend(self.connector.residual_tensors());
1226
1227        uvb.to_safetensors()
1228    }
1229}
1230
1231// AnyMoE is forwarded to the base model
1232impl AnyMoeBaseModelMixin for Idefics2 {
1233    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1234        self.text_model.get_mlps()
1235    }
1236    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1237        self.text_model.get_mlps_mut()
1238    }
1239    fn create_anymoe_layers(
1240        &mut self,
1241        additional_vbs: Vec<ShardedVarBuilder>,
1242        config: AnyMoeConfig,
1243        (prefix, mlp): (String, String),
1244        layers: Vec<usize>,
1245        expert_type: AnyMoeExpertType,
1246        gate_vb: Option<ShardedVarBuilder>,
1247    ) -> Result<()> {
1248        self.text_model.create_anymoe_layers(
1249            additional_vbs,
1250            config,
1251            (prefix, mlp),
1252            layers,
1253            expert_type,
1254            gate_vb,
1255        )
1256    }
1257    fn amoe_supported(&self) -> bool {
1258        true
1259    }
1260}
1261
1262impl VisionModel for Idefics2 {
1263    fn forward(
1264        &self,
1265        input_ids: &Tensor,
1266        pixel_values: Option<Tensor>,
1267        seqlen_offsets: &[usize],
1268        context_lens: Vec<(usize, usize)>,
1269        _: Vec<usize>, // Ignore, it is for phi3
1270        model_specific_args: Box<dyn Any>,
1271        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1272        flash_params: &FlashParams,
1273    ) -> candle_core::Result<Tensor> {
1274        let pixel_attention_mask: Option<Tensor> = *model_specific_args
1275            .downcast()
1276            .expect("Cannot downcast into `Option<Tensor>`");
1277        self.forward_inner(
1278            input_ids,
1279            pixel_values,
1280            seqlen_offsets,
1281            context_lens,
1282            pixel_attention_mask,
1283            metadata,
1284            flash_params,
1285        )
1286    }
1287    fn cache(&self) -> &EitherCache {
1288        self.text_model.cache()
1289    }
1290    fn cache_mut(&mut self) -> &mut EitherCache {
1291        self.text_model.cache_mut()
1292    }
1293    fn device(&self) -> &Device {
1294        self.text_model.device()
1295    }
1296    fn max_seq_len(&self) -> usize {
1297        self.text_model.max_seq_len()
1298    }
1299    fn has_conv2d(&self) -> bool {
1300        true
1301    }
1302    fn config(&self) -> &ModelConfigMetadata {
1303        self.text_model.config()
1304    }
1305    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1306        let args: Option<Tensor> = None;
1307        Box::new(args)
1308    }
1309}