mistralrs_core/vision_models/mllama/
vision.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
8
9use crate::{
10    attention::SdpaParams,
11    layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
12    pipeline::IsqModel,
13    utils::unvarbuilder::UnVarBuilder,
14};
15
16use super::{MLlamaVisionConfig, VisionActivation};
17
18struct MLlamaPrecomputedPositionEmbedding {
19    gate: Tensor,
20    embedding: Tensor,
21    tile_embedding: Embedding,
22    num_patches: usize,
23    hidden_size: usize,
24    max_num_tiles: usize,
25}
26
27impl MLlamaPrecomputedPositionEmbedding {
28    fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
29        let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
30        Ok(Self {
31            gate: vb.get((1,), "gate")?,
32            embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
33            tile_embedding: embedding(
34                cfg.max_aspect_ratio_id() + 1,
35                cfg.max_num_tiles * num_patches * cfg.hidden_size,
36                vb.pp("tile_embedding"),
37                &None,
38            )?,
39            num_patches,
40            hidden_size: cfg.hidden_size,
41            max_num_tiles: cfg.max_num_tiles,
42        })
43    }
44
45    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L197
46    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
47        // position embeddings
48        let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
49        let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
50            1,
51            1,
52            self.num_patches,
53            self.hidden_size,
54        ))?)?;
55
56        // precomputed tile position embeddings
57        let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
58        let bs = hidden_state.dim(0)?;
59        tile_position_embedding = tile_position_embedding.reshape((
60            bs,
61            self.max_num_tiles,
62            self.num_patches,
63            self.hidden_size,
64        ))?;
65        gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
66
67        hidden_state.broadcast_add(&gated_pos_embed)
68    }
69
70    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
71        let uvb_gpe = UnVarBuilder::new();
72
73        uvb_gpe.add_tensor("gate", self.gate.clone());
74        uvb_gpe.add_tensor("embedding", self.embedding.clone());
75        uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
76
77        uvb_gpe.to_safetensors()
78    }
79}
80
81struct MLlamaPrecomputedAspectRatioEmbedding {
82    embedding: Embedding,
83    gate: Option<Tensor>,
84    max_num_tiles: usize,
85    hidden_size: usize,
86}
87
88impl MLlamaPrecomputedAspectRatioEmbedding {
89    fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
90        Ok(Self {
91            embedding: embedding(
92                cfg.max_aspect_ratio_id() + 1,
93                cfg.max_num_tiles * cfg.hidden_size,
94                vb.pp("embedding"),
95                &None,
96            )?,
97            gate: if GATED {
98                Some(vb.get((1,), "gate")?)
99            } else {
100                None
101            },
102            max_num_tiles: cfg.max_num_tiles,
103            hidden_size: cfg.hidden_size,
104        })
105    }
106
107    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
108        let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
109        embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
110
111        if let Some(gate) = &self.gate {
112            embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
113        }
114
115        hidden_state.broadcast_add(&embeddings)
116    }
117
118    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
119        let uvb_ptpe = UnVarBuilder::new();
120
121        if let Some(gate) = self.gate.clone() {
122            uvb_ptpe.add_tensor("gate", gate);
123        }
124        uvb_ptpe.pp("embedding").add(&self.embedding);
125
126        uvb_ptpe.to_safetensors()
127    }
128}
129
130struct MLlamaVisionAttention {
131    q_proj: Arc<dyn QuantMethod>,
132    k_proj: Arc<dyn QuantMethod>,
133    v_proj: Arc<dyn QuantMethod>,
134    o_proj: Arc<dyn QuantMethod>,
135    sdpa_params: SdpaParams,
136    num_heads: usize,
137    head_dim: usize,
138}
139
140impl MLlamaVisionAttention {
141    fn new(
142        cfg: &MLlamaVisionConfig,
143        vb: ShardedVarBuilder,
144        comm: &Arc<mistralrs_quant::Comm>,
145    ) -> Result<Self> {
146        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
147        Ok(Self {
148            q_proj: ColumnParallelLayer::new(
149                cfg.hidden_size,
150                cfg.num_attention_heads * head_dim,
151                &None,
152                false,
153                comm,
154                vb.pp("q_proj"),
155            )?,
156            k_proj: ColumnParallelLayer::new(
157                cfg.hidden_size,
158                cfg.num_attention_heads * head_dim,
159                &None,
160                false,
161                comm,
162                vb.pp("k_proj"),
163            )?,
164            v_proj: ColumnParallelLayer::new(
165                cfg.hidden_size,
166                cfg.num_attention_heads * head_dim,
167                &None,
168                false,
169                comm,
170                vb.pp("v_proj"),
171            )?,
172            o_proj: RowParallelLayer::new(
173                cfg.hidden_size,
174                cfg.num_attention_heads * head_dim,
175                &None,
176                false,
177                comm,
178                vb.pp("o_proj"),
179            )?,
180            sdpa_params: SdpaParams {
181                n_kv_groups: 1,
182                softcap: None,
183                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
184                sliding_window: None,
185            },
186            num_heads: cfg.num_attention_heads,
187            head_dim,
188        })
189    }
190
191    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L243
192    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
193        let mut hidden_state = hidden_state.clone();
194        let original_dtype = hidden_state.dtype();
195        if let Some(t) = self.q_proj.quantized_act_type() {
196            hidden_state = hidden_state.to_dtype(t)?;
197        }
198        let mut q = self.q_proj.forward(&hidden_state)?;
199        let mut k = self.k_proj.forward(&hidden_state)?;
200        let mut v = self.v_proj.forward(&hidden_state)?;
201        if self.q_proj.quantized_act_type().is_some() {
202            q = q.to_dtype(original_dtype)?;
203            k = k.to_dtype(original_dtype)?;
204            v = v.to_dtype(original_dtype)?;
205        }
206
207        // Should be same, no caching...
208        let (bs, q_sq, _) = q.dims3()?;
209        let (_, k_sq, _) = k.dims3()?;
210
211        q = q
212            .reshape((bs, q_sq, self.num_heads, self.head_dim))?
213            .transpose(1, 2)?;
214        k = k
215            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
216            .transpose(1, 2)?;
217        v = v
218            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
219            .transpose(1, 2)?;
220
221        let mut attn_output = Sdpa
222            .run_attention(
223                &q.contiguous()?,
224                &k.contiguous()?,
225                &v.contiguous()?,
226                attention_mask,
227                None,
228                &self.sdpa_params,
229            )?
230            .transpose(1, 2)?
231            .contiguous()?
232            .reshape((bs, q_sq, ()))?
233            .to_dtype(q.dtype())?;
234
235        if let Some(t) = self.q_proj.quantized_act_type() {
236            attn_output = attn_output.to_dtype(t)?;
237        }
238        let mut res = self.o_proj.forward(&attn_output)?;
239        if self.q_proj.quantized_act_type().is_some() {
240            res = res.to_dtype(original_dtype)?;
241        }
242        Ok(res)
243    }
244}
245
246struct MLlamaMlp {
247    act: VisionActivation,
248    fc1: Arc<dyn QuantMethod>,
249    fc2: Arc<dyn QuantMethod>,
250}
251
252impl MLlamaMlp {
253    fn new(
254        cfg: &MLlamaVisionConfig,
255        vb: ShardedVarBuilder,
256        comm: &Arc<mistralrs_quant::Comm>,
257    ) -> Result<Self> {
258        Ok(Self {
259            act: cfg.hidden_act,
260            fc1: ColumnParallelLayer::new(
261                cfg.hidden_size,
262                cfg.intermediate_size,
263                &None,
264                true,
265                comm,
266                vb.pp("fc1"),
267            )?,
268            fc2: RowParallelLayer::new(
269                cfg.intermediate_size,
270                cfg.hidden_size,
271                &None,
272                true,
273                comm,
274                vb.pp("fc2"),
275            )?,
276        })
277    }
278
279    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L223
280    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
281        let original_dtype = hidden_states.dtype();
282        let mut hidden_states = hidden_states.clone();
283        if let Some(t) = self.fc1.quantized_act_type() {
284            hidden_states = hidden_states.to_dtype(t)?;
285        }
286        hidden_states = self
287            .fc2
288            .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
289        if self.fc1.quantized_act_type().is_some() {
290            hidden_states = hidden_states.to_dtype(original_dtype)?;
291        }
292        Ok(hidden_states)
293    }
294}
295
296struct MLlamaVisionEncoderLayer {
297    self_attn: MLlamaVisionAttention,
298    mlp: MLlamaMlp,
299    input_layernorm: LayerNorm,
300    post_attention_layernorm: LayerNorm,
301    gate_attn: Option<Tensor>,
302    gate_ffn: Option<Tensor>,
303}
304
305impl MLlamaVisionEncoderLayer {
306    fn new<const GATED: bool>(
307        cfg: &MLlamaVisionConfig,
308        vb: ShardedVarBuilder,
309        real_dev: &Device,
310        comm: &Arc<mistralrs_quant::Comm>,
311    ) -> Result<Self> {
312        let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
313        let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
314
315        let input_layernorm = layer_norm(
316            cfg.hidden_size,
317            cfg.norm_eps,
318            vb.pp("input_layernorm").set_device(real_dev.clone()),
319        )?;
320        let post_attention_layernorm = layer_norm(
321            cfg.hidden_size,
322            cfg.norm_eps,
323            vb.pp("post_attention_layernorm")
324                .set_device(real_dev.clone()),
325        )?;
326
327        if GATED {
328            Ok(Self {
329                self_attn,
330                mlp,
331                input_layernorm,
332                post_attention_layernorm,
333                gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
334                gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
335            })
336        } else {
337            Ok(Self {
338                self_attn,
339                mlp,
340                input_layernorm,
341                post_attention_layernorm,
342                gate_attn: None,
343                gate_ffn: None,
344            })
345        }
346    }
347
348    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L348
349    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
350        // Self attn
351        let residual = hidden_state;
352        let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
353
354        hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
355
356        if let Some(gate) = &self.gate_attn {
357            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
358        }
359        hidden_state = (residual + hidden_state)?;
360
361        // FF
362        let residual = hidden_state.clone();
363        hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
364
365        hidden_state = self.mlp.forward(&hidden_state)?;
366
367        if let Some(gate) = &self.gate_ffn {
368            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
369        }
370        residual + hidden_state
371    }
372}
373
374struct MLlamaVisionEncoder {
375    layers: Vec<MLlamaVisionEncoderLayer>,
376}
377
378impl MLlamaVisionEncoder {
379    fn new<const GATED: bool>(
380        cfg: &MLlamaVisionConfig,
381        num_layers: usize,
382        vb: ShardedVarBuilder,
383        real_dev: &Device,
384        comm: &Arc<mistralrs_quant::Comm>,
385    ) -> Result<Self> {
386        let mut layers = Vec::with_capacity(num_layers);
387        let layers_vb = vb.pp("layers");
388        for i in 0..num_layers {
389            layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
390                cfg,
391                layers_vb.pp(i),
392                real_dev,
393                comm,
394            )?);
395        }
396        Ok(Self { layers })
397    }
398
399    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L394
400    /// Also (optionally) return hidden states at some indices
401    fn forward_with_states(
402        &self,
403        hidden_state: &Tensor,
404        attention_mask: Option<&Tensor>,
405        intermediate_layers_indices: Option<&[usize]>,
406    ) -> Result<(Tensor, Vec<Tensor>)> {
407        let mut hidden_state = hidden_state.clone();
408        let mut hidden_states = Vec::new();
409        for (i, layer) in self.layers.iter().enumerate() {
410            if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
411                hidden_states.push(hidden_state.clone());
412            }
413            hidden_state = layer.forward(&hidden_state, attention_mask)?;
414        }
415        Ok((hidden_state, hidden_states))
416    }
417
418    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
419        let uvb_t = UnVarBuilder::new();
420
421        for (i, layer) in self.layers.iter().enumerate() {
422            let uvb_l = uvb_t.pp("layers").pp(i);
423            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
424            uvb_l
425                .pp("post_attention_layernorm")
426                .add(&layer.post_attention_layernorm);
427            if let Some(gate) = layer.gate_attn.clone() {
428                uvb_l.add_tensor("gate_attn", gate);
429            }
430            if let Some(gate) = layer.gate_ffn.clone() {
431                uvb_l.add_tensor("gate_ffn", gate);
432            }
433        }
434
435        uvb_t.to_safetensors()
436    }
437}
438
439fn _prepare_aspect_ratio_attention_mask(
440    aspect_ratio_mask: &Tensor,
441    num_patches: usize,
442    target_length: usize,
443    dtype: DType,
444    _num_attn_heads: usize,
445) -> Result<Tensor> {
446    let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
447    let mut attention_mask = aspect_ratio_mask
448        .reshape((bs, max_num_tiles, 1, 1))?
449        .repeat((1, 1, target_length, 1))?;
450
451    // Mask padding patches
452    let pad_patches = target_length - num_patches;
453    let (bs, d1, d2, d3) = attention_mask.dims4()?;
454    attention_mask = attention_mask.slice_assign(
455        &[&.., &.., &(d2 - pad_patches..), &..],
456        &Tensor::zeros(
457            (bs, d1, pad_patches, d3),
458            attention_mask.dtype(),
459            attention_mask.device(),
460        )?,
461    )?;
462
463    // Invert the mask
464    attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
465
466    // Reshape to 2d and create 4d attn mask
467    // (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
468    let neg_inf_value = dtype.finfo()?.min;
469    attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
470    attention_mask.matmul(
471        &attention_mask
472            .transpose(D::Minus1, D::Minus2)?
473            .mul(neg_inf_value)?,
474    )
475}
476
477pub(super) struct MLlamaVisionModel {
478    patch_embedding: Conv2d,
479    class_embedding: Tensor,
480    gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
481    pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
482    post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
483    layernorm_pre: LayerNorm,
484    layernorm_post: LayerNorm,
485    transformer: MLlamaVisionEncoder,
486    global_transformer: MLlamaVisionEncoder,
487    pub(super) num_patches: usize,
488    intermediate_layers_indices: Vec<usize>,
489    num_attn_heads: usize,
490}
491
492impl MLlamaVisionModel {
493    pub(super) fn new(
494        cfg: &MLlamaVisionConfig,
495        vb: ShardedVarBuilder,
496        real_dev: &Device,
497        comm: &Arc<mistralrs_quant::Comm>,
498    ) -> Result<Self> {
499        let patch_embedding = conv2d_no_bias(
500            cfg.num_channels,
501            cfg.hidden_size,
502            cfg.patch_size,
503            Conv2dConfig {
504                stride: cfg.patch_size,
505                ..Default::default()
506            },
507            vb.pp("patch_embedding").set_device(real_dev.clone()),
508        )?;
509
510        let class_embedding = vb
511            .get((cfg.hidden_size,), "class_embedding")?
512            .to_device(real_dev)?;
513        let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
514            cfg,
515            vb.pp("gated_positional_embedding")
516                .set_device(real_dev.clone()),
517        )?;
518
519        let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
520            cfg,
521            vb.pp("pre_tile_positional_embedding")
522                .set_device(real_dev.clone()),
523        )?;
524        let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
525            cfg,
526            vb.pp("post_tile_positional_embedding")
527                .set_device(real_dev.clone()),
528        )?;
529
530        // layer norms
531        let layernorm_pre = layer_norm(
532            cfg.hidden_size,
533            LayerNormConfig::default(),
534            vb.pp("layernorm_pre").set_device(real_dev.clone()),
535        )?;
536        let layernorm_post = layer_norm(
537            cfg.hidden_size,
538            LayerNormConfig::default(),
539            vb.pp("layernorm_post").set_device(real_dev.clone()),
540        )?;
541
542        // encoders
543        let transformer = MLlamaVisionEncoder::new::<false>(
544            cfg,
545            cfg.num_hidden_layers,
546            vb.pp("transformer"),
547            real_dev,
548            comm,
549        )?;
550        let global_transformer = MLlamaVisionEncoder::new::<true>(
551            cfg,
552            cfg.num_global_layers,
553            vb.pp("global_transformer"),
554            real_dev,
555            comm,
556        )?;
557
558        Ok(Self {
559            patch_embedding,
560            class_embedding,
561            gated_positional_embedding,
562            pre_tile_positional_embedding,
563            post_tile_positional_embedding,
564            layernorm_post,
565            layernorm_pre,
566            transformer,
567            global_transformer,
568            num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
569            intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
570            num_attn_heads: cfg.num_attention_heads,
571        })
572    }
573
574    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L1425
575    pub(super) fn forward(
576        &self,
577        pixel_values: &Tensor,
578        aspect_ratio_ids: &Tensor,
579        aspect_ratio_mask: &Tensor,
580    ) -> Result<Tensor> {
581        let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
582
583        let bs = pixel_values.dim(0)?;
584        let num_concurrent_media = pixel_values.dim(1)?;
585        let num_tiles = pixel_values.dim(2)?;
586        let num_channels = pixel_values.dim(3)?;
587        let height = pixel_values.dim(4)?;
588        let width = pixel_values.dim(5)?;
589
590        let pixel_values = pixel_values.reshape((
591            bs * num_concurrent_media * num_tiles,
592            num_channels,
593            height,
594            width,
595        ))?;
596        let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
597
598        // Patch embedding
599        let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
600        let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
601
602        // Tile embeddings
603        let (_, mut num_patches, dim) = hidden_state.dims3()?;
604        hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
605        hidden_state = self
606            .pre_tile_positional_embedding
607            .forward(&hidden_state, &aspect_ratio_ids)?;
608
609        // Add cls token
610        hidden_state =
611            hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
612        hidden_state = self.apply_class_embedding(&hidden_state)?;
613        num_patches += 1;
614
615        // Position embeddings
616        hidden_state =
617            hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
618        hidden_state = self
619            .gated_positional_embedding
620            .forward(&hidden_state, &aspect_ratio_ids)?;
621
622        hidden_state = self.layernorm_pre.forward(&hidden_state)?;
623
624        // Compute the number of tokens to pad
625        let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
626        // Compute padding tuple for pad function
627        // (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
628        let _padding = (0usize, 0usize, 0usize, num_padding_patches);
629        if num_padding_patches >= 0 {
630            hidden_state =
631                hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
632        } else {
633            hidden_state = hidden_state.narrow(
634                D::Minus2,
635                0,
636                (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
637            )?;
638        }
639
640        // Prepare attention mask
641        let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
642        attention_mask = _prepare_aspect_ratio_attention_mask(
643            &attention_mask,
644            self.num_patches,
645            hidden_state.dim(2)?,
646            hidden_state.dtype(),
647            self.num_attn_heads,
648        )?;
649        if attention_mask.dim(0)? != 1 {
650            attention_mask = attention_mask.unsqueeze(1)?;
651        }
652
653        // Apply encoder
654        hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
655        let (mut hidden_state, all_intermediate_hidden_states) =
656            self.transformer.forward_with_states(
657                &hidden_state,
658                Some(&attention_mask),
659                Some(&self.intermediate_layers_indices),
660            )?;
661
662        // Collect intermediate layer outputs from encoder output
663        let mut intermediate_hidden_states =
664            Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
665        drop(all_intermediate_hidden_states);
666
667        hidden_state = self.layernorm_post.forward(&hidden_state)?;
668
669        // Apply global encoder
670        hidden_state = hidden_state.reshape((
671            bs * num_concurrent_media,
672            num_tiles,
673            (num_patches as isize + num_padding_patches) as usize,
674            dim,
675        ))?;
676        hidden_state = self
677            .post_tile_positional_embedding
678            .forward(&hidden_state, &aspect_ratio_ids)?;
679        hidden_state = hidden_state.reshape((
680            bs * num_concurrent_media,
681            num_tiles * (num_patches as isize + num_padding_patches) as usize,
682            dim,
683        ))?;
684        (hidden_state, _) = self.global_transformer.forward_with_states(
685            &hidden_state,
686            Some(&attention_mask),
687            None,
688        )?;
689
690        // Remove padding from hidden state
691        hidden_state = hidden_state.reshape((
692            bs * num_concurrent_media,
693            num_tiles,
694            (num_patches as isize + num_padding_patches) as usize,
695            dim,
696        ))?;
697        hidden_state = hidden_state.narrow(
698            2,
699            0,
700            (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
701        )?;
702        hidden_state =
703            hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
704
705        // Remove padding from intermediate hidden states
706        intermediate_hidden_states = intermediate_hidden_states.reshape((
707            bs * num_concurrent_media,
708            num_tiles,
709            (num_patches as isize + num_padding_patches) as usize,
710            (),
711        ))?;
712        intermediate_hidden_states = intermediate_hidden_states.narrow(
713            2,
714            0,
715            (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
716        )?;
717        intermediate_hidden_states = intermediate_hidden_states.reshape((
718            bs,
719            num_concurrent_media,
720            num_tiles,
721            num_patches,
722            (),
723        ))?;
724
725        // Concatenate final hidden state and intermediate hidden states
726        Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
727    }
728
729    fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
730        let (bs, _, hidden_size) = hidden_state.dims3()?;
731        let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
732        Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
733    }
734
735    pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
736        let mut layers = Vec::new();
737        for layer in &mut self.global_transformer.layers {
738            layers.push(&mut layer.self_attn.q_proj);
739            layers.push(&mut layer.self_attn.k_proj);
740            layers.push(&mut layer.self_attn.v_proj);
741            layers.push(&mut layer.self_attn.o_proj);
742
743            layers.push(&mut layer.mlp.fc1);
744            layers.push(&mut layer.mlp.fc2);
745        }
746        for layer in &mut self.transformer.layers {
747            layers.push(&mut layer.self_attn.q_proj);
748            layers.push(&mut layer.self_attn.k_proj);
749            layers.push(&mut layer.self_attn.v_proj);
750            layers.push(&mut layer.self_attn.o_proj);
751
752            layers.push(&mut layer.mlp.fc1);
753            layers.push(&mut layer.mlp.fc2);
754        }
755        layers
756    }
757}
758
759impl IsqModel for MLlamaVisionModel {
760    fn get_layers(
761        &mut self,
762    ) -> (
763        Vec<(
764            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
765            Option<usize>,
766        )>,
767        &dyn crate::device_map::DeviceMapper,
768    ) {
769        unreachable!("MLlamaVision model cannot be quantized.");
770    }
771    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
772        let uvb = UnVarBuilder::new();
773
774        uvb.pp("patch_embedding").add(&self.patch_embedding);
775        uvb.add_tensor("class_embedding", self.class_embedding.clone());
776
777        // gated_positional_embedding
778        uvb.pp("gated_positional_embedding")
779            .extend(self.gated_positional_embedding.residual_tensors());
780
781        // pre_tile_positional_embedding
782        uvb.pp("pre_tile_positional_embedding")
783            .extend(self.pre_tile_positional_embedding.residual_tensors());
784
785        // post_tile_positional_embedding
786        uvb.pp("post_tile_positional_embedding")
787            .extend(self.post_tile_positional_embedding.residual_tensors());
788
789        uvb.pp("layernorm_pre").add(&self.layernorm_pre);
790        uvb.pp("layernorm_post").add(&self.layernorm_post);
791
792        // transformer
793        uvb.pp("transformer")
794            .extend(self.transformer.residual_tensors());
795
796        // global_transformer
797        uvb.pp("global_transformer")
798            .extend(self.global_transformer.residual_tensors());
799
800        uvb.to_safetensors()
801    }
802}