mistralrs_core/vision_models/mllama/
vision.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{
8    ColumnParallelLayer, Convolution, QuantMethod, RowParallelLayer, ShardedVarBuilder,
9};
10
11use crate::{
12    attention::SdpaParams,
13    layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
14    pipeline::IsqModel,
15    utils::unvarbuilder::UnVarBuilder,
16};
17
18use super::{MLlamaVisionConfig, VisionActivation};
19
20struct MLlamaPrecomputedPositionEmbedding {
21    gate: Tensor,
22    embedding: Tensor,
23    tile_embedding: Embedding,
24    num_patches: usize,
25    hidden_size: usize,
26    max_num_tiles: usize,
27}
28
29impl MLlamaPrecomputedPositionEmbedding {
30    fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
31        let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
32        Ok(Self {
33            gate: vb.get((1,), "gate")?,
34            embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
35            tile_embedding: embedding(
36                cfg.max_aspect_ratio_id() + 1,
37                cfg.max_num_tiles * num_patches * cfg.hidden_size,
38                vb.pp("tile_embedding"),
39                &None,
40            )?,
41            num_patches,
42            hidden_size: cfg.hidden_size,
43            max_num_tiles: cfg.max_num_tiles,
44        })
45    }
46
47    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L197
48    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
49        // position embeddings
50        let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
51        let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
52            1,
53            1,
54            self.num_patches,
55            self.hidden_size,
56        ))?)?;
57
58        // precomputed tile position embeddings
59        let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
60        let bs = hidden_state.dim(0)?;
61        tile_position_embedding = tile_position_embedding.reshape((
62            bs,
63            self.max_num_tiles,
64            self.num_patches,
65            self.hidden_size,
66        ))?;
67        gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
68
69        hidden_state.broadcast_add(&gated_pos_embed)
70    }
71
72    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
73        let uvb_gpe = UnVarBuilder::new();
74
75        uvb_gpe.add_tensor("gate", self.gate.clone());
76        uvb_gpe.add_tensor("embedding", self.embedding.clone());
77        uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
78
79        uvb_gpe.to_safetensors()
80    }
81}
82
83struct MLlamaPrecomputedAspectRatioEmbedding {
84    embedding: Embedding,
85    gate: Option<Tensor>,
86    max_num_tiles: usize,
87    hidden_size: usize,
88}
89
90impl MLlamaPrecomputedAspectRatioEmbedding {
91    fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
92        Ok(Self {
93            embedding: embedding(
94                cfg.max_aspect_ratio_id() + 1,
95                cfg.max_num_tiles * cfg.hidden_size,
96                vb.pp("embedding"),
97                &None,
98            )?,
99            gate: if GATED {
100                Some(vb.get((1,), "gate")?)
101            } else {
102                None
103            },
104            max_num_tiles: cfg.max_num_tiles,
105            hidden_size: cfg.hidden_size,
106        })
107    }
108
109    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
110        let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
111        embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
112
113        if let Some(gate) = &self.gate {
114            embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
115        }
116
117        hidden_state.broadcast_add(&embeddings)
118    }
119
120    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
121        let uvb_ptpe = UnVarBuilder::new();
122
123        if let Some(gate) = self.gate.clone() {
124            uvb_ptpe.add_tensor("gate", gate);
125        }
126        uvb_ptpe.pp("embedding").add(&self.embedding);
127
128        uvb_ptpe.to_safetensors()
129    }
130}
131
132struct MLlamaVisionAttention {
133    q_proj: Arc<dyn QuantMethod>,
134    k_proj: Arc<dyn QuantMethod>,
135    v_proj: Arc<dyn QuantMethod>,
136    o_proj: Arc<dyn QuantMethod>,
137    sdpa_params: SdpaParams,
138    num_heads: usize,
139    head_dim: usize,
140}
141
142impl MLlamaVisionAttention {
143    fn new(
144        cfg: &MLlamaVisionConfig,
145        vb: ShardedVarBuilder,
146        comm: &Arc<mistralrs_quant::Comm>,
147    ) -> Result<Self> {
148        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
149        Ok(Self {
150            q_proj: ColumnParallelLayer::new(
151                cfg.hidden_size,
152                cfg.num_attention_heads * head_dim,
153                &None,
154                false,
155                comm,
156                vb.pp("q_proj"),
157            )?,
158            k_proj: ColumnParallelLayer::new(
159                cfg.hidden_size,
160                cfg.num_attention_heads * head_dim,
161                &None,
162                false,
163                comm,
164                vb.pp("k_proj"),
165            )?,
166            v_proj: ColumnParallelLayer::new(
167                cfg.hidden_size,
168                cfg.num_attention_heads * head_dim,
169                &None,
170                false,
171                comm,
172                vb.pp("v_proj"),
173            )?,
174            o_proj: RowParallelLayer::new(
175                cfg.hidden_size,
176                cfg.num_attention_heads * head_dim,
177                &None,
178                false,
179                comm,
180                vb.pp("o_proj"),
181            )?,
182            sdpa_params: SdpaParams {
183                n_kv_groups: 1,
184                softcap: None,
185                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
186                sliding_window: None,
187            },
188            num_heads: cfg.num_attention_heads,
189            head_dim,
190        })
191    }
192
193    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L243
194    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
195        let mut hidden_state = hidden_state.clone();
196        let original_dtype = hidden_state.dtype();
197        if let Some(t) = self.q_proj.quantized_act_type() {
198            hidden_state = hidden_state.to_dtype(t)?;
199        }
200        let mut q = self.q_proj.forward(&hidden_state)?;
201        let mut k = self.k_proj.forward(&hidden_state)?;
202        let mut v = self.v_proj.forward(&hidden_state)?;
203        if self.q_proj.quantized_act_type().is_some() {
204            q = q.to_dtype(original_dtype)?;
205            k = k.to_dtype(original_dtype)?;
206            v = v.to_dtype(original_dtype)?;
207        }
208
209        // Should be same, no caching...
210        let (bs, q_sq, _) = q.dims3()?;
211        let (_, k_sq, _) = k.dims3()?;
212
213        q = q
214            .reshape((bs, q_sq, self.num_heads, self.head_dim))?
215            .transpose(1, 2)?;
216        k = k
217            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
218            .transpose(1, 2)?;
219        v = v
220            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
221            .transpose(1, 2)?;
222
223        let mut attn_output = Sdpa
224            .run_attention(
225                &q.contiguous()?,
226                &k.contiguous()?,
227                &v.contiguous()?,
228                attention_mask,
229                None,
230                &self.sdpa_params,
231            )?
232            .transpose(1, 2)?
233            .contiguous()?
234            .reshape((bs, q_sq, ()))?
235            .to_dtype(q.dtype())?;
236
237        if let Some(t) = self.q_proj.quantized_act_type() {
238            attn_output = attn_output.to_dtype(t)?;
239        }
240        let mut res = self.o_proj.forward(&attn_output)?;
241        if self.q_proj.quantized_act_type().is_some() {
242            res = res.to_dtype(original_dtype)?;
243        }
244        Ok(res)
245    }
246}
247
248struct MLlamaMlp {
249    act: VisionActivation,
250    fc1: Arc<dyn QuantMethod>,
251    fc2: Arc<dyn QuantMethod>,
252}
253
254impl MLlamaMlp {
255    fn new(
256        cfg: &MLlamaVisionConfig,
257        vb: ShardedVarBuilder,
258        comm: &Arc<mistralrs_quant::Comm>,
259    ) -> Result<Self> {
260        Ok(Self {
261            act: cfg.hidden_act,
262            fc1: ColumnParallelLayer::new(
263                cfg.hidden_size,
264                cfg.intermediate_size,
265                &None,
266                true,
267                comm,
268                vb.pp("fc1"),
269            )?,
270            fc2: RowParallelLayer::new(
271                cfg.intermediate_size,
272                cfg.hidden_size,
273                &None,
274                true,
275                comm,
276                vb.pp("fc2"),
277            )?,
278        })
279    }
280
281    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L223
282    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
283        let original_dtype = hidden_states.dtype();
284        let mut hidden_states = hidden_states.clone();
285        if let Some(t) = self.fc1.quantized_act_type() {
286            hidden_states = hidden_states.to_dtype(t)?;
287        }
288        hidden_states = self
289            .fc2
290            .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
291        if self.fc1.quantized_act_type().is_some() {
292            hidden_states = hidden_states.to_dtype(original_dtype)?;
293        }
294        Ok(hidden_states)
295    }
296}
297
298struct MLlamaVisionEncoderLayer {
299    self_attn: MLlamaVisionAttention,
300    mlp: MLlamaMlp,
301    input_layernorm: LayerNorm,
302    post_attention_layernorm: LayerNorm,
303    gate_attn: Option<Tensor>,
304    gate_ffn: Option<Tensor>,
305}
306
307impl MLlamaVisionEncoderLayer {
308    fn new<const GATED: bool>(
309        cfg: &MLlamaVisionConfig,
310        vb: ShardedVarBuilder,
311        real_dev: &Device,
312        comm: &Arc<mistralrs_quant::Comm>,
313    ) -> Result<Self> {
314        let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
315        let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
316
317        let input_layernorm = layer_norm(
318            cfg.hidden_size,
319            cfg.norm_eps,
320            vb.pp("input_layernorm").set_device(real_dev.clone()),
321        )?;
322        let post_attention_layernorm = layer_norm(
323            cfg.hidden_size,
324            cfg.norm_eps,
325            vb.pp("post_attention_layernorm")
326                .set_device(real_dev.clone()),
327        )?;
328
329        if GATED {
330            Ok(Self {
331                self_attn,
332                mlp,
333                input_layernorm,
334                post_attention_layernorm,
335                gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
336                gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
337            })
338        } else {
339            Ok(Self {
340                self_attn,
341                mlp,
342                input_layernorm,
343                post_attention_layernorm,
344                gate_attn: None,
345                gate_ffn: None,
346            })
347        }
348    }
349
350    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L348
351    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
352        // Self attn
353        let residual = hidden_state;
354        let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
355
356        hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
357
358        if let Some(gate) = &self.gate_attn {
359            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
360        }
361        hidden_state = (residual + hidden_state)?;
362
363        // FF
364        let residual = hidden_state.clone();
365        hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
366
367        hidden_state = self.mlp.forward(&hidden_state)?;
368
369        if let Some(gate) = &self.gate_ffn {
370            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
371        }
372        residual + hidden_state
373    }
374}
375
376struct MLlamaVisionEncoder {
377    layers: Vec<MLlamaVisionEncoderLayer>,
378}
379
380impl MLlamaVisionEncoder {
381    fn new<const GATED: bool>(
382        cfg: &MLlamaVisionConfig,
383        num_layers: usize,
384        vb: ShardedVarBuilder,
385        real_dev: &Device,
386        comm: &Arc<mistralrs_quant::Comm>,
387    ) -> Result<Self> {
388        let mut layers = Vec::with_capacity(num_layers);
389        let layers_vb = vb.pp("layers");
390        for i in 0..num_layers {
391            layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
392                cfg,
393                layers_vb.pp(i),
394                real_dev,
395                comm,
396            )?);
397        }
398        Ok(Self { layers })
399    }
400
401    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L394
402    /// Also (optionally) return hidden states at some indices
403    fn forward_with_states(
404        &self,
405        hidden_state: &Tensor,
406        attention_mask: Option<&Tensor>,
407        intermediate_layers_indices: Option<&[usize]>,
408    ) -> Result<(Tensor, Vec<Tensor>)> {
409        let mut hidden_state = hidden_state.clone();
410        let mut hidden_states = Vec::new();
411        for (i, layer) in self.layers.iter().enumerate() {
412            if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
413                hidden_states.push(hidden_state.clone());
414            }
415            hidden_state = layer.forward(&hidden_state, attention_mask)?;
416        }
417        Ok((hidden_state, hidden_states))
418    }
419
420    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
421        let uvb_t = UnVarBuilder::new();
422
423        for (i, layer) in self.layers.iter().enumerate() {
424            let uvb_l = uvb_t.pp("layers").pp(i);
425            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
426            uvb_l
427                .pp("post_attention_layernorm")
428                .add(&layer.post_attention_layernorm);
429            if let Some(gate) = layer.gate_attn.clone() {
430                uvb_l.add_tensor("gate_attn", gate);
431            }
432            if let Some(gate) = layer.gate_ffn.clone() {
433                uvb_l.add_tensor("gate_ffn", gate);
434            }
435        }
436
437        uvb_t.to_safetensors()
438    }
439}
440
441fn _prepare_aspect_ratio_attention_mask(
442    aspect_ratio_mask: &Tensor,
443    num_patches: usize,
444    target_length: usize,
445    dtype: DType,
446    _num_attn_heads: usize,
447) -> Result<Tensor> {
448    let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
449    let mut attention_mask = aspect_ratio_mask
450        .reshape((bs, max_num_tiles, 1, 1))?
451        .repeat((1, 1, target_length, 1))?;
452
453    // Mask padding patches
454    let pad_patches = target_length - num_patches;
455    let (bs, d1, d2, d3) = attention_mask.dims4()?;
456    attention_mask = attention_mask.slice_assign(
457        &[0..bs, 0..d1, (d2 - pad_patches)..d2, 0..d3],
458        &Tensor::zeros(
459            (bs, d1, pad_patches, d3),
460            attention_mask.dtype(),
461            attention_mask.device(),
462        )?,
463    )?;
464
465    // Invert the mask
466    attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
467
468    // Reshape to 2d and create 4d attn mask
469    // (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
470    let neg_inf_value = dtype.finfo()?.min;
471    attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
472    attention_mask.matmul(
473        &attention_mask
474            .transpose(D::Minus1, D::Minus2)?
475            .mul(neg_inf_value)?,
476    )
477}
478
479pub(super) struct MLlamaVisionModel {
480    patch_embedding: Conv2d,
481    class_embedding: Tensor,
482    gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
483    pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
484    post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
485    layernorm_pre: LayerNorm,
486    layernorm_post: LayerNorm,
487    transformer: MLlamaVisionEncoder,
488    global_transformer: MLlamaVisionEncoder,
489    pub(super) num_patches: usize,
490    intermediate_layers_indices: Vec<usize>,
491    num_attn_heads: usize,
492}
493
494impl MLlamaVisionModel {
495    pub(super) fn new(
496        cfg: &MLlamaVisionConfig,
497        vb: ShardedVarBuilder,
498        real_dev: &Device,
499        comm: &Arc<mistralrs_quant::Comm>,
500    ) -> Result<Self> {
501        let patch_embedding = conv2d_no_bias(
502            cfg.num_channels,
503            cfg.hidden_size,
504            cfg.patch_size,
505            Conv2dConfig {
506                stride: cfg.patch_size,
507                ..Default::default()
508            },
509            vb.pp("patch_embedding").set_device(real_dev.clone()),
510        )?;
511
512        let class_embedding = vb
513            .get((cfg.hidden_size,), "class_embedding")?
514            .to_device(real_dev)?;
515        let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
516            cfg,
517            vb.pp("gated_positional_embedding")
518                .set_device(real_dev.clone()),
519        )?;
520
521        let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
522            cfg,
523            vb.pp("pre_tile_positional_embedding")
524                .set_device(real_dev.clone()),
525        )?;
526        let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
527            cfg,
528            vb.pp("post_tile_positional_embedding")
529                .set_device(real_dev.clone()),
530        )?;
531
532        // layer norms
533        let layernorm_pre = layer_norm(
534            cfg.hidden_size,
535            LayerNormConfig::default(),
536            vb.pp("layernorm_pre").set_device(real_dev.clone()),
537        )?;
538        let layernorm_post = layer_norm(
539            cfg.hidden_size,
540            LayerNormConfig::default(),
541            vb.pp("layernorm_post").set_device(real_dev.clone()),
542        )?;
543
544        // encoders
545        let transformer = MLlamaVisionEncoder::new::<false>(
546            cfg,
547            cfg.num_hidden_layers,
548            vb.pp("transformer"),
549            real_dev,
550            comm,
551        )?;
552        let global_transformer = MLlamaVisionEncoder::new::<true>(
553            cfg,
554            cfg.num_global_layers,
555            vb.pp("global_transformer"),
556            real_dev,
557            comm,
558        )?;
559
560        Ok(Self {
561            patch_embedding,
562            class_embedding,
563            gated_positional_embedding,
564            pre_tile_positional_embedding,
565            post_tile_positional_embedding,
566            layernorm_post,
567            layernorm_pre,
568            transformer,
569            global_transformer,
570            num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
571            intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
572            num_attn_heads: cfg.num_attention_heads,
573        })
574    }
575
576    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L1425
577    pub(super) fn forward(
578        &self,
579        pixel_values: &Tensor,
580        aspect_ratio_ids: &Tensor,
581        aspect_ratio_mask: &Tensor,
582    ) -> Result<Tensor> {
583        let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
584
585        let bs = pixel_values.dim(0)?;
586        let num_concurrent_media = pixel_values.dim(1)?;
587        let num_tiles = pixel_values.dim(2)?;
588        let num_channels = pixel_values.dim(3)?;
589        let height = pixel_values.dim(4)?;
590        let width = pixel_values.dim(5)?;
591
592        let pixel_values = pixel_values.reshape((
593            bs * num_concurrent_media * num_tiles,
594            num_channels,
595            height,
596            width,
597        ))?;
598        let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
599
600        // Patch embedding
601        let patch_embeds = Convolution.forward_2d(&self.patch_embedding, &pixel_values)?;
602        let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
603
604        // Tile embeddings
605        let (_, mut num_patches, dim) = hidden_state.dims3()?;
606        hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
607        hidden_state = self
608            .pre_tile_positional_embedding
609            .forward(&hidden_state, &aspect_ratio_ids)?;
610
611        // Add cls token
612        hidden_state =
613            hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
614        hidden_state = self.apply_class_embedding(&hidden_state)?;
615        num_patches += 1;
616
617        // Position embeddings
618        hidden_state =
619            hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
620        hidden_state = self
621            .gated_positional_embedding
622            .forward(&hidden_state, &aspect_ratio_ids)?;
623
624        hidden_state = self.layernorm_pre.forward(&hidden_state)?;
625
626        // Compute the number of tokens to pad
627        let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
628        // Compute padding tuple for pad function
629        // (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
630        let _padding = (0usize, 0usize, 0usize, num_padding_patches);
631        if num_padding_patches >= 0 {
632            hidden_state =
633                hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
634        } else {
635            hidden_state = hidden_state.narrow(
636                D::Minus2,
637                0,
638                (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
639            )?;
640        }
641
642        // Prepare attention mask
643        let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
644        attention_mask = _prepare_aspect_ratio_attention_mask(
645            &attention_mask,
646            self.num_patches,
647            hidden_state.dim(2)?,
648            hidden_state.dtype(),
649            self.num_attn_heads,
650        )?;
651        if attention_mask.dim(0)? != 1 {
652            attention_mask = attention_mask.unsqueeze(1)?;
653        }
654
655        // Apply encoder
656        hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
657        let (mut hidden_state, all_intermediate_hidden_states) =
658            self.transformer.forward_with_states(
659                &hidden_state,
660                Some(&attention_mask),
661                Some(&self.intermediate_layers_indices),
662            )?;
663
664        // Collect intermediate layer outputs from encoder output
665        let mut intermediate_hidden_states =
666            Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
667        drop(all_intermediate_hidden_states);
668
669        hidden_state = self.layernorm_post.forward(&hidden_state)?;
670
671        // Apply global encoder
672        hidden_state = hidden_state.reshape((
673            bs * num_concurrent_media,
674            num_tiles,
675            (num_patches as isize + num_padding_patches) as usize,
676            dim,
677        ))?;
678        hidden_state = self
679            .post_tile_positional_embedding
680            .forward(&hidden_state, &aspect_ratio_ids)?;
681        hidden_state = hidden_state.reshape((
682            bs * num_concurrent_media,
683            num_tiles * (num_patches as isize + num_padding_patches) as usize,
684            dim,
685        ))?;
686        (hidden_state, _) = self.global_transformer.forward_with_states(
687            &hidden_state,
688            Some(&attention_mask),
689            None,
690        )?;
691
692        // Remove padding from hidden state
693        hidden_state = hidden_state.reshape((
694            bs * num_concurrent_media,
695            num_tiles,
696            (num_patches as isize + num_padding_patches) as usize,
697            dim,
698        ))?;
699        hidden_state = hidden_state.narrow(
700            2,
701            0,
702            (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
703        )?;
704        hidden_state =
705            hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
706
707        // Remove padding from intermediate hidden states
708        intermediate_hidden_states = intermediate_hidden_states.reshape((
709            bs * num_concurrent_media,
710            num_tiles,
711            (num_patches as isize + num_padding_patches) as usize,
712            (),
713        ))?;
714        intermediate_hidden_states = intermediate_hidden_states.narrow(
715            2,
716            0,
717            (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
718        )?;
719        intermediate_hidden_states = intermediate_hidden_states.reshape((
720            bs,
721            num_concurrent_media,
722            num_tiles,
723            num_patches,
724            (),
725        ))?;
726
727        // Concatenate final hidden state and intermediate hidden states
728        Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
729    }
730
731    fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
732        let (bs, _, hidden_size) = hidden_state.dims3()?;
733        let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
734        Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
735    }
736
737    pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
738        let mut layers = Vec::new();
739        for layer in &mut self.global_transformer.layers {
740            layers.push(&mut layer.self_attn.q_proj);
741            layers.push(&mut layer.self_attn.k_proj);
742            layers.push(&mut layer.self_attn.v_proj);
743            layers.push(&mut layer.self_attn.o_proj);
744
745            layers.push(&mut layer.mlp.fc1);
746            layers.push(&mut layer.mlp.fc2);
747        }
748        for layer in &mut self.transformer.layers {
749            layers.push(&mut layer.self_attn.q_proj);
750            layers.push(&mut layer.self_attn.k_proj);
751            layers.push(&mut layer.self_attn.v_proj);
752            layers.push(&mut layer.self_attn.o_proj);
753
754            layers.push(&mut layer.mlp.fc1);
755            layers.push(&mut layer.mlp.fc2);
756        }
757        layers
758    }
759}
760
761impl IsqModel for MLlamaVisionModel {
762    fn get_layers(
763        &mut self,
764    ) -> (
765        Vec<(
766            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
767            Option<usize>,
768        )>,
769        &dyn crate::device_map::DeviceMapper,
770    ) {
771        unreachable!("MLlamaVision model cannot be quantized.");
772    }
773    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
774        let uvb = UnVarBuilder::new();
775
776        uvb.pp("patch_embedding").add(&self.patch_embedding);
777        uvb.add_tensor("class_embedding", self.class_embedding.clone());
778
779        // gated_positional_embedding
780        uvb.pp("gated_positional_embedding")
781            .extend(self.gated_positional_embedding.residual_tensors());
782
783        // pre_tile_positional_embedding
784        uvb.pp("pre_tile_positional_embedding")
785            .extend(self.pre_tile_positional_embedding.residual_tensors());
786
787        // post_tile_positional_embedding
788        uvb.pp("post_tile_positional_embedding")
789            .extend(self.post_tile_positional_embedding.residual_tensors());
790
791        uvb.pp("layernorm_pre").add(&self.layernorm_pre);
792        uvb.pp("layernorm_post").add(&self.layernorm_post);
793
794        // transformer
795        uvb.pp("transformer")
796            .extend(self.transformer.residual_tensors());
797
798        // global_transformer
799        uvb.pp("global_transformer")
800            .extend(self.global_transformer.residual_tensors());
801
802        uvb.to_safetensors()
803    }
804}