mistralrs_core/vision_models/idefics2/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod idefics2_input_processor;
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
7use mistralrs_quant::ShardedVarBuilder;
8use serde::Deserialize;
9use std::{any::Any, ops::Mul};
10
11use crate::{
12    amoe::{AnyMoeBaseModelMixin, MlpLayer},
13    device_map::DeviceMapper,
14    layers::{
15        conv2d, embedding, layer_norm, linear, linear_no_bias, repeat_kv, Activation, CausalMasker,
16        MatMul, QLinear, RmsNorm,
17    },
18    models::mistral::Model as Mistral,
19    paged_attention::{AttentionImplementation, ModelConfigMetadata},
20    pipeline::{
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
23    },
24    utils::unvarbuilder::UnVarBuilder,
25    AnyMoeConfig, AnyMoeExpertType,
26};
27
28use crate::models::mistral;
29
30// https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics2/modeling_idefics2.py
31
32fn default_32000() -> usize {
33    32000
34}
35fn default_32001() -> usize {
36    32001
37}
38fn default_4096() -> usize {
39    4096
40}
41fn default_14336() -> usize {
42    14336
43}
44fn default_32() -> usize {
45    32
46}
47fn default_8() -> usize {
48    8
49}
50fn default_act() -> Activation {
51    Activation::Silu
52}
53fn default_131072() -> usize {
54    131072
55}
56fn default_eps() -> f64 {
57    1e-6
58}
59fn default_rope() -> f64 {
60    10000.0
61}
62fn default_false() -> bool {
63    false
64}
65fn default_sliding() -> Option<usize> {
66    Some(4096)
67}
68fn default_gelu() -> Activation {
69    Activation::GeluPytorchTanh
70}
71fn default_64() -> usize {
72    64
73}
74fn default_3() -> usize {
75    3
76}
77fn default_16() -> usize {
78    16
79}
80fn default_96() -> usize {
81    96
82}
83fn default_4() -> usize {
84    4
85}
86fn default_0_0() -> f32 {
87    0.0
88}
89fn default_0_02() -> f32 {
90    0.02
91}
92fn default_768() -> usize {
93    768
94}
95fn default_3072() -> usize {
96    3072
97}
98fn default_12() -> usize {
99    12
100}
101fn default_224() -> usize {
102    224
103}
104
105#[derive(Debug, Clone, PartialEq, Deserialize)]
106pub struct PerceiverConfig {
107    #[serde(default = "default_act")]
108    pub hidden_act: Activation,
109    #[serde(default = "default_64")]
110    pub resampler_n_latents: usize,
111    #[serde(default = "default_3")]
112    pub resampler_depth: usize,
113    #[serde(default = "default_16")]
114    pub resampler_n_heads: usize,
115    #[serde(default = "default_96")]
116    pub resampler_head_dim: usize,
117    #[serde(default = "default_4")]
118    pub num_key_value_heads: usize,
119    #[serde(default = "default_0_0")]
120    pub attention_dropout: f32,
121}
122
123#[derive(Debug, Clone, PartialEq, Deserialize)]
124pub struct VisionConfig {
125    #[serde(default = "default_768")]
126    pub hidden_size: usize,
127    #[serde(default = "default_3072")]
128    pub intermediate_size: usize,
129    #[serde(default = "default_12")]
130    pub num_hidden_layers: usize,
131    #[serde(default = "default_12")]
132    pub num_attention_heads: usize,
133    #[serde(default = "default_3")]
134    pub num_channels: usize,
135    #[serde(default = "default_224")]
136    pub image_size: usize,
137    #[serde(default = "default_32")]
138    pub patch_size: usize,
139    #[serde(default = "default_gelu")]
140    pub hidden_act: Activation,
141    #[serde(default = "default_eps")]
142    pub layer_norm_eps: f64,
143    #[serde(default = "default_0_0")]
144    pub attn_dropout: f32,
145    #[serde(default = "default_0_02")]
146    pub initiailizer_range: f32,
147}
148
149#[derive(Debug, Clone, PartialEq, Deserialize)]
150pub(crate) struct TextConfig {
151    #[serde(default = "default_32000")]
152    pub(crate) vocab_size: usize,
153    #[serde(default = "default_4096")]
154    pub(crate) hidden_size: usize,
155    #[serde(default = "default_14336")]
156    pub(crate) intermediate_size: usize,
157    #[serde(default = "default_32")]
158    pub(crate) num_hidden_layers: usize,
159    #[serde(default = "default_32")]
160    pub(crate) num_attention_heads: usize,
161    #[serde(default = "default_8")]
162    pub(crate) num_key_value_heads: usize,
163    #[serde(default = "default_act")]
164    pub(crate) hidden_act: Activation,
165    #[serde(default = "default_131072")]
166    pub(crate) max_position_embeddings: usize,
167    #[serde(default = "default_eps")]
168    pub(crate) rms_norm_eps: f64,
169    #[serde(default = "default_rope")]
170    pub(crate) rope_theta: f64,
171    #[serde(default = "default_sliding")]
172    pub(crate) sliding_window: Option<usize>,
173
174    #[serde(default = "default_false")]
175    pub(crate) use_flash_attn: bool,
176    model_type: String, // Must be mistral for now
177}
178
179impl From<TextConfig> for mistral::Config {
180    fn from(val: TextConfig) -> Self {
181        mistral::Config {
182            vocab_size: val.vocab_size,
183            hidden_act: val.hidden_act,
184            hidden_size: val.hidden_size,
185            intermediate_size: val.intermediate_size,
186            num_hidden_layers: val.num_hidden_layers,
187            num_attention_heads: val.num_attention_heads,
188            num_key_value_heads: val.num_key_value_heads,
189            max_position_embeddings: val.max_position_embeddings,
190            rms_norm_eps: val.rms_norm_eps,
191            rope_theta: val.rope_theta,
192            sliding_window: val.sliding_window,
193            use_flash_attn: val.use_flash_attn,
194            head_dim: None,
195            quantization_config: None,
196            tie_word_embeddings: false,
197        }
198    }
199}
200
201#[derive(Debug, Clone, PartialEq, Deserialize)]
202pub(crate) struct Config {
203    pub perceiver_config: PerceiverConfig,
204    pub vision_config: VisionConfig,
205    pub(crate) text_config: TextConfig,
206    #[serde(default = "default_32001")]
207    pub image_token_id: usize,
208    #[serde(default = "default_false")]
209    pub tie_word_embeddings: bool,
210}
211
212// == START VISION MODEL ==
213
214struct VisionEmbeddings {
215    patch_size: usize,
216    patch_embedding: Conv2d,
217    num_patches_per_side: usize,
218    position_embedding: Embedding,
219}
220
221/// torch.bucketize with right=True
222/// Returns a 1d tensor of shape (xs.len(),) on the CPU
223fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
224    use std::cmp::Ordering;
225
226    let mut result = Vec::with_capacity(xs.len());
227
228    for &x in xs {
229        // binary_search_by returns:
230        //   Ok(i)   if boundaries[i] == x
231        //   Err(i)  if x would be inserted at i
232        //
233        // The returned i is the "insertion point" for x to keep
234        // boundaries sorted. That i is the smallest position
235        // where boundaries[i] >= x (i.e. bisect_left).
236
237        let idx = match boundaries.binary_search_by(|&val| {
238            // Use partial_cmp here; assume no NaNs.
239            // For robust handling of NaNs, you might need a custom comparison.
240            val.partial_cmp(&x).unwrap_or(Ordering::Less)
241        }) {
242            Ok(i) => i,
243            Err(i) => i,
244        };
245
246        result.push(idx as u32);
247    }
248
249    Tensor::from_vec(result, (xs.len(),), device)
250}
251
252impl VisionEmbeddings {
253    fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
254        let conv_config = Conv2dConfig {
255            stride: config.patch_size,
256            ..Default::default()
257        };
258        let patch_embedding = conv2d(
259            config.num_channels,
260            config.hidden_size,
261            config.patch_size,
262            conv_config,
263            vb.pp("patch_embedding"),
264        )?;
265        let num_patches_per_side = config.image_size / config.patch_size;
266        let num_patches = num_patches_per_side.pow(2);
267        Ok(Self {
268            patch_size: config.patch_size,
269            patch_embedding,
270            num_patches_per_side,
271            position_embedding: embedding(
272                num_patches,
273                config.hidden_size,
274                vb.pp("position_embedding"),
275            )?,
276        })
277    }
278
279    fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
280        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
281
282        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
283
284        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
285
286        let (max_nb_patches_h, max_nb_patches_w) =
287            (max_im_h / self.patch_size, max_im_w / self.patch_size);
288        let boundaries = Tensor::arange_step(
289            1.0 / self.num_patches_per_side as f32,
290            1.0,
291            1.0 / self.num_patches_per_side as f32,
292            pixel_values.device(),
293        )?
294        .to_vec1::<f32>()?;
295        let position_ids = Tensor::full(
296            0u32,
297            (bs, max_nb_patches_h * max_nb_patches_w),
298            pixel_values.device(),
299        )?;
300
301        let mut new_position_ids = Vec::new();
302        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
303            let p_attn_mask = p_attn_mask.squeeze(0)?;
304            let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
305            let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
306
307            let fractional_coords_h = Tensor::arange_step(
308                0.0,
309                1.0 - 1e-6,
310                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
311                pixel_values.device(),
312            )?
313            .to_vec1::<f32>()?;
314            let fractional_coords_w = Tensor::arange_step(
315                0.0,
316                1.0 - 1e-6,
317                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
318                pixel_values.device(),
319            )?
320            .to_vec1::<f32>()?;
321
322            let bucket_coords_h =
323                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
324            let bucket_coords_w =
325                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
326
327            let pos_ids = bucket_coords_h
328                .unsqueeze(D::Minus1)?
329                .mul(self.num_patches_per_side as f64)?
330                .broadcast_add(&bucket_coords_w)?
331                .flatten_all()?
332                .to_vec1::<u32>()?;
333
334            let true_indices = p_attn_mask
335                .flatten_all()?
336                .to_vec1::<u8>()?
337                .iter()
338                .enumerate()
339                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
340                .collect::<Vec<_>>();
341            let position_ids_b = position_ids.i(b_idx)?;
342
343            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
344            let new_position_ids_b_len = new_position_ids_b.len();
345            for (i, true_idx) in true_indices.into_iter().enumerate() {
346                new_position_ids_b[true_idx] = pos_ids[i];
347            }
348
349            new_position_ids.push(Tensor::from_vec(
350                new_position_ids_b,
351                new_position_ids_b_len,
352                pixel_values.device(),
353            )?);
354        }
355        let position_ids = Tensor::stack(&new_position_ids, 0)?;
356        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
357        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
358    }
359
360    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
361        let uvb = UnVarBuilder::new();
362
363        uvb.pp("patch_embedding").add(&self.patch_embedding);
364        uvb.pp("position_embedding").add(&self.position_embedding);
365
366        uvb.to_safetensors()
367    }
368}
369
370struct Attention {
371    embed_dim: usize,
372    num_heads: usize,
373    head_dim: usize,
374    scale: f64,
375    q_proj: QLinear,
376    k_proj: QLinear,
377    v_proj: QLinear,
378    o_proj: QLinear,
379    neg_inf: Tensor,
380}
381
382impl Attention {
383    fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
384        let embed_dim = config.hidden_size;
385        let num_heads = config.num_attention_heads;
386        let head_dim = embed_dim / num_heads;
387        let scale = 1.0 / (head_dim as f64).sqrt();
388
389        let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
390        let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
391        let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
392        let o_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
393
394        Ok(Self {
395            embed_dim,
396            num_heads,
397            head_dim,
398            scale,
399            q_proj: QLinear::from_linear(q_proj),
400            k_proj: QLinear::from_linear(k_proj),
401            v_proj: QLinear::from_linear(v_proj),
402            o_proj: QLinear::from_linear(o_proj),
403            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
404        })
405    }
406
407    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
408        let (b_sz, q_len, _) = xs.dims3()?;
409
410        let original_dtype = xs.dtype();
411        let mut xs = xs.clone();
412        if self.q_proj.is_quant() {
413            xs = xs.to_dtype(DType::F32)?;
414        }
415        let mut q = self.q_proj.forward(&xs)?;
416        let mut k = self.k_proj.forward(&xs)?;
417        let mut v = self.v_proj.forward(&xs)?;
418        if self.q_proj.is_quant() {
419            q = q.to_dtype(original_dtype)?;
420            k = k.to_dtype(original_dtype)?;
421            v = v.to_dtype(original_dtype)?;
422        }
423
424        let q = q
425            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
426            .transpose(1, 2)?;
427        let k = k
428            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
429            .transpose(1, 2)?;
430        let v = v
431            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
432            .transpose(1, 2)?;
433
434        let attn_weights =
435            (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
436
437        let attn_weights = CausalMasker.apply_mask_one_and_zero(
438            &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
439            attn_weights,
440            &self.neg_inf,
441        )?;
442        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
443        let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
444
445        if self.q_proj.is_quant() {
446            attn_output = attn_output.to_dtype(DType::F32)?;
447        }
448        let mut res = attn_output
449            .transpose(1, 2)?
450            .reshape((b_sz, q_len, self.embed_dim))?
451            .apply(&self.o_proj)?;
452        if self.q_proj.is_quant() {
453            res = res.to_dtype(original_dtype)?;
454        }
455        Ok(res)
456    }
457
458    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
459        let uvb = UnVarBuilder::new();
460
461        uvb.pp("q_proj").add(&self.q_proj);
462        uvb.pp("k_proj").add(&self.k_proj);
463        uvb.pp("v_proj").add(&self.v_proj);
464        uvb.pp("out_proj").add(&self.o_proj);
465
466        uvb.to_safetensors()
467    }
468}
469
470struct VisionMLP {
471    activation: Activation,
472    fc1: QLinear,
473    fc2: QLinear,
474}
475
476impl VisionMLP {
477    fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
478        let fc1 = linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
479        let fc2 = linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
480        Ok(Self {
481            activation: config.hidden_act,
482            fc1: QLinear::from_linear(fc1),
483            fc2: QLinear::from_linear(fc2),
484        })
485    }
486
487    fn forward(&self, x: &Tensor) -> Result<Tensor> {
488        let mut x = x.clone();
489        let original_dtype = x.dtype();
490        if self.fc1.is_quant() {
491            x = x.to_dtype(DType::F32)?;
492        }
493        let x = self.fc1.forward(&x)?;
494        let x = self.activation.forward(&x)?;
495        let mut res = self.fc2.forward(&x)?;
496        if self.fc1.is_quant() {
497            res = res.to_dtype(original_dtype)?;
498        }
499        Ok(res)
500    }
501
502    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
503        let uvb = UnVarBuilder::new();
504
505        uvb.pp("fc1").add(&self.fc1);
506        uvb.pp("fc2").add(&self.fc2);
507
508        uvb.to_safetensors()
509    }
510}
511
512struct EncoderLayer {
513    mlp: VisionMLP,
514    attn: Attention,
515    layer_norm_1: LayerNorm,
516    layer_norm_2: LayerNorm,
517}
518
519impl EncoderLayer {
520    fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
521        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
522        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
523        let layer_norm_1 = layer_norm(
524            config.hidden_size,
525            config.layer_norm_eps,
526            vb.pp("layer_norm1"),
527        )?;
528        let layer_norm_2 = layer_norm(
529            config.hidden_size,
530            config.layer_norm_eps,
531            vb.pp("layer_norm2"),
532        )?;
533        Ok(Self {
534            mlp,
535            attn,
536            layer_norm_1,
537            layer_norm_2,
538        })
539    }
540
541    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
542        let residual = xs.clone();
543
544        let hidden_states = self.layer_norm_1.forward(xs)?;
545        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
546        let hidden_states = (hidden_states + residual)?;
547
548        let residual = &hidden_states;
549        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
550        let hidden_states = self.mlp.forward(&hidden_states)?;
551        hidden_states + residual
552    }
553}
554
555struct Encoder {
556    layers: Vec<EncoderLayer>,
557}
558
559impl Encoder {
560    fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
561        let mut layers = Vec::new();
562        let vb_l = vb.pp("layers");
563        for i in 0..config.num_hidden_layers {
564            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
565        }
566        Ok(Self { layers })
567    }
568
569    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
570        let mut hidden_states = xs.clone();
571        for layer in &self.layers {
572            hidden_states = layer.forward(&hidden_states, attention_mask)?;
573        }
574        Ok(hidden_states)
575    }
576}
577
578struct VisionTransformer {
579    embeddings: VisionEmbeddings,
580    encoder: Encoder,
581    post_layernorm: LayerNorm,
582    config: VisionConfig,
583}
584
585impl VisionTransformer {
586    fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
587        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
588        let post_layernorm = layer_norm(
589            config.hidden_size,
590            config.layer_norm_eps,
591            vb.pp("post_layernorm"),
592        )?;
593        let encoder = Encoder::new(config, vb.pp("encoder"))?;
594        Ok(Self {
595            embeddings,
596            encoder,
597            post_layernorm,
598            config: config.clone(),
599        })
600    }
601
602    fn forward(&self, pixel_values: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
603        let bs = pixel_values.dim(0)?;
604        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
605            attn_mask.clone()
606        } else {
607            let patch_size = self.config.patch_size;
608            Tensor::ones(
609                (
610                    bs,
611                    pixel_values.dim(2)? / patch_size,
612                    pixel_values.dim(3)? / patch_size,
613                ),
614                DType::U8,
615                pixel_values.device(),
616            )?
617        };
618
619        let hidden_states = self
620            .embeddings
621            .forward(pixel_values, &patch_attention_mask)?;
622
623        let attention_mask = if attention_mask.is_none() {
624            None
625        } else {
626            let mask = patch_attention_mask
627                .reshape((patch_attention_mask.dim(0)?, ()))?
628                .to_dtype(hidden_states.dtype())?;
629            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
630        };
631        let hidden_states = self
632            .encoder
633            .forward(&hidden_states, attention_mask.as_ref())?;
634        hidden_states.apply(&self.post_layernorm)
635    }
636
637    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
638        let uvb = UnVarBuilder::new();
639
640        uvb.pp("post_layernorm").add(&self.post_layernorm);
641        uvb.pp("embeddings")
642            .extend(self.embeddings.residual_tensors());
643
644        let uvb_enc = uvb.pp("encoder");
645        for (i, layer) in self.encoder.layers.iter().enumerate() {
646            let uvb_l = uvb_enc.pp("layers").pp(i);
647
648            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
649            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
650            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
651            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
652        }
653
654        uvb.to_safetensors()
655    }
656}
657
658// == END VISION MODEL ==
659
660// == START CONNECTOR ==
661struct Mlp {
662    gate_proj: QLinear,
663    up_proj: QLinear,
664    down_proj: QLinear,
665    activation: Activation,
666}
667
668impl Mlp {
669    fn new(
670        hidden_size: usize,
671        intermediate_size: usize,
672        output_size: usize,
673        activation: Activation,
674        vb: ShardedVarBuilder,
675    ) -> Result<Self> {
676        let gate_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?;
677        let up_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?;
678        let down_proj = linear_no_bias(intermediate_size, output_size, vb.pp("down_proj"))?;
679        Ok(Self {
680            gate_proj: QLinear::from_linear(gate_proj),
681            up_proj: QLinear::from_linear(up_proj),
682            down_proj: QLinear::from_linear(down_proj),
683            activation,
684        })
685    }
686
687    fn forward(&self, x: &Tensor) -> Result<Tensor> {
688        let mut x = x.clone();
689        let original_dtype = x.dtype();
690        if self.gate_proj.is_quant() {
691            x = x.to_dtype(DType::F32)?;
692        }
693        let mut res = self.down_proj.forward(
694            &(self.activation.forward(&self.gate_proj.forward(&x)?)?
695                * self.up_proj.forward(&x)?)?,
696        )?;
697        if self.gate_proj.is_quant() {
698            res = res.to_dtype(original_dtype)?;
699        }
700        Ok(res)
701    }
702
703    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
704        let uvb = UnVarBuilder::new();
705
706        uvb.pp("gate_proj").add(&self.gate_proj);
707        uvb.pp("up_proj").add(&self.up_proj);
708        uvb.pp("down_proj").add(&self.down_proj);
709
710        uvb.to_safetensors()
711    }
712}
713
714struct PerceiverAttention {
715    num_heads: usize,
716    num_kv_heads: usize,
717    num_kv_groups: usize,
718    head_dim: usize,
719    q_proj: QLinear,
720    k_proj: QLinear,
721    v_proj: QLinear,
722    o_proj: QLinear,
723    neg_inf: Tensor,
724}
725
726impl PerceiverAttention {
727    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
728        let hidden_size = config.text_config.hidden_size;
729        let num_heads = config.perceiver_config.resampler_n_heads;
730        let head_dim = config.perceiver_config.resampler_head_dim;
731        let num_key_value_heads = config.perceiver_config.num_key_value_heads;
732        let num_key_value_groups = num_heads / num_key_value_heads;
733
734        let q_proj = linear_no_bias(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
735        let k_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("k_proj"))?;
736        let v_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("v_proj"))?;
737        let o_proj = linear_no_bias(num_heads * head_dim, hidden_size, vb.pp("o_proj"))?;
738
739        Ok(Self {
740            num_heads,
741            head_dim,
742            q_proj: QLinear::from_linear(q_proj),
743            k_proj: QLinear::from_linear(k_proj),
744            v_proj: QLinear::from_linear(v_proj),
745            o_proj: QLinear::from_linear(o_proj),
746            neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
747            num_kv_heads: num_key_value_heads,
748            num_kv_groups: num_key_value_groups,
749        })
750    }
751
752    fn forward(
753        &self,
754        latents: &Tensor,
755        context: &Tensor,
756        attention_mask: &Tensor,
757    ) -> Result<Tensor> {
758        let (b_sz, q_len, _) = latents.dims3()?;
759        let kv_seq_len = q_len + context.dims()[1];
760
761        let mut hidden_states = Tensor::cat(&[context, latents], D::Minus2)?;
762
763        let original_dtype = latents.dtype();
764        let mut latents = latents.clone();
765        if self.q_proj.is_quant() {
766            latents = latents.to_dtype(DType::F32)?;
767            hidden_states = hidden_states.to_dtype(DType::F32)?;
768        }
769        let mut q = self.q_proj.forward(&latents)?;
770        let mut k = self.k_proj.forward(&hidden_states)?;
771        let mut v = self.v_proj.forward(&hidden_states)?;
772        if self.q_proj.is_quant() {
773            q = q.to_dtype(original_dtype)?;
774            k = k.to_dtype(original_dtype)?;
775            v = v.to_dtype(original_dtype)?;
776        }
777
778        let q = q
779            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
780            .transpose(1, 2)?;
781        let k = k
782            .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
783            .transpose(1, 2)?;
784        let v = v
785            .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
786            .transpose(1, 2)?;
787
788        let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
789        let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
790
791        let attn_weights = (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)?
792            / (self.head_dim as f64).sqrt())?;
793
794        let attn_weights = CausalMasker.apply_mask_one_and_zero(
795            &Some(attention_mask.to_dtype(DType::U8)?),
796            attn_weights,
797            &self.neg_inf,
798        )?;
799        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
800        let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
801
802        if self.q_proj.is_quant() {
803            attn_output = attn_output.to_dtype(DType::F32)?;
804        }
805        let mut res = attn_output
806            .transpose(1, 2)?
807            .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
808            .apply(&self.o_proj)?;
809        if self.q_proj.is_quant() {
810            res = res.to_dtype(original_dtype)?;
811        }
812        Ok(res)
813    }
814
815    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
816        let uvb = UnVarBuilder::new();
817
818        uvb.pp("q_proj").add(&self.q_proj);
819        uvb.pp("k_proj").add(&self.k_proj);
820        uvb.pp("v_proj").add(&self.v_proj);
821        uvb.pp("o_proj").add(&self.o_proj);
822
823        uvb.to_safetensors()
824    }
825}
826
827struct PerceiverLayer {
828    input_latents_norm: RmsNorm,
829    input_context_norm: RmsNorm,
830    self_attn: PerceiverAttention,
831    post_attn_norm: RmsNorm,
832    mlp: Mlp,
833}
834
835impl PerceiverLayer {
836    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
837        let hidden_size = config.text_config.hidden_size;
838        let mlp_act = config.perceiver_config.hidden_act;
839        let rms_eps = config.text_config.rms_norm_eps;
840
841        Ok(Self {
842            input_latents_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_latents_norm"))?,
843            input_context_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_context_norm"))?,
844            self_attn: PerceiverAttention::new(config, vb.pp("self_attn"))?,
845            post_attn_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("post_attention_layernorm"))?,
846            mlp: Mlp::new(
847                hidden_size,
848                hidden_size * 4,
849                hidden_size,
850                mlp_act,
851                vb.pp("mlp"),
852            )?,
853        })
854    }
855
856    fn forward(
857        &self,
858        latents: &Tensor,
859        context: &Tensor,
860        attention_mask: &Tensor,
861    ) -> Result<Tensor> {
862        let residual = latents;
863
864        let latents = self.input_latents_norm.forward(latents)?;
865        let context = self.input_context_norm.forward(context)?;
866
867        let latents = self.self_attn.forward(&latents, &context, attention_mask)?;
868        let latents = (residual + latents)?;
869        let residual = &latents;
870
871        let latents = self.post_attn_norm.forward(&latents)?;
872        let latents = self.mlp.forward(&latents)?;
873        residual + latents
874    }
875}
876
877struct PerceiverResampler {
878    latents: Tensor,
879    layers: Vec<PerceiverLayer>,
880    norm: RmsNorm,
881    n_latents: usize,
882}
883
884impl PerceiverResampler {
885    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
886        let n_latents = config.perceiver_config.resampler_n_latents;
887        let hidden_size = config.text_config.hidden_size;
888        let depth = config.perceiver_config.resampler_depth;
889
890        let latents = vb.get((n_latents, hidden_size), "latents")?;
891        let mut layers = Vec::new();
892        let vb_l = vb.pp("layers");
893        for i in 0..depth {
894            layers.push(PerceiverLayer::new(config, vb_l.pp(i))?);
895        }
896        let norm = RmsNorm::new(hidden_size, config.text_config.rms_norm_eps, vb.pp("norm"))?;
897        Ok(Self {
898            latents,
899            layers,
900            norm,
901            n_latents,
902        })
903    }
904
905    fn forward(&self, context: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
906        let mut s = vec![context.dim(0)?];
907        s.extend(self.latents.dims());
908        let latents = self.latents.unsqueeze(0)?.expand(s)?;
909
910        let latent_attention_mask = Tensor::ones(
911            (attention_mask.dim(0)?, latents.dim(1)?),
912            attention_mask.dtype(),
913            attention_mask.device(),
914        )?;
915        let attention_mask = Tensor::cat(&[attention_mask, &latent_attention_mask], D::Minus1)?;
916        let attention_mask =
917            CausalMasker.expand_mask(&attention_mask, latents.dtype(), Some(self.n_latents))?;
918
919        let mut compressed_context = latents;
920        for perceiver_layer in &self.layers {
921            compressed_context =
922                perceiver_layer.forward(&compressed_context, context, &attention_mask)?;
923        }
924        self.norm.forward(&compressed_context)
925    }
926
927    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
928        let uvb = UnVarBuilder::new();
929
930        uvb.pp("norm").add(&self.norm);
931        uvb.add_tensor("latents", self.latents.clone());
932
933        for (i, layer) in self.layers.iter().enumerate() {
934            let uvb_l = uvb.pp("layers").pp(i);
935
936            uvb_l
937                .pp("input_latents_norm")
938                .add(&layer.input_latents_norm);
939            uvb_l
940                .pp("input_context_norm")
941                .add(&layer.input_context_norm);
942            uvb_l
943                .pp("post_attention_layernorm")
944                .add(&layer.post_attn_norm);
945            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
946            uvb_l
947                .pp("self_attn")
948                .extend(layer.self_attn.residual_tensors());
949        }
950
951        uvb.to_safetensors()
952    }
953}
954
955struct Connector {
956    modality_projection: Mlp,
957    perceiver_resampler: PerceiverResampler,
958}
959
960impl Connector {
961    fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
962        let modality_projection = Mlp::new(
963            config.vision_config.hidden_size,
964            config.text_config.intermediate_size,
965            config.text_config.hidden_size,
966            config.text_config.hidden_act,
967            vb.pp("modality_projection"),
968        )?;
969        let perceiver_resampler = PerceiverResampler::new(config, vb.pp("perceiver_resampler"))?;
970        Ok(Self {
971            modality_projection,
972            perceiver_resampler,
973        })
974    }
975
976    fn forward(&self, image_hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
977        let image_hidden_states = self.modality_projection.forward(image_hidden_states)?;
978        self.perceiver_resampler
979            .forward(&image_hidden_states, attention_mask)
980    }
981
982    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
983        let uvb = UnVarBuilder::new();
984
985        uvb.pp("modality_projection")
986            .extend(self.modality_projection.residual_tensors());
987        uvb.pp("perceiver_resampler")
988            .extend(self.perceiver_resampler.residual_tensors());
989
990        uvb.to_safetensors()
991    }
992}
993
994// == END CONNECTOR ==
995
996// == START MODEL ==
997
998pub struct Idefics2 {
999    vision_model: VisionTransformer,
1000    connector: Connector,
1001    text_model: Mistral,
1002    dtype: DType,
1003    config: Config,
1004}
1005
1006impl Idefics2 {
1007    pub fn new(
1008        config: &Config,
1009        vb: ShardedVarBuilder,
1010        is_gptx: bool,
1011        normal_loading_metadata: NormalLoadingMetadata,
1012        attention_mechanism: AttentionImplementation,
1013    ) -> Result<Self> {
1014        let vb_m = vb.pp("model");
1015        let text_model = Mistral::new_inner(
1016            &config.text_config.clone().into(),
1017            vb_m.pp("text_model"),
1018            vb.pp("lm_head"),
1019            is_gptx,
1020            normal_loading_metadata,
1021            attention_mechanism,
1022        )?;
1023        let vision_model = VisionTransformer::new(
1024            &config.vision_config,
1025            vb_m.pp("vision_model")
1026                .set_device(text_model.device().clone()),
1027        )?;
1028        let connector = Connector::new(
1029            config,
1030            vb_m.pp("connector").set_device(text_model.device().clone()),
1031        )?;
1032        Ok(Self {
1033            vision_model,
1034            connector,
1035            text_model,
1036            dtype: vb.dtype(),
1037            config: config.clone(),
1038        })
1039    }
1040
1041    fn inputs_merger(
1042        &self,
1043        input_ids: &Tensor,
1044        input_embeds: &Tensor,
1045        image_hidden_states: &Tensor,
1046    ) -> Result<Tensor> {
1047        // Docs copied from Transformers impl
1048        /*
1049        This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
1050        The merging happens as follows:
1051        - The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
1052        - We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
1053        We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
1054        - The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
1055        - To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
1056        */
1057        let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
1058        let bs = input_ids.dim(0)?;
1059        let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
1060        let mut new_inputs_embeds = input_embeds.clone();
1061        let reshaped_image_hidden_states =
1062            image_hidden_states.reshape((bs, (), vision_hidden_size))?;
1063        assert_eq!(input_embeds.dim(0)?, 1);
1064        assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
1065        let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
1066        let mut image_hidden_state_i = 0;
1067        for (i, v) in special_image_token_mask.iter().enumerate() {
1068            if *v != 0 {
1069                new_inputs_embeds = new_inputs_embeds.slice_assign(
1070                    &[&.., &i, &..],
1071                    &reshaped_image_hidden_states
1072                        .i((.., image_hidden_state_i, ..))?
1073                        .unsqueeze(1)?,
1074                )?;
1075                image_hidden_state_i += 1;
1076            }
1077        }
1078        Ok(new_inputs_embeds)
1079    }
1080
1081    #[allow(clippy::too_many_arguments)]
1082    fn forward_inner(
1083        &self,
1084        input_ids: &Tensor,
1085        pixel_values: Option<Tensor>,
1086        seqlen_offsets: &[usize],
1087        context_lens: Vec<(usize, usize)>,
1088        pixel_attention_mask: Option<Tensor>,
1089        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1090        flash_params: &FlashParams,
1091    ) -> Result<Tensor> {
1092        let input_embeds = if let Some(pixel_values) = pixel_values {
1093            // == START VISUAL INPUTS INTEGRATION ==
1094            let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
1095            let mut s = vec![batch_size * num_images];
1096            s.extend(pixel_values.dims()[2..].to_vec());
1097            let pixel_values = pixel_values.reshape(s)?;
1098
1099            // Remove padding images which are full of 0s
1100            let nb_values_per_image = pixel_values.dims()[1..].iter().product::<usize>();
1101            let real_images_inds = pixel_values
1102                .eq(0.0f64)?
1103                .sum(vec![
1104                    pixel_values.dims().len() - 1,
1105                    pixel_values.dims().len() - 2,
1106                    pixel_values.dims().len() - 3,
1107                ])?
1108                .ne(nb_values_per_image as f64)?;
1109            let mut batches = Vec::new();
1110            for (batch, use_it) in pixel_values
1111                .chunk(pixel_values.dim(0)?, 0)?
1112                .iter()
1113                .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1114            {
1115                let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1116                if use_it {
1117                    batches.push(batch.clone());
1118                }
1119            }
1120            let pixel_values = Tensor::cat(&batches, 0)?;
1121
1122            // Vision attention mask
1123            let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
1124                let pixel_attention_mask = pixel_attention_mask.reshape((
1125                    batch_size * num_images,
1126                    pixel_attention_mask.dims()[2],
1127                    pixel_attention_mask.dims()[3],
1128                ))?;
1129                let mut batches = Vec::new();
1130                for (batch, use_it) in pixel_attention_mask
1131                    .chunk(pixel_attention_mask.dim(0)?, 0)?
1132                    .iter()
1133                    .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1134                {
1135                    let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1136                    if use_it {
1137                        batches.push(batch.clone());
1138                    }
1139                }
1140                Tensor::cat(&batches, 0)?
1141            } else {
1142                Tensor::ones(
1143                    (
1144                        pixel_values.dims()[0],
1145                        pixel_values.dims()[2],
1146                        pixel_values.dims()[3],
1147                    ),
1148                    DType::U8,
1149                    pixel_values.device(),
1150                )?
1151            };
1152
1153            let patch_size = self.config.vision_config.patch_size;
1154            let patches_subgrid = pixel_attention_mask.unfold(1, patch_size, patch_size)?;
1155            let patches_subgrid = patches_subgrid.unfold(2, patch_size, patch_size)?;
1156
1157            let patch_attention_mask = patches_subgrid
1158                .sum((D::Minus1, D::Minus2))?
1159                .eq((patch_size * patch_size) as f64)?
1160                .to_dtype(DType::U8)?;
1161
1162            let pixel_values = pixel_values.to_dtype(self.dtype)?;
1163
1164            // Get seq from vision encoder
1165            let image_hidden_states = self
1166                .vision_model
1167                .forward(&pixel_values, Some(&patch_attention_mask))?;
1168
1169            // Modality proj and perceiver resampling
1170            let image_hidden_states = self.connector.forward(
1171                &image_hidden_states,
1172                &patch_attention_mask.reshape((pixel_values.dim(0)?, ()))?,
1173            )?;
1174
1175            if self.text_model.cache.normal().0[0].current_seq_len() == 0 {
1176                self.inputs_merger(
1177                    input_ids,
1178                    &self.text_model.get_input_embeddings(input_ids)?,
1179                    &image_hidden_states,
1180                )?
1181            } else {
1182                candle_core::bail!("Pixel values were specified for a non-prompt.")
1183            }
1184        } else {
1185            self.text_model.get_input_embeddings(input_ids)?
1186        };
1187
1188        self.text_model.forward_embeds(
1189            input_ids,
1190            input_embeds,
1191            seqlen_offsets,
1192            context_lens,
1193            metadata,
1194            flash_params,
1195        )
1196    }
1197}
1198
1199impl IsqModel for Idefics2 {
1200    fn get_layers(
1201        &mut self,
1202    ) -> (
1203        Vec<(
1204            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
1205            Option<usize>,
1206        )>,
1207        &dyn DeviceMapper,
1208    ) {
1209        self.text_model.get_layers()
1210    }
1211
1212    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1213        let uvb = UnVarBuilder::new();
1214
1215        let uvb_m = uvb.pp("model");
1216        uvb_m
1217            .pp("text_model")
1218            .extend(self.text_model.residual_tensors());
1219        uvb_m
1220            .pp("vision_model")
1221            .extend(self.vision_model.residual_tensors());
1222        uvb_m
1223            .pp("connector")
1224            .extend(self.connector.residual_tensors());
1225
1226        uvb.to_safetensors()
1227    }
1228}
1229
1230// AnyMoE is forwarded to the base model
1231impl AnyMoeBaseModelMixin for Idefics2 {
1232    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1233        self.text_model.get_mlps()
1234    }
1235    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1236        self.text_model.get_mlps_mut()
1237    }
1238    fn create_anymoe_layers(
1239        &mut self,
1240        additional_vbs: Vec<ShardedVarBuilder>,
1241        config: AnyMoeConfig,
1242        (prefix, mlp): (String, String),
1243        layers: Vec<usize>,
1244        expert_type: AnyMoeExpertType,
1245        gate_vb: Option<ShardedVarBuilder>,
1246    ) -> Result<()> {
1247        self.text_model.create_anymoe_layers(
1248            additional_vbs,
1249            config,
1250            (prefix, mlp),
1251            layers,
1252            expert_type,
1253            gate_vb,
1254        )
1255    }
1256    fn amoe_supported(&self) -> bool {
1257        true
1258    }
1259}
1260
1261impl VisionModel for Idefics2 {
1262    fn forward(
1263        &self,
1264        input_ids: &Tensor,
1265        pixel_values: Option<Tensor>,
1266        seqlen_offsets: &[usize],
1267        context_lens: Vec<(usize, usize)>,
1268        _: Vec<usize>, // Ignore, it is for phi3
1269        model_specific_args: Box<dyn Any>,
1270        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1271        flash_params: &FlashParams,
1272    ) -> candle_core::Result<Tensor> {
1273        let pixel_attention_mask: Option<Tensor> = *model_specific_args
1274            .downcast()
1275            .expect("Cannot downcast into `Option<Tensor>`");
1276        self.forward_inner(
1277            input_ids,
1278            pixel_values,
1279            seqlen_offsets,
1280            context_lens,
1281            pixel_attention_mask,
1282            metadata,
1283            flash_params,
1284        )
1285    }
1286    fn cache(&self) -> &EitherCache {
1287        self.text_model.cache()
1288    }
1289    fn cache_mut(&mut self) -> &mut EitherCache {
1290        self.text_model.cache_mut()
1291    }
1292    fn device(&self) -> &Device {
1293        self.text_model.device()
1294    }
1295    fn max_seq_len(&self) -> usize {
1296        self.text_model.max_seq_len()
1297    }
1298    fn has_conv2d(&self) -> bool {
1299        true
1300    }
1301    fn config(&self) -> &ModelConfigMetadata {
1302        self.text_model.config()
1303    }
1304    fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1305        let args: Option<Tensor> = None;
1306        Box::new(args)
1307    }
1308}