mistralrs_core/vision_models/mllama/
vision.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
8
9use crate::{
10    attention::SdpaParams,
11    layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
12    pipeline::IsqModel,
13    utils::unvarbuilder::UnVarBuilder,
14};
15
16use super::{MLlamaVisionConfig, VisionActivation};
17
18struct MLlamaPrecomputedPositionEmbedding {
19    gate: Tensor,
20    embedding: Tensor,
21    tile_embedding: Embedding,
22    num_patches: usize,
23    hidden_size: usize,
24    max_num_tiles: usize,
25}
26
27impl MLlamaPrecomputedPositionEmbedding {
28    fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
29        let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
30        Ok(Self {
31            gate: vb.get((1,), "gate")?,
32            embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
33            tile_embedding: embedding(
34                cfg.max_aspect_ratio_id() + 1,
35                cfg.max_num_tiles * num_patches * cfg.hidden_size,
36                vb.pp("tile_embedding"),
37                &None,
38            )?,
39            num_patches,
40            hidden_size: cfg.hidden_size,
41            max_num_tiles: cfg.max_num_tiles,
42        })
43    }
44
45    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L197
46    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
47        // position embeddings
48        let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
49        let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
50            1,
51            1,
52            self.num_patches,
53            self.hidden_size,
54        ))?)?;
55
56        // precomputed tile position embeddings
57        let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
58        let bs = hidden_state.dim(0)?;
59        tile_position_embedding = tile_position_embedding.reshape((
60            bs,
61            self.max_num_tiles,
62            self.num_patches,
63            self.hidden_size,
64        ))?;
65        gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
66
67        hidden_state.broadcast_add(&gated_pos_embed)
68    }
69
70    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
71        let uvb_gpe = UnVarBuilder::new();
72
73        uvb_gpe.add_tensor("gate", self.gate.clone());
74        uvb_gpe.add_tensor("embedding", self.embedding.clone());
75        uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
76
77        uvb_gpe.to_safetensors()
78    }
79}
80
81struct MLlamaPrecomputedAspectRatioEmbedding {
82    embedding: Embedding,
83    gate: Option<Tensor>,
84    max_num_tiles: usize,
85    hidden_size: usize,
86}
87
88impl MLlamaPrecomputedAspectRatioEmbedding {
89    fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
90        Ok(Self {
91            embedding: embedding(
92                cfg.max_aspect_ratio_id() + 1,
93                cfg.max_num_tiles * cfg.hidden_size,
94                vb.pp("embedding"),
95                &None,
96            )?,
97            gate: if GATED {
98                Some(vb.get((1,), "gate")?)
99            } else {
100                None
101            },
102            max_num_tiles: cfg.max_num_tiles,
103            hidden_size: cfg.hidden_size,
104        })
105    }
106
107    fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
108        let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
109        embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
110
111        if let Some(gate) = &self.gate {
112            embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
113        }
114
115        hidden_state.broadcast_add(&embeddings)
116    }
117
118    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
119        let uvb_ptpe = UnVarBuilder::new();
120
121        if let Some(gate) = self.gate.clone() {
122            uvb_ptpe.add_tensor("gate", gate);
123        }
124        uvb_ptpe.pp("embedding").add(&self.embedding);
125
126        uvb_ptpe.to_safetensors()
127    }
128}
129
130struct MLlamaVisionAttention {
131    q_proj: Arc<dyn QuantMethod>,
132    k_proj: Arc<dyn QuantMethod>,
133    v_proj: Arc<dyn QuantMethod>,
134    o_proj: Arc<dyn QuantMethod>,
135    sdpa_params: SdpaParams,
136    num_heads: usize,
137    head_dim: usize,
138}
139
140impl MLlamaVisionAttention {
141    fn new(
142        cfg: &MLlamaVisionConfig,
143        vb: ShardedVarBuilder,
144        comm: &Arc<mistralrs_quant::Comm>,
145    ) -> Result<Self> {
146        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
147        Ok(Self {
148            q_proj: ColumnParallelLayer::new(
149                cfg.hidden_size,
150                cfg.num_attention_heads * head_dim,
151                &None,
152                false,
153                comm,
154                vb.pp("q_proj"),
155            )?,
156            k_proj: ColumnParallelLayer::new(
157                cfg.hidden_size,
158                cfg.num_attention_heads * head_dim,
159                &None,
160                false,
161                comm,
162                vb.pp("k_proj"),
163            )?,
164            v_proj: ColumnParallelLayer::new(
165                cfg.hidden_size,
166                cfg.num_attention_heads * head_dim,
167                &None,
168                false,
169                comm,
170                vb.pp("v_proj"),
171            )?,
172            o_proj: RowParallelLayer::new(
173                cfg.hidden_size,
174                cfg.num_attention_heads * head_dim,
175                &None,
176                false,
177                comm,
178                vb.pp("o_proj"),
179            )?,
180            sdpa_params: SdpaParams {
181                n_kv_groups: 1,
182                use_flash_attn: false,
183                softcap: None,
184                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
185                sliding_window: None,
186            },
187            num_heads: cfg.num_attention_heads,
188            head_dim,
189        })
190    }
191
192    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L243
193    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
194        let mut hidden_state = hidden_state.clone();
195        let original_dtype = hidden_state.dtype();
196        if let Some(t) = self.q_proj.quantized_act_type() {
197            hidden_state = hidden_state.to_dtype(t)?;
198        }
199        let mut q = self.q_proj.forward(&hidden_state)?;
200        let mut k = self.k_proj.forward(&hidden_state)?;
201        let mut v = self.v_proj.forward(&hidden_state)?;
202        if self.q_proj.quantized_act_type().is_some() {
203            q = q.to_dtype(original_dtype)?;
204            k = k.to_dtype(original_dtype)?;
205            v = v.to_dtype(original_dtype)?;
206        }
207
208        // Should be same, no caching...
209        let (bs, q_sq, _) = q.dims3()?;
210        let (_, k_sq, _) = k.dims3()?;
211
212        q = q
213            .reshape((bs, q_sq, self.num_heads, self.head_dim))?
214            .transpose(1, 2)?;
215        k = k
216            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
217            .transpose(1, 2)?;
218        v = v
219            .reshape((bs, k_sq, self.num_heads, self.head_dim))?
220            .transpose(1, 2)?;
221
222        let mut attn_output = Sdpa
223            .run_attention(
224                &q.contiguous()?,
225                &k.contiguous()?,
226                &v.contiguous()?,
227                attention_mask,
228                None,
229                &self.sdpa_params,
230            )?
231            .transpose(1, 2)?
232            .contiguous()?
233            .reshape((bs, q_sq, ()))?
234            .to_dtype(q.dtype())?;
235
236        if let Some(t) = self.q_proj.quantized_act_type() {
237            attn_output = attn_output.to_dtype(t)?;
238        }
239        let mut res = self.o_proj.forward(&attn_output)?;
240        if self.q_proj.quantized_act_type().is_some() {
241            res = res.to_dtype(original_dtype)?;
242        }
243        Ok(res)
244    }
245}
246
247struct MLlamaMlp {
248    act: VisionActivation,
249    fc1: Arc<dyn QuantMethod>,
250    fc2: Arc<dyn QuantMethod>,
251}
252
253impl MLlamaMlp {
254    fn new(
255        cfg: &MLlamaVisionConfig,
256        vb: ShardedVarBuilder,
257        comm: &Arc<mistralrs_quant::Comm>,
258    ) -> Result<Self> {
259        Ok(Self {
260            act: cfg.hidden_act,
261            fc1: ColumnParallelLayer::new(
262                cfg.hidden_size,
263                cfg.intermediate_size,
264                &None,
265                true,
266                comm,
267                vb.pp("fc1"),
268            )?,
269            fc2: RowParallelLayer::new(
270                cfg.intermediate_size,
271                cfg.hidden_size,
272                &None,
273                true,
274                comm,
275                vb.pp("fc2"),
276            )?,
277        })
278    }
279
280    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L223
281    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
282        let original_dtype = hidden_states.dtype();
283        let mut hidden_states = hidden_states.clone();
284        if let Some(t) = self.fc1.quantized_act_type() {
285            hidden_states = hidden_states.to_dtype(t)?;
286        }
287        hidden_states = self
288            .fc2
289            .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
290        if self.fc1.quantized_act_type().is_some() {
291            hidden_states = hidden_states.to_dtype(original_dtype)?;
292        }
293        Ok(hidden_states)
294    }
295}
296
297struct MLlamaVisionEncoderLayer {
298    self_attn: MLlamaVisionAttention,
299    mlp: MLlamaMlp,
300    input_layernorm: LayerNorm,
301    post_attention_layernorm: LayerNorm,
302    gate_attn: Option<Tensor>,
303    gate_ffn: Option<Tensor>,
304}
305
306impl MLlamaVisionEncoderLayer {
307    fn new<const GATED: bool>(
308        cfg: &MLlamaVisionConfig,
309        vb: ShardedVarBuilder,
310        real_dev: &Device,
311        comm: &Arc<mistralrs_quant::Comm>,
312    ) -> Result<Self> {
313        let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
314        let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
315
316        let input_layernorm = layer_norm(
317            cfg.hidden_size,
318            cfg.norm_eps,
319            vb.pp("input_layernorm").set_device(real_dev.clone()),
320        )?;
321        let post_attention_layernorm = layer_norm(
322            cfg.hidden_size,
323            cfg.norm_eps,
324            vb.pp("post_attention_layernorm")
325                .set_device(real_dev.clone()),
326        )?;
327
328        if GATED {
329            Ok(Self {
330                self_attn,
331                mlp,
332                input_layernorm,
333                post_attention_layernorm,
334                gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
335                gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
336            })
337        } else {
338            Ok(Self {
339                self_attn,
340                mlp,
341                input_layernorm,
342                post_attention_layernorm,
343                gate_attn: None,
344                gate_ffn: None,
345            })
346        }
347    }
348
349    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L348
350    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
351        // Self attn
352        let residual = hidden_state;
353        let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
354
355        hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
356
357        if let Some(gate) = &self.gate_attn {
358            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
359        }
360        hidden_state = (residual + hidden_state)?;
361
362        // FF
363        let residual = hidden_state.clone();
364        hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
365
366        hidden_state = self.mlp.forward(&hidden_state)?;
367
368        if let Some(gate) = &self.gate_ffn {
369            hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
370        }
371        residual + hidden_state
372    }
373}
374
375struct MLlamaVisionEncoder {
376    layers: Vec<MLlamaVisionEncoderLayer>,
377}
378
379impl MLlamaVisionEncoder {
380    fn new<const GATED: bool>(
381        cfg: &MLlamaVisionConfig,
382        num_layers: usize,
383        vb: ShardedVarBuilder,
384        real_dev: &Device,
385        comm: &Arc<mistralrs_quant::Comm>,
386    ) -> Result<Self> {
387        let mut layers = Vec::with_capacity(num_layers);
388        let layers_vb = vb.pp("layers");
389        for i in 0..num_layers {
390            layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
391                cfg,
392                layers_vb.pp(i),
393                real_dev,
394                comm,
395            )?);
396        }
397        Ok(Self { layers })
398    }
399
400    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L394
401    /// Also (optionally) return hidden states at some indices
402    fn forward_with_states(
403        &self,
404        hidden_state: &Tensor,
405        attention_mask: Option<&Tensor>,
406        intermediate_layers_indices: Option<&[usize]>,
407    ) -> Result<(Tensor, Vec<Tensor>)> {
408        let mut hidden_state = hidden_state.clone();
409        let mut hidden_states = Vec::new();
410        for (i, layer) in self.layers.iter().enumerate() {
411            if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
412                hidden_states.push(hidden_state.clone());
413            }
414            hidden_state = layer.forward(&hidden_state, attention_mask)?;
415        }
416        Ok((hidden_state, hidden_states))
417    }
418
419    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
420        let uvb_t = UnVarBuilder::new();
421
422        for (i, layer) in self.layers.iter().enumerate() {
423            let uvb_l = uvb_t.pp("layers").pp(i);
424            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
425            uvb_l
426                .pp("post_attention_layernorm")
427                .add(&layer.post_attention_layernorm);
428            if let Some(gate) = layer.gate_attn.clone() {
429                uvb_l.add_tensor("gate_attn", gate);
430            }
431            if let Some(gate) = layer.gate_ffn.clone() {
432                uvb_l.add_tensor("gate_ffn", gate);
433            }
434        }
435
436        uvb_t.to_safetensors()
437    }
438}
439
440fn _prepare_aspect_ratio_attention_mask(
441    aspect_ratio_mask: &Tensor,
442    num_patches: usize,
443    target_length: usize,
444    dtype: DType,
445    _num_attn_heads: usize,
446) -> Result<Tensor> {
447    let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
448    let mut attention_mask = aspect_ratio_mask
449        .reshape((bs, max_num_tiles, 1, 1))?
450        .repeat((1, 1, target_length, 1))?;
451
452    // Mask padding patches
453    let pad_patches = target_length - num_patches;
454    let (bs, d1, d2, d3) = attention_mask.dims4()?;
455    attention_mask = attention_mask.slice_assign(
456        &[&.., &.., &(d2 - pad_patches..), &..],
457        &Tensor::zeros(
458            (bs, d1, pad_patches, d3),
459            attention_mask.dtype(),
460            attention_mask.device(),
461        )?,
462    )?;
463
464    // Invert the mask
465    attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
466
467    // Reshape to 2d and create 4d attn mask
468    // (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
469    let neg_inf_value = dtype.finfo()?.min;
470    attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
471    attention_mask.matmul(
472        &attention_mask
473            .transpose(D::Minus1, D::Minus2)?
474            .mul(neg_inf_value)?,
475    )
476}
477
478pub(super) struct MLlamaVisionModel {
479    patch_embedding: Conv2d,
480    class_embedding: Tensor,
481    gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
482    pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
483    post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
484    layernorm_pre: LayerNorm,
485    layernorm_post: LayerNorm,
486    transformer: MLlamaVisionEncoder,
487    global_transformer: MLlamaVisionEncoder,
488    pub(super) num_patches: usize,
489    intermediate_layers_indices: Vec<usize>,
490    num_attn_heads: usize,
491}
492
493impl MLlamaVisionModel {
494    pub(super) fn new(
495        cfg: &MLlamaVisionConfig,
496        vb: ShardedVarBuilder,
497        real_dev: &Device,
498        comm: &Arc<mistralrs_quant::Comm>,
499    ) -> Result<Self> {
500        let patch_embedding = conv2d_no_bias(
501            cfg.num_channels,
502            cfg.hidden_size,
503            cfg.patch_size,
504            Conv2dConfig {
505                stride: cfg.patch_size,
506                ..Default::default()
507            },
508            vb.pp("patch_embedding").set_device(real_dev.clone()),
509        )?;
510
511        let class_embedding = vb
512            .get((cfg.hidden_size,), "class_embedding")?
513            .to_device(real_dev)?;
514        let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
515            cfg,
516            vb.pp("gated_positional_embedding")
517                .set_device(real_dev.clone()),
518        )?;
519
520        let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
521            cfg,
522            vb.pp("pre_tile_positional_embedding")
523                .set_device(real_dev.clone()),
524        )?;
525        let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
526            cfg,
527            vb.pp("post_tile_positional_embedding")
528                .set_device(real_dev.clone()),
529        )?;
530
531        // layer norms
532        let layernorm_pre = layer_norm(
533            cfg.hidden_size,
534            LayerNormConfig::default(),
535            vb.pp("layernorm_pre").set_device(real_dev.clone()),
536        )?;
537        let layernorm_post = layer_norm(
538            cfg.hidden_size,
539            LayerNormConfig::default(),
540            vb.pp("layernorm_post").set_device(real_dev.clone()),
541        )?;
542
543        // encoders
544        let transformer = MLlamaVisionEncoder::new::<false>(
545            cfg,
546            cfg.num_hidden_layers,
547            vb.pp("transformer"),
548            real_dev,
549            comm,
550        )?;
551        let global_transformer = MLlamaVisionEncoder::new::<true>(
552            cfg,
553            cfg.num_global_layers,
554            vb.pp("global_transformer"),
555            real_dev,
556            comm,
557        )?;
558
559        Ok(Self {
560            patch_embedding,
561            class_embedding,
562            gated_positional_embedding,
563            pre_tile_positional_embedding,
564            post_tile_positional_embedding,
565            layernorm_post,
566            layernorm_pre,
567            transformer,
568            global_transformer,
569            num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
570            intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
571            num_attn_heads: cfg.num_attention_heads,
572        })
573    }
574
575    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/modeling_mllama.py#L1425
576    pub(super) fn forward(
577        &self,
578        pixel_values: &Tensor,
579        aspect_ratio_ids: &Tensor,
580        aspect_ratio_mask: &Tensor,
581    ) -> Result<Tensor> {
582        let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
583
584        let bs = pixel_values.dim(0)?;
585        let num_concurrent_media = pixel_values.dim(1)?;
586        let num_tiles = pixel_values.dim(2)?;
587        let num_channels = pixel_values.dim(3)?;
588        let height = pixel_values.dim(4)?;
589        let width = pixel_values.dim(5)?;
590
591        let pixel_values = pixel_values.reshape((
592            bs * num_concurrent_media * num_tiles,
593            num_channels,
594            height,
595            width,
596        ))?;
597        let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
598
599        // Patch embedding
600        let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
601        let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
602
603        // Tile embeddings
604        let (_, mut num_patches, dim) = hidden_state.dims3()?;
605        hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
606        hidden_state = self
607            .pre_tile_positional_embedding
608            .forward(&hidden_state, &aspect_ratio_ids)?;
609
610        // Add cls token
611        hidden_state =
612            hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
613        hidden_state = self.apply_class_embedding(&hidden_state)?;
614        num_patches += 1;
615
616        // Position embeddings
617        hidden_state =
618            hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
619        hidden_state = self
620            .gated_positional_embedding
621            .forward(&hidden_state, &aspect_ratio_ids)?;
622
623        hidden_state = self.layernorm_pre.forward(&hidden_state)?;
624
625        // Compute the number of tokens to pad
626        let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
627        // Compute padding tuple for pad function
628        // (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
629        let _padding = (0usize, 0usize, 0usize, num_padding_patches);
630        if num_padding_patches >= 0 {
631            hidden_state =
632                hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
633        } else {
634            hidden_state = hidden_state.narrow(
635                D::Minus2,
636                0,
637                (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
638            )?;
639        }
640
641        // Prepare attention mask
642        let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
643        attention_mask = _prepare_aspect_ratio_attention_mask(
644            &attention_mask,
645            self.num_patches,
646            hidden_state.dim(2)?,
647            hidden_state.dtype(),
648            self.num_attn_heads,
649        )?;
650        if attention_mask.dim(0)? != 1 {
651            attention_mask = attention_mask.unsqueeze(1)?;
652        }
653
654        // Apply encoder
655        hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
656        let (mut hidden_state, all_intermediate_hidden_states) =
657            self.transformer.forward_with_states(
658                &hidden_state,
659                Some(&attention_mask),
660                Some(&self.intermediate_layers_indices),
661            )?;
662
663        // Collect intermediate layer outputs from encoder output
664        let mut intermediate_hidden_states =
665            Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
666        drop(all_intermediate_hidden_states);
667
668        hidden_state = self.layernorm_post.forward(&hidden_state)?;
669
670        // Apply global encoder
671        hidden_state = hidden_state.reshape((
672            bs * num_concurrent_media,
673            num_tiles,
674            (num_patches as isize + num_padding_patches) as usize,
675            dim,
676        ))?;
677        hidden_state = self
678            .post_tile_positional_embedding
679            .forward(&hidden_state, &aspect_ratio_ids)?;
680        hidden_state = hidden_state.reshape((
681            bs * num_concurrent_media,
682            num_tiles * (num_patches as isize + num_padding_patches) as usize,
683            dim,
684        ))?;
685        (hidden_state, _) = self.global_transformer.forward_with_states(
686            &hidden_state,
687            Some(&attention_mask),
688            None,
689        )?;
690
691        // Remove padding from hidden state
692        hidden_state = hidden_state.reshape((
693            bs * num_concurrent_media,
694            num_tiles,
695            (num_patches as isize + num_padding_patches) as usize,
696            dim,
697        ))?;
698        hidden_state = hidden_state.narrow(
699            2,
700            0,
701            (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
702        )?;
703        hidden_state =
704            hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
705
706        // Remove padding from intermediate hidden states
707        intermediate_hidden_states = intermediate_hidden_states.reshape((
708            bs * num_concurrent_media,
709            num_tiles,
710            (num_patches as isize + num_padding_patches) as usize,
711            (),
712        ))?;
713        intermediate_hidden_states = intermediate_hidden_states.narrow(
714            2,
715            0,
716            (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
717        )?;
718        intermediate_hidden_states = intermediate_hidden_states.reshape((
719            bs,
720            num_concurrent_media,
721            num_tiles,
722            num_patches,
723            (),
724        ))?;
725
726        // Concatenate final hidden state and intermediate hidden states
727        Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
728    }
729
730    fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
731        let (bs, _, hidden_size) = hidden_state.dims3()?;
732        let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
733        Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
734    }
735
736    pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
737        let mut layers = Vec::new();
738        for layer in &mut self.global_transformer.layers {
739            layers.push(&mut layer.self_attn.q_proj);
740            layers.push(&mut layer.self_attn.k_proj);
741            layers.push(&mut layer.self_attn.v_proj);
742            layers.push(&mut layer.self_attn.o_proj);
743
744            layers.push(&mut layer.mlp.fc1);
745            layers.push(&mut layer.mlp.fc2);
746        }
747        for layer in &mut self.transformer.layers {
748            layers.push(&mut layer.self_attn.q_proj);
749            layers.push(&mut layer.self_attn.k_proj);
750            layers.push(&mut layer.self_attn.v_proj);
751            layers.push(&mut layer.self_attn.o_proj);
752
753            layers.push(&mut layer.mlp.fc1);
754            layers.push(&mut layer.mlp.fc2);
755        }
756        layers
757    }
758}
759
760impl IsqModel for MLlamaVisionModel {
761    fn get_layers(
762        &mut self,
763    ) -> (
764        Vec<(
765            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
766            Option<usize>,
767        )>,
768        &dyn crate::device_map::DeviceMapper,
769    ) {
770        unreachable!("MLlamaVision model cannot be quantized.");
771    }
772    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
773        let uvb = UnVarBuilder::new();
774
775        uvb.pp("patch_embedding").add(&self.patch_embedding);
776        uvb.add_tensor("class_embedding", self.class_embedding.clone());
777
778        // gated_positional_embedding
779        uvb.pp("gated_positional_embedding")
780            .extend(self.gated_positional_embedding.residual_tensors());
781
782        // pre_tile_positional_embedding
783        uvb.pp("pre_tile_positional_embedding")
784            .extend(self.pre_tile_positional_embedding.residual_tensors());
785
786        // post_tile_positional_embedding
787        uvb.pp("post_tile_positional_embedding")
788            .extend(self.post_tile_positional_embedding.residual_tensors());
789
790        uvb.pp("layernorm_pre").add(&self.layernorm_pre);
791        uvb.pp("layernorm_post").add(&self.layernorm_post);
792
793        // transformer
794        uvb.pp("transformer")
795            .extend(self.transformer.residual_tensors());
796
797        // global_transformer
798        uvb.pp("global_transformer")
799            .extend(self.global_transformer.residual_tensors());
800
801        uvb.to_safetensors()
802    }
803}