mistralrs_core/vision_models/mllama/
vision.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
8
9use crate::{
10    attention::SdpaParams,
11    layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
12    pipeline::IsqModel,
13    utils::unvarbuilder::UnVarBuilder,
14};
15
16use super::{MLlamaVisionConfig, VisionActivation};
17
18struct MLlamaPrecomputedPositionEmbedding {
19    gate: Tensor,
20    embedding: Tensor,
21    tile_embedding: Embedding,
22    num_patches: usize,
23    hidden_size: usize,
24    max_num_tiles: usize,
25}
26
27impl MLlamaPrecomputedPositionEmbedding {
28    fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
29        let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
30        Ok(Self {
31            gate: vb.get((1,), "gate")?,
32            embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
33            tile_embedding: embedding(
34                cfg.max_aspect_ratio_id() + 1,
35                cfg.max_num_tiles * num_patches * cfg.hidden_size,
36                vb.pp("tile_embedding"),
37            )?,
38            num_patches,
39            hidden_size: cfg.hidden_size,
40            max_num_tiles: cfg.max_num_tiles,
41        })
42    }
43
44    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L197
45    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
46        // position embeddings
47        let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
48        let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
49            1,
50            1,
51            self.num_patches,
52            self.hidden_size,
53        ))?)?;
54
55        // precomputed tile position embeddings
56        let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
57        let bs = hidden_state.dim(0)?;
58        tile_position_embedding = tile_position_embedding.reshape((
59            bs,
60            self.max_num_tiles,
61            self.num_patches,
62            self.hidden_size,
63        ))?;
64        gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
65
66        hidden_state.broadcast_add(&gated_pos_embed)
67    }
68
69    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
70        let uvb_gpe = UnVarBuilder::new();
71
72        uvb_gpe.add_tensor("gate", self.gate.clone());
73        uvb_gpe.add_tensor("embedding", self.embedding.clone());
74        uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
75
76        uvb_gpe.to_safetensors()
77    }
78}
79
80struct MLlamaPrecomputedAspectRatioEmbedding {
81    embedding: Embedding,
82    gate: Option<Tensor>,
83    max_num_tiles: usize,
84    hidden_size: usize,
85}
86
87impl MLlamaPrecomputedAspectRatioEmbedding {
88    fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
89        Ok(Self {
90            embedding: embedding(
91                cfg.max_aspect_ratio_id() + 1,
92                cfg.max_num_tiles * cfg.hidden_size,
93                vb.pp("embedding"),
94            )?,
95            gate: if GATED {
96                Some(vb.get((1,), "gate")?)
97            } else {
98                None
99            },
100            max_num_tiles: cfg.max_num_tiles,
101            hidden_size: cfg.hidden_size,
102        })
103    }
104
105    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
106        let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
107        embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
108
109        if let Some(gate) = &self.gate {
110            embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
111        }
112
113        hidden_state.broadcast_add(&embeddings)
114    }
115
116    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
117        let uvb_ptpe = UnVarBuilder::new();
118
119        if let Some(gate) = self.gate.clone() {
120            uvb_ptpe.add_tensor("gate", gate);
121        }
122        uvb_ptpe.pp("embedding").add(&self.embedding);
123
124        uvb_ptpe.to_safetensors()
125    }
126}
127
128struct MLlamaVisionAttention {
129    q_proj: Arc<dyn QuantMethod>,
130    k_proj: Arc<dyn QuantMethod>,
131    v_proj: Arc<dyn QuantMethod>,
132    o_proj: Arc<dyn QuantMethod>,
133    sdpa_params: SdpaParams,
134    num_heads: usize,
135    head_dim: usize,
136}
137
138impl MLlamaVisionAttention {
139    fn new(
140        cfg: &MLlamaVisionConfig,
141        vb: ShardedVarBuilder,
142        comm: &Arc<mistralrs_quant::Comm>,
143    ) -> Result<Self> {
144        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
145        Ok(Self {
146            q_proj: ColumnParallelLayer::new(
147                cfg.hidden_size,
148                cfg.num_attention_heads * head_dim,
149                &None,
150                false,
151                comm,
152                vb.pp("q_proj"),
153            )?,
154            k_proj: ColumnParallelLayer::new(
155                cfg.hidden_size,
156                cfg.num_attention_heads * head_dim,
157                &None,
158                false,
159                comm,
160                vb.pp("k_proj"),
161            )?,
162            v_proj: ColumnParallelLayer::new(
163                cfg.hidden_size,
164                cfg.num_attention_heads * head_dim,
165                &None,
166                false,
167                comm,
168                vb.pp("v_proj"),
169            )?,
170            o_proj: RowParallelLayer::new(
171                cfg.hidden_size,
172                cfg.num_attention_heads * head_dim,
173                &None,
174                false,
175                comm,
176                vb.pp("o_proj"),
177            )?,
178            sdpa_params: SdpaParams {
179                n_kv_groups: 1,
180                use_flash_attn: false,
181                softcap: None,
182                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
183                sliding_window: None,
184            },
185            num_heads: cfg.num_attention_heads,
186            head_dim,
187        })
188    }
189
190    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L243
191    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
192        let mut hidden_state = hidden_state.clone();
193        let original_dtype = hidden_state.dtype();
194        if let Some(t) = self.q_proj.quantized_act_type() {
195            hidden_state = hidden_state.to_dtype(t)?;
196        }
197        let mut q = self.q_proj.forward(&hidden_state)?;
198        let mut k = self.k_proj.forward(&hidden_state)?;
199        let mut v = self.v_proj.forward(&hidden_state)?;
200        if self.q_proj.quantized_act_type().is_some() {
201            q = q.to_dtype(original_dtype)?;
202            k = k.to_dtype(original_dtype)?;
203            v = v.to_dtype(original_dtype)?;
204        }
205
206        // Should be same, no caching...
207        let (bs, q_sq, _) = q.dims3()?;
208        let (_, k_sq, _) = k.dims3()?;
209
210        q = q
211            .reshape((bs, q_sq, self.num_heads, self.head_dim))?
212            .transpose(1, 2)?;
213        k = k
214            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
215            .transpose(1, 2)?;
216        v = v
217            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
218            .transpose(1, 2)?;
219
220        let mut attn_output = Sdpa
221            .run_attention(
222                &q.contiguous()?,
223                &k.contiguous()?,
224                &v.contiguous()?,
225                attention_mask,
226                None,
227                &self.sdpa_params,
228            )?
229            .transpose(1, 2)?
230            .contiguous()?
231            .reshape((bs, q_sq, ()))?
232            .to_dtype(q.dtype())?;
233
234        if let Some(t) = self.q_proj.quantized_act_type() {
235            attn_output = attn_output.to_dtype(t)?;
236        }
237        let mut res = self.o_proj.forward(&attn_output)?;
238        if self.q_proj.quantized_act_type().is_some() {
239            res = res.to_dtype(original_dtype)?;
240        }
241        Ok(res)
242    }
243}
244
245struct MLlamaMlp {
246    act: VisionActivation,
247    fc1: Arc<dyn QuantMethod>,
248    fc2: Arc<dyn QuantMethod>,
249}
250
251impl MLlamaMlp {
252    fn new(
253        cfg: &MLlamaVisionConfig,
254        vb: ShardedVarBuilder,
255        comm: &Arc<mistralrs_quant::Comm>,
256    ) -> Result<Self> {
257        Ok(Self {
258            act: cfg.hidden_act,
259            fc1: ColumnParallelLayer::new(
260                cfg.hidden_size,
261                cfg.intermediate_size,
262                &None,
263                true,
264                comm,
265                vb.pp("fc1"),
266            )?,
267            fc2: RowParallelLayer::new(
268                cfg.intermediate_size,
269                cfg.hidden_size,
270                &None,
271                true,
272                comm,
273                vb.pp("fc2"),
274            )?,
275        })
276    }
277
278    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L223
279    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
280        let original_dtype = hidden_states.dtype();
281        let mut hidden_states = hidden_states.clone();
282        if let Some(t) = self.fc1.quantized_act_type() {
283            hidden_states = hidden_states.to_dtype(t)?;
284        }
285        hidden_states = self
286            .fc2
287            .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
288        if self.fc1.quantized_act_type().is_some() {
289            hidden_states = hidden_states.to_dtype(original_dtype)?;
290        }
291        Ok(hidden_states)
292    }
293}
294
295struct MLlamaVisionEncoderLayer {
296    self_attn: MLlamaVisionAttention,
297    mlp: MLlamaMlp,
298    input_layernorm: LayerNorm,
299    post_attention_layernorm: LayerNorm,
300    gate_attn: Option<Tensor>,
301    gate_ffn: Option<Tensor>,
302}
303
304impl MLlamaVisionEncoderLayer {
305    fn new<const GATED: bool>(
306        cfg: &MLlamaVisionConfig,
307        vb: ShardedVarBuilder,
308        real_dev: &Device,
309        comm: &Arc<mistralrs_quant::Comm>,
310    ) -> Result<Self> {
311        let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
312        let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
313
314        let input_layernorm = layer_norm(
315            cfg.hidden_size,
316            cfg.norm_eps,
317            vb.pp("input_layernorm").set_device(real_dev.clone()),
318        )?;
319        let post_attention_layernorm = layer_norm(
320            cfg.hidden_size,
321            cfg.norm_eps,
322            vb.pp("post_attention_layernorm")
323                .set_device(real_dev.clone()),
324        )?;
325
326        if GATED {
327            Ok(Self {
328                self_attn,
329                mlp,
330                input_layernorm,
331                post_attention_layernorm,
332                gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
333                gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
334            })
335        } else {
336            Ok(Self {
337                self_attn,
338                mlp,
339                input_layernorm,
340                post_attention_layernorm,
341                gate_attn: None,
342                gate_ffn: None,
343            })
344        }
345    }
346
347    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L348
348    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
349        // Self attn
350        let residual = hidden_state;
351        let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
352
353        hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
354
355        if let Some(gate) = &self.gate_attn {
356            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
357        }
358        hidden_state = (residual + hidden_state)?;
359
360        // FF
361        let residual = hidden_state.clone();
362        hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
363
364        hidden_state = self.mlp.forward(&hidden_state)?;
365
366        if let Some(gate) = &self.gate_ffn {
367            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
368        }
369        residual + hidden_state
370    }
371}
372
373struct MLlamaVisionEncoder {
374    layers: Vec<MLlamaVisionEncoderLayer>,
375}
376
377impl MLlamaVisionEncoder {
378    fn new<const GATED: bool>(
379        cfg: &MLlamaVisionConfig,
380        num_layers: usize,
381        vb: ShardedVarBuilder,
382        real_dev: &Device,
383        comm: &Arc<mistralrs_quant::Comm>,
384    ) -> Result<Self> {
385        let mut layers = Vec::with_capacity(num_layers);
386        let layers_vb = vb.pp("layers");
387        for i in 0..num_layers {
388            layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
389                cfg,
390                layers_vb.pp(i),
391                real_dev,
392                comm,
393            )?);
394        }
395        Ok(Self { layers })
396    }
397
398    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L394
399    /// Also (optionally) return hidden states at some indices
400    fn forward_with_states(
401        &self,
402        hidden_state: &Tensor,
403        attention_mask: Option<&Tensor>,
404        intermediate_layers_indices: Option<&[usize]>,
405    ) -> Result<(Tensor, Vec<Tensor>)> {
406        let mut hidden_state = hidden_state.clone();
407        let mut hidden_states = Vec::new();
408        for (i, layer) in self.layers.iter().enumerate() {
409            if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
410                hidden_states.push(hidden_state.clone());
411            }
412            hidden_state = layer.forward(&hidden_state, attention_mask)?;
413        }
414        Ok((hidden_state, hidden_states))
415    }
416
417    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
418        let uvb_t = UnVarBuilder::new();
419
420        for (i, layer) in self.layers.iter().enumerate() {
421            let uvb_l = uvb_t.pp("layers").pp(i);
422            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
423            uvb_l
424                .pp("post_attention_layernorm")
425                .add(&layer.post_attention_layernorm);
426            if let Some(gate) = layer.gate_attn.clone() {
427                uvb_l.add_tensor("gate_attn", gate);
428            }
429            if let Some(gate) = layer.gate_ffn.clone() {
430                uvb_l.add_tensor("gate_ffn", gate);
431            }
432        }
433
434        uvb_t.to_safetensors()
435    }
436}
437
438fn _prepare_aspect_ratio_attention_mask(
439    aspect_ratio_mask: &Tensor,
440    num_patches: usize,
441    target_length: usize,
442    dtype: DType,
443    _num_attn_heads: usize,
444) -> Result<Tensor> {
445    let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
446    let mut attention_mask = aspect_ratio_mask
447        .reshape((bs, max_num_tiles, 1, 1))?
448        .repeat((1, 1, target_length, 1))?;
449
450    // Mask padding patches
451    let pad_patches = target_length - num_patches;
452    let (bs, d1, d2, d3) = attention_mask.dims4()?;
453    attention_mask = attention_mask.slice_assign(
454        &[&.., &.., &(d2 - pad_patches..), &..],
455        &Tensor::zeros(
456            (bs, d1, pad_patches, d3),
457            attention_mask.dtype(),
458            attention_mask.device(),
459        )?,
460    )?;
461
462    // Invert the mask
463    attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
464
465    // Reshape to 2d and create 4d attn mask
466    // (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
467    let neg_inf_value = dtype.finfo()?.min;
468    attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
469    attention_mask.matmul(
470        &attention_mask
471            .transpose(D::Minus1, D::Minus2)?
472            .mul(neg_inf_value)?,
473    )
474}
475
476pub(super) struct MLlamaVisionModel {
477    patch_embedding: Conv2d,
478    class_embedding: Tensor,
479    gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
480    pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
481    post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
482    layernorm_pre: LayerNorm,
483    layernorm_post: LayerNorm,
484    transformer: MLlamaVisionEncoder,
485    global_transformer: MLlamaVisionEncoder,
486    pub(super) num_patches: usize,
487    intermediate_layers_indices: Vec<usize>,
488    num_attn_heads: usize,
489}
490
491impl MLlamaVisionModel {
492    pub(super) fn new(
493        cfg: &MLlamaVisionConfig,
494        vb: ShardedVarBuilder,
495        real_dev: &Device,
496        comm: &Arc<mistralrs_quant::Comm>,
497    ) -> Result<Self> {
498        let patch_embedding = conv2d_no_bias(
499            cfg.num_channels,
500            cfg.hidden_size,
501            cfg.patch_size,
502            Conv2dConfig {
503                stride: cfg.patch_size,
504                ..Default::default()
505            },
506            vb.pp("patch_embedding").set_device(real_dev.clone()),
507        )?;
508
509        let class_embedding = vb
510            .get((cfg.hidden_size,), "class_embedding")?
511            .to_device(real_dev)?;
512        let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
513            cfg,
514            vb.pp("gated_positional_embedding")
515                .set_device(real_dev.clone()),
516        )?;
517
518        let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
519            cfg,
520            vb.pp("pre_tile_positional_embedding")
521                .set_device(real_dev.clone()),
522        )?;
523        let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
524            cfg,
525            vb.pp("post_tile_positional_embedding")
526                .set_device(real_dev.clone()),
527        )?;
528
529        // layer norms
530        let layernorm_pre = layer_norm(
531            cfg.hidden_size,
532            LayerNormConfig::default(),
533            vb.pp("layernorm_pre").set_device(real_dev.clone()),
534        )?;
535        let layernorm_post = layer_norm(
536            cfg.hidden_size,
537            LayerNormConfig::default(),
538            vb.pp("layernorm_post").set_device(real_dev.clone()),
539        )?;
540
541        // encoders
542        let transformer = MLlamaVisionEncoder::new::<false>(
543            cfg,
544            cfg.num_hidden_layers,
545            vb.pp("transformer"),
546            real_dev,
547            comm,
548        )?;
549        let global_transformer = MLlamaVisionEncoder::new::<true>(
550            cfg,
551            cfg.num_global_layers,
552            vb.pp("global_transformer"),
553            real_dev,
554            comm,
555        )?;
556
557        Ok(Self {
558            patch_embedding,
559            class_embedding,
560            gated_positional_embedding,
561            pre_tile_positional_embedding,
562            post_tile_positional_embedding,
563            layernorm_post,
564            layernorm_pre,
565            transformer,
566            global_transformer,
567            num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
568            intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
569            num_attn_heads: cfg.num_attention_heads,
570        })
571    }
572
573    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L1425
574    pub(super) fn forward(
575        &self,
576        pixel_values: &Tensor,
577        aspect_ratio_ids: &Tensor,
578        aspect_ratio_mask: &Tensor,
579    ) -> Result<Tensor> {
580        let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
581
582        let bs = pixel_values.dim(0)?;
583        let num_concurrent_media = pixel_values.dim(1)?;
584        let num_tiles = pixel_values.dim(2)?;
585        let num_channels = pixel_values.dim(3)?;
586        let height = pixel_values.dim(4)?;
587        let width = pixel_values.dim(5)?;
588
589        let pixel_values = pixel_values.reshape((
590            bs * num_concurrent_media * num_tiles,
591            num_channels,
592            height,
593            width,
594        ))?;
595        let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
596
597        // Patch embedding
598        let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
599        let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
600
601        // Tile embeddings
602        let (_, mut num_patches, dim) = hidden_state.dims3()?;
603        hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
604        hidden_state = self
605            .pre_tile_positional_embedding
606            .forward(&hidden_state, &aspect_ratio_ids)?;
607
608        // Add cls token
609        hidden_state =
610            hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
611        hidden_state = self.apply_class_embedding(&hidden_state)?;
612        num_patches += 1;
613
614        // Position embeddings
615        hidden_state =
616            hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
617        hidden_state = self
618            .gated_positional_embedding
619            .forward(&hidden_state, &aspect_ratio_ids)?;
620
621        hidden_state = self.layernorm_pre.forward(&hidden_state)?;
622
623        // Compute the number of tokens to pad
624        let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
625        // Compute padding tuple for pad function
626        // (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
627        let _padding = (0usize, 0usize, 0usize, num_padding_patches);
628        if num_padding_patches >= 0 {
629            hidden_state =
630                hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
631        } else {
632            hidden_state = hidden_state.narrow(
633                D::Minus2,
634                0,
635                (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
636            )?;
637        }
638
639        // Prepare attention mask
640        let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
641        attention_mask = _prepare_aspect_ratio_attention_mask(
642            &attention_mask,
643            self.num_patches,
644            hidden_state.dim(2)?,
645            hidden_state.dtype(),
646            self.num_attn_heads,
647        )?;
648        if attention_mask.dim(0)? != 1 {
649            attention_mask = attention_mask.unsqueeze(1)?;
650        }
651
652        // Apply encoder
653        hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
654        let (mut hidden_state, all_intermediate_hidden_states) =
655            self.transformer.forward_with_states(
656                &hidden_state,
657                Some(&attention_mask),
658                Some(&self.intermediate_layers_indices),
659            )?;
660
661        // Collect intermediate layer outputs from encoder output
662        let mut intermediate_hidden_states =
663            Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
664        drop(all_intermediate_hidden_states);
665
666        hidden_state = self.layernorm_post.forward(&hidden_state)?;
667
668        // Apply global encoder
669        hidden_state = hidden_state.reshape((
670            bs * num_concurrent_media,
671            num_tiles,
672            (num_patches as isize + num_padding_patches) as usize,
673            dim,
674        ))?;
675        hidden_state = self
676            .post_tile_positional_embedding
677            .forward(&hidden_state, &aspect_ratio_ids)?;
678        hidden_state = hidden_state.reshape((
679            bs * num_concurrent_media,
680            num_tiles * (num_patches as isize + num_padding_patches) as usize,
681            dim,
682        ))?;
683        (hidden_state, _) = self.global_transformer.forward_with_states(
684            &hidden_state,
685            Some(&attention_mask),
686            None,
687        )?;
688
689        // Remove padding from hidden state
690        hidden_state = hidden_state.reshape((
691            bs * num_concurrent_media,
692            num_tiles,
693            (num_patches as isize + num_padding_patches) as usize,
694            dim,
695        ))?;
696        hidden_state = hidden_state.narrow(
697            2,
698            0,
699            (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
700        )?;
701        hidden_state =
702            hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
703
704        // Remove padding from intermediate hidden states
705        intermediate_hidden_states = intermediate_hidden_states.reshape((
706            bs * num_concurrent_media,
707            num_tiles,
708            (num_patches as isize + num_padding_patches) as usize,
709            (),
710        ))?;
711        intermediate_hidden_states = intermediate_hidden_states.narrow(
712            2,
713            0,
714            (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
715        )?;
716        intermediate_hidden_states = intermediate_hidden_states.reshape((
717            bs,
718            num_concurrent_media,
719            num_tiles,
720            num_patches,
721            (),
722        ))?;
723
724        // Concatenate final hidden state and intermediate hidden states
725        Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
726    }
727
728    fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
729        let (bs, _, hidden_size) = hidden_state.dims3()?;
730        let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
731        Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
732    }
733
734    pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
735        let mut layers = Vec::new();
736        for layer in &mut self.global_transformer.layers {
737            layers.push(&mut layer.self_attn.q_proj);
738            layers.push(&mut layer.self_attn.k_proj);
739            layers.push(&mut layer.self_attn.v_proj);
740            layers.push(&mut layer.self_attn.o_proj);
741
742            layers.push(&mut layer.mlp.fc1);
743            layers.push(&mut layer.mlp.fc2);
744        }
745        for layer in &mut self.transformer.layers {
746            layers.push(&mut layer.self_attn.q_proj);
747            layers.push(&mut layer.self_attn.k_proj);
748            layers.push(&mut layer.self_attn.v_proj);
749            layers.push(&mut layer.self_attn.o_proj);
750
751            layers.push(&mut layer.mlp.fc1);
752            layers.push(&mut layer.mlp.fc2);
753        }
754        layers
755    }
756}
757
758impl IsqModel for MLlamaVisionModel {
759    fn get_layers(
760        &mut self,
761    ) -> (
762        Vec<(
763            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
764            Option<usize>,
765        )>,
766        &dyn crate::device_map::DeviceMapper,
767    ) {
768        unreachable!("MLlamaVision model cannot be quantized.");
769    }
770    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
771        let uvb = UnVarBuilder::new();
772
773        uvb.pp("patch_embedding").add(&self.patch_embedding);
774        uvb.add_tensor("class_embedding", self.class_embedding.clone());
775
776        // gated_positional_embedding
777        uvb.pp("gated_positional_embedding")
778            .extend(self.gated_positional_embedding.residual_tensors());
779
780        // pre_tile_positional_embedding
781        uvb.pp("pre_tile_positional_embedding")
782            .extend(self.pre_tile_positional_embedding.residual_tensors());
783
784        // post_tile_positional_embedding
785        uvb.pp("post_tile_positional_embedding")
786            .extend(self.post_tile_positional_embedding.residual_tensors());
787
788        uvb.pp("layernorm_pre").add(&self.layernorm_pre);
789        uvb.pp("layernorm_post").add(&self.layernorm_post);
790
791        // transformer
792        uvb.pp("transformer")
793            .extend(self.transformer.residual_tensors());
794
795        // global_transformer
796        uvb.pp("global_transformer")
797            .extend(self.global_transformer.residual_tensors());
798
799        uvb.to_safetensors()
800    }
801}