mistralrs_core/vision_models/llama4/
vision.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{LayerNorm, LayerNormConfig, Linear, Module};
7use indicatif::MultiProgress;
8use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
9
10use crate::{
11    attention::SdpaParams,
12    layers::{layer_norm, linear_no_bias, Activation, Sdpa},
13    ops::RepeatInterleaveOp,
14    pipeline::IsqModel,
15    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
16};
17
18use super::config::VisionConfig;
19
20struct Llama4UnfoldConvolution {
21    linear: Linear,
22    kernel_size: usize,
23    patch_size: usize,
24}
25
26impl Llama4UnfoldConvolution {
27    fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
28        let kernel_size = cfg.patch_size;
29        let linear = linear_no_bias(
30            cfg.num_channels * kernel_size * kernel_size,
31            cfg.hidden_size,
32            vb.pp("linear"),
33        )?;
34        Ok(Self {
35            linear,
36            kernel_size,
37            patch_size: cfg.patch_size,
38        })
39    }
40
41    fn unfold(&self, xs: &Tensor) -> Result<Tensor> {
42        // In original code
43        let kernel_size = (self.kernel_size, self.kernel_size);
44        let stride = (self.patch_size, self.patch_size);
45        let padding = (0, 0);
46        let dilation = (1, 1);
47        let (bs, c, h, w) = xs.dims4()?;
48
49        let h_out = (h + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1) / stride.0 + 1;
50        let w_out = (w + 2 * padding.1 - dilation.1 * (kernel_size.1 - 1) - 1) / stride.1 + 1;
51
52        // Extract blocks
53        let mut blocks = Vec::new();
54        for i in (0..h - kernel_size.0 * dilation.0 + 1).step_by(stride.0) {
55            for j in (0..w - kernel_size.1 * dilation.1 + 1).step_by(stride.1) {
56                let mut block = Vec::new();
57                for di in 0..kernel_size.0 {
58                    for dj in 0..kernel_size.1 {
59                        let h_idx = i + di * dilation.0;
60                        let w_idx = j + dj * dilation.0;
61                        // Get the block for all channels and add to our list
62                        block.push(xs.i((.., .., h_idx, w_idx))?);
63                    }
64                }
65
66                // Stack the channel-blocks
67                // (b, k*k, c)
68                let mut block = Tensor::stack(&block, 1)?;
69                block = block.permute((0, 2, 1))?;
70                blocks.push(block);
71            }
72        }
73
74        // (b, c, k*k, l)
75        let mut result = Tensor::stack(&blocks, D::Minus1)?;
76        // (b, c*k*k, l)
77        result = result.reshape((bs, c * kernel_size.0 * kernel_size.1, h_out * w_out))?;
78        Ok(result)
79    }
80
81    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
82        // let hidden_states = {
83        //     let mut patches = hidden_states
84        //         .unfold(2, self.kernel_size, self.patch_size)?
85        //         .unfold(3, self.kernel_size, self.patch_size)?;
86        //     patches = patches.contiguous()?.permute((0, 2, 3, 1, 4, 5))?;
87        //     let b = patches.dim(0)?;
88        //     let out_h = patches.dim(1)?;
89        //     let out_w = patches.dim(2)?;
90        //     let c = patches.dim(3)?;
91        //     patches.reshape((b, out_h * out_w, c * self.kernel_size * self.kernel_size))?
92        // };
93
94        let mut hidden_states = self.unfold(hidden_states)?;
95        hidden_states = hidden_states.transpose(1, 2)?;
96        self.linear.forward(&hidden_states)
97    }
98}
99
100struct Llama4VisionAttention {
101    q_proj: Arc<dyn QuantMethod>,
102    k_proj: Arc<dyn QuantMethod>,
103    v_proj: Arc<dyn QuantMethod>,
104    o_proj: Arc<dyn QuantMethod>,
105    sdpa_params: SdpaParams,
106    head_dim: usize,
107    freqs: Llama4VisionRotaryEmbedding,
108}
109
110impl Llama4VisionAttention {
111    fn new(
112        cfg: &VisionConfig,
113        vb: ShardedVarBuilder,
114        freqs: Llama4VisionRotaryEmbedding,
115        comm: &Arc<mistralrs_quant::Comm>,
116    ) -> Result<Self> {
117        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
118        Ok(Self {
119            q_proj: ColumnParallelLayer::new(
120                cfg.hidden_size,
121                cfg.num_attention_heads * head_dim,
122                &None,
123                true,
124                comm,
125                vb.pp("q_proj"),
126            )?,
127            k_proj: ColumnParallelLayer::new(
128                cfg.hidden_size,
129                cfg.num_attention_heads * head_dim,
130                &None,
131                true,
132                comm,
133                vb.pp("k_proj"),
134            )?,
135            v_proj: ColumnParallelLayer::new(
136                cfg.hidden_size,
137                cfg.num_attention_heads * head_dim,
138                &None,
139                true,
140                comm,
141                vb.pp("v_proj"),
142            )?,
143            o_proj: RowParallelLayer::new(
144                cfg.hidden_size,
145                cfg.num_attention_heads * head_dim,
146                &None,
147                true,
148                comm,
149                vb.pp("o_proj"),
150            )?,
151            sdpa_params: SdpaParams {
152                n_kv_groups: 1,
153                use_flash_attn: false,
154                softcap: None,
155                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
156                sliding_window: None,
157            },
158            head_dim,
159            freqs,
160        })
161    }
162
163    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
164        let mut hidden_state = hidden_state.clone();
165        let original_dtype = hidden_state.dtype();
166        if let Some(t) = self.q_proj.quantized_act_type() {
167            hidden_state = hidden_state.to_dtype(t)?;
168        }
169        let mut q = self.q_proj.forward(&hidden_state)?;
170        let mut k = self.k_proj.forward(&hidden_state)?;
171        let mut v = self.v_proj.forward(&hidden_state)?;
172        if self.q_proj.quantized_act_type().is_some() {
173            q = q.to_dtype(original_dtype)?;
174            k = k.to_dtype(original_dtype)?;
175            v = v.to_dtype(original_dtype)?;
176        }
177
178        // Should be same, no caching...
179        let (bs, q_sq, _) = q.dims3()?;
180        let (_, k_sq, _) = k.dims3()?;
181
182        q = q
183            .reshape((bs, q_sq, (), self.head_dim))?
184            .transpose(1, 2)?
185            .contiguous()?;
186        k = k
187            .reshape((bs, k_sq, (), self.head_dim))?
188            .transpose(1, 2)?
189            .contiguous()?;
190        v = v
191            .reshape((bs, k_sq, (), self.head_dim))?
192            .transpose(1, 2)?
193            .contiguous()?;
194
195        // Apply rope
196        {
197            q = candle_nn::rotary_emb::rope_i(&q, &self.freqs.cos, &self.freqs.sin)?;
198            k = candle_nn::rotary_emb::rope_i(&k, &self.freqs.cos, &self.freqs.sin)?;
199        }
200
201        let mut attn_output = Sdpa
202            .run_attention(&q, &k, &v, attention_mask, None, &self.sdpa_params)?
203            .transpose(1, 2)?
204            .contiguous()?
205            .reshape((bs, q_sq, ()))?
206            .to_dtype(q.dtype())?;
207
208        if let Some(t) = self.q_proj.quantized_act_type() {
209            attn_output = attn_output.to_dtype(t)?;
210        }
211        let mut res = self.o_proj.forward(&attn_output)?;
212        if self.q_proj.quantized_act_type().is_some() {
213            res = res.to_dtype(original_dtype)?;
214        }
215        Ok(res)
216    }
217}
218
219struct Llama4Mlp {
220    act: Activation,
221    fc1: Arc<dyn QuantMethod>,
222    fc2: Arc<dyn QuantMethod>,
223}
224
225impl Llama4Mlp {
226    fn new(
227        cfg: &VisionConfig,
228        vb: ShardedVarBuilder,
229        comm: &Arc<mistralrs_quant::Comm>,
230    ) -> Result<Self> {
231        Ok(Self {
232            act: cfg.hidden_act,
233            fc1: ColumnParallelLayer::new(
234                cfg.hidden_size,
235                cfg.intermediate_size,
236                &None,
237                true,
238                comm,
239                vb.pp("fc1"),
240            )?,
241            fc2: RowParallelLayer::new(
242                cfg.intermediate_size,
243                cfg.hidden_size,
244                &None,
245                true,
246                comm,
247                vb.pp("fc2"),
248            )?,
249        })
250    }
251
252    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
253        let original_dtype = hidden_states.dtype();
254        let mut hidden_states = hidden_states.clone();
255        if let Some(t) = self.fc1.quantized_act_type() {
256            hidden_states = hidden_states.to_dtype(t)?;
257        }
258        hidden_states = self.fc1.forward(&hidden_states)?;
259        hidden_states = self.act.forward(&hidden_states)?;
260        hidden_states = self.fc2.forward(&hidden_states)?;
261        if self.fc1.quantized_act_type().is_some() {
262            hidden_states = hidden_states.to_dtype(original_dtype)?;
263        }
264        Ok(hidden_states)
265    }
266}
267
268struct Llama4VisionEncoderLayer {
269    self_attn: Llama4VisionAttention,
270    mlp: Llama4Mlp,
271    input_layernorm: LayerNorm,
272    post_attention_layernorm: LayerNorm,
273}
274
275impl Llama4VisionEncoderLayer {
276    fn new(
277        cfg: &VisionConfig,
278        vb: ShardedVarBuilder,
279        freqs: Llama4VisionRotaryEmbedding,
280        real_dev: &Device,
281        comm: &Arc<mistralrs_quant::Comm>,
282    ) -> Result<Self> {
283        let self_attn = Llama4VisionAttention::new(cfg, vb.pp("self_attn"), freqs, comm)?;
284        let mlp = Llama4Mlp::new(cfg, vb.pp("mlp"), comm)?;
285
286        let input_layernorm = layer_norm(
287            cfg.hidden_size,
288            cfg.norm_eps,
289            vb.pp("input_layernorm").set_device(real_dev.clone()),
290        )?;
291        let post_attention_layernorm = layer_norm(
292            cfg.hidden_size,
293            cfg.norm_eps,
294            vb.pp("post_attention_layernorm")
295                .set_device(real_dev.clone()),
296        )?;
297
298        Ok(Self {
299            self_attn,
300            mlp,
301            input_layernorm,
302            post_attention_layernorm,
303        })
304    }
305
306    fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
307        // Self attn
308        let residual = hidden_state;
309        let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
310
311        hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
312        hidden_state = (residual + hidden_state)?;
313
314        // FF
315        let residual = hidden_state.clone();
316        hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
317
318        hidden_state = self.mlp.forward(&hidden_state)?;
319        residual + hidden_state
320    }
321}
322
323struct Llama4VisionEncoder {
324    layers: Vec<Llama4VisionEncoderLayer>,
325}
326
327impl Llama4VisionEncoder {
328    fn new(
329        cfg: &VisionConfig,
330        num_layers: usize,
331        vb: ShardedVarBuilder,
332        freqs: Llama4VisionRotaryEmbedding,
333        real_dev: &Device,
334        comm: &Arc<mistralrs_quant::Comm>,
335        multi_progress: &Arc<MultiProgress>,
336    ) -> Result<Self> {
337        let mut layers = Vec::with_capacity(num_layers);
338        let layers_vb = vb.pp("layers");
339        for i in NiceProgressBar::<_, 'b'>(
340            0..num_layers,
341            "Loading vision repeating layers",
342            multi_progress,
343        ) {
344            layers.push(Llama4VisionEncoderLayer::new(
345                cfg,
346                layers_vb.pp(i),
347                freqs.clone(),
348                real_dev,
349                comm,
350            )?);
351        }
352        Ok(Self { layers })
353    }
354
355    fn forward_with_states(
356        &self,
357        hidden_state: &Tensor,
358        attention_mask: Option<&Tensor>,
359    ) -> Result<Tensor> {
360        let mut hidden_state = hidden_state.clone();
361        for layer in self.layers.iter() {
362            hidden_state = layer.forward(&hidden_state, attention_mask)?;
363        }
364        Ok(hidden_state)
365    }
366
367    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
368        let uvb_t = UnVarBuilder::new();
369
370        for (i, layer) in self.layers.iter().enumerate() {
371            let uvb_l = uvb_t.pp("layers").pp(i);
372            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
373            uvb_l
374                .pp("post_attention_layernorm")
375                .add(&layer.post_attention_layernorm);
376        }
377
378        uvb_t.to_safetensors()
379    }
380}
381
382struct Llama4VisionPixelShuffleMLP {
383    act: Activation,
384    fc1: Arc<dyn QuantMethod>,
385    fc2: Arc<dyn QuantMethod>,
386}
387
388impl Llama4VisionPixelShuffleMLP {
389    fn new(
390        cfg: &VisionConfig,
391        vb: ShardedVarBuilder,
392        comm: &Arc<mistralrs_quant::Comm>,
393    ) -> Result<Self> {
394        Ok(Self {
395            act: Activation::Gelu,
396            fc1: ColumnParallelLayer::new(
397                cfg.intermediate_size,
398                cfg.projector_input_dim,
399                &None,
400                false,
401                comm,
402                vb.pp("fc1"),
403            )?,
404            fc2: RowParallelLayer::new(
405                cfg.projector_input_dim,
406                cfg.projector_output_dim,
407                &None,
408                false,
409                comm,
410                vb.pp("fc2"),
411            )?,
412        })
413    }
414
415    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
416        let original_dtype = hidden_states.dtype();
417        let mut hidden_states = hidden_states.clone();
418        if let Some(t) = self.fc1.quantized_act_type() {
419            hidden_states = hidden_states.to_dtype(t)?;
420        }
421        hidden_states = self.act.forward(
422            &self
423                .fc2
424                .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?,
425        )?;
426        if self.fc1.quantized_act_type().is_some() {
427            hidden_states = hidden_states.to_dtype(original_dtype)?;
428        }
429        Ok(hidden_states)
430    }
431}
432
433struct Llama4VisionPixelShuffle {
434    mlp: Llama4VisionPixelShuffleMLP,
435    pixel_shuffle_ratio: f32,
436}
437
438impl Llama4VisionPixelShuffle {
439    fn new(
440        cfg: &VisionConfig,
441        vb: ShardedVarBuilder,
442        comm: &Arc<mistralrs_quant::Comm>,
443    ) -> Result<Self> {
444        let mlp = Llama4VisionPixelShuffleMLP::new(cfg, vb.pp("mlp"), comm)?;
445        Ok(Self {
446            mlp,
447            pixel_shuffle_ratio: cfg.pixel_shuffle_ratio,
448        })
449    }
450
451    fn pixel_shuffle(&self, xs: &Tensor) -> Result<Tensor> {
452        let (bs, num_patches, _c) = xs.dims3()?;
453        let patch_size = (num_patches as f32).sqrt() as usize;
454
455        let mut xs = xs.reshape((bs, patch_size, patch_size, ()))?;
456        let (_bs, h, w, c) = xs.dims4()?;
457
458        xs = xs.reshape((
459            bs,
460            h,
461            (w as f32 * self.pixel_shuffle_ratio) as usize,
462            (c as f32 / self.pixel_shuffle_ratio) as usize,
463        ))?;
464        xs = xs.permute((0, 2, 1, 3))?.contiguous()?;
465
466        xs = xs.reshape((
467            bs,
468            (h as f32 * self.pixel_shuffle_ratio) as usize,
469            (w as f32 * self.pixel_shuffle_ratio) as usize,
470            (c as f32 / self.pixel_shuffle_ratio.powi(2)) as usize,
471        ))?;
472        xs = xs.permute((0, 2, 1, 3))?.contiguous()?;
473
474        xs.reshape((bs, (), xs.dim(D::Minus1)?))
475    }
476
477    fn forward(&self, encoded_patches: &Tensor) -> Result<Tensor> {
478        let encoded_patches = self.pixel_shuffle(encoded_patches)?;
479        self.mlp.forward(&encoded_patches)
480    }
481}
482
483#[derive(Clone)]
484struct Llama4VisionRotaryEmbedding {
485    cos: Tensor,
486    sin: Tensor,
487}
488
489impl Llama4VisionRotaryEmbedding {
490    fn new(cfg: &VisionConfig, device: &Device, dtype: DType) -> Result<Self> {
491        let idx = cfg.image_size / cfg.patch_size;
492        let mut img_idx =
493            Tensor::arange(0f32, idx.pow(2) as f32, device)?.reshape((idx.pow(2), 1))?;
494        img_idx = Tensor::cat(&[&img_idx, &img_idx.narrow(0, 0, 1)?], 0)?;
495        // Insert ID_CLS_TOKEN in the bottom right
496        img_idx = img_idx.slice_assign(
497            &[
498                &(img_idx.dim(0)? - 1..img_idx.dim(0)?),
499                &(img_idx.dim(1)? - 1..img_idx.dim(1)?),
500            ],
501            &Tensor::new(-2f32, device)?.reshape((1, 1))?,
502        )?;
503        let img_ids_flat = img_idx.flatten_all()?.to_vec1::<f32>()?;
504        // frequencies_x = img_idx % idx
505        // get the coordinates of the 2d matrix along x
506        let frequencies_x = {
507            let frequencies_x = img_ids_flat
508                .iter()
509                .map(|x| x % idx as f32)
510                .collect::<Vec<_>>();
511            Tensor::from_vec(frequencies_x, img_idx.shape().clone(), device)?
512        };
513        // frequencies_y = img_idx // idx
514        // get the coordinates of the 2d matrix along y
515        let frequencies_y = {
516            let frequencies_y = img_ids_flat
517                .iter()
518                .map(|x| x / idx as f32)
519                .collect::<Vec<_>>();
520            Tensor::from_vec(frequencies_y, img_idx.shape().clone(), device)?
521        };
522        let rope_freq = {
523            let freq_dim = cfg.hidden_size / cfg.num_attention_heads / 2;
524            let freqs: Vec<_> = (0..freq_dim)
525                .step_by(2)
526                .take(freq_dim / 2)
527                .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / freq_dim as f32))
528                .collect();
529            let freqs_len = freqs.len();
530            Tensor::from_vec(freqs, freqs_len, device)?
531        };
532        let freqs_x = (frequencies_x + 1.)?
533            .unsqueeze(D::Minus1)?
534            .broadcast_mul(&rope_freq.unsqueeze(0)?.unsqueeze(0)?)?
535            .repeat_interleave(2, D::Minus1)?;
536        let freqs_y = (frequencies_y + 1.)?
537            .unsqueeze(D::Minus1)?
538            .broadcast_mul(&rope_freq.unsqueeze(0)?.unsqueeze(0)?)?
539            .repeat_interleave(2, D::Minus1)?;
540        let mut freqs = {
541            let freqs = Tensor::cat(&[freqs_x, freqs_y], D::Minus1)?.contiguous()?;
542            // This implements [..., ::2]
543            let indices_every_two = Tensor::new(
544                (0..freqs.dim(D::Minus1)?)
545                    .step_by(2)
546                    .map(|x| x as u32)
547                    .collect::<Vec<_>>(),
548                device,
549            )?;
550            freqs.index_select(&indices_every_two, D::Minus1)?
551        };
552        freqs = freqs.squeeze(1)?;
553        freqs = freqs.lt(0.)?.where_cond(&freqs.zeros_like()?, &freqs)?;
554
555        Ok(Self {
556            cos: freqs.cos()?.to_dtype(dtype)?,
557            sin: freqs.sin()?.to_dtype(dtype)?,
558        })
559    }
560}
561
562pub(super) struct Llama4VisionModel {
563    patch_embedding: Llama4UnfoldConvolution,
564    class_embedding: Tensor,
565    positional_embedding_vlm: Tensor,
566    layernorm_pre: LayerNorm,
567    layernorm_post: LayerNorm,
568    model: Llama4VisionEncoder,
569    vision_adapter: Llama4VisionPixelShuffle,
570}
571
572impl Llama4VisionModel {
573    pub(super) fn new(
574        cfg: &VisionConfig,
575        vb: ShardedVarBuilder,
576        real_dev: &Device,
577        comm: &Arc<mistralrs_quant::Comm>,
578        multi_progress: &Arc<MultiProgress>,
579    ) -> Result<Self> {
580        let patch_embedding = Llama4UnfoldConvolution::new(
581            cfg,
582            vb.pp("patch_embedding").set_device(real_dev.clone()),
583        )?;
584
585        let class_embedding = vb
586            .get((cfg.hidden_size,), "class_embedding")?
587            .to_device(real_dev)?;
588        let num_patches = cfg.num_patches();
589        let positional_embedding_vlm = vb
590            .get((num_patches, cfg.hidden_size), "positional_embedding_vlm")?
591            .to_device(real_dev)?;
592
593        // layer norms
594        let layernorm_pre = layer_norm(
595            cfg.hidden_size,
596            LayerNormConfig::default(),
597            vb.pp("layernorm_pre").set_device(real_dev.clone()),
598        )?;
599        let layernorm_post = layer_norm(
600            cfg.hidden_size,
601            LayerNormConfig::default(),
602            vb.pp("layernorm_post").set_device(real_dev.clone()),
603        )?;
604
605        let rotary_embedding = Llama4VisionRotaryEmbedding::new(cfg, real_dev, vb.dtype())?;
606        let model = Llama4VisionEncoder::new(
607            cfg,
608            cfg.num_hidden_layers,
609            vb.pp("model"),
610            rotary_embedding,
611            real_dev,
612            comm,
613            multi_progress,
614        )?;
615
616        let vision_adapter = Llama4VisionPixelShuffle::new(cfg, vb.pp("vision_adapter"), comm)?;
617
618        assert_eq!(cfg.vision_feature_layer, -1);
619
620        Ok(Self {
621            patch_embedding,
622            class_embedding,
623            positional_embedding_vlm,
624            layernorm_post,
625            layernorm_pre,
626            model,
627            vision_adapter,
628        })
629    }
630
631    pub(super) fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
632        let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
633
634        let (bs_times_num_tiles, _num_channels, _height, _width) = pixel_values.dims4()?;
635        let num_concurrent_media = 1;
636
637        // Patch embedding
638        let mut hidden_state = self.patch_embedding.forward(&pixel_values)?;
639        let (_, mut num_patches, hidden_dim) = hidden_state.dims3()?;
640
641        // Add cls token
642        hidden_state = hidden_state.reshape((
643            bs_times_num_tiles * num_concurrent_media,
644            num_patches,
645            hidden_dim,
646        ))?;
647        let class_embedding =
648            self.class_embedding
649                .expand((hidden_state.dim(0)?, 1, hidden_state.dim(D::Minus1)?))?;
650        hidden_state = Tensor::cat(&[hidden_state, class_embedding], 1)?;
651        num_patches += 1;
652
653        // Position embeddings
654        hidden_state = hidden_state.reshape((
655            bs_times_num_tiles * num_concurrent_media,
656            num_patches,
657            hidden_dim,
658        ))?;
659        hidden_state = hidden_state.broadcast_add(&self.positional_embedding_vlm)?;
660
661        hidden_state = self.layernorm_pre.forward(&hidden_state)?;
662
663        hidden_state = hidden_state.reshape((bs_times_num_tiles, (), hidden_dim))?;
664
665        // Apply encoder
666        hidden_state =
667            hidden_state.reshape((bs_times_num_tiles * num_concurrent_media, (), hidden_dim))?;
668        hidden_state = self.model.forward_with_states(&hidden_state, None)?;
669
670        hidden_state = self.layernorm_post.forward(&hidden_state)?;
671
672        hidden_state = hidden_state.narrow(1, 0, hidden_state.dim(1)? - 1)?;
673
674        self.vision_adapter.forward(&hidden_state)
675    }
676
677    pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
678        let mut layers = Vec::new();
679        for layer in &mut self.model.layers {
680            layers.push(&mut layer.self_attn.q_proj);
681            layers.push(&mut layer.self_attn.k_proj);
682            layers.push(&mut layer.self_attn.v_proj);
683            layers.push(&mut layer.self_attn.o_proj);
684
685            layers.push(&mut layer.mlp.fc1);
686            layers.push(&mut layer.mlp.fc2);
687        }
688        layers.push(&mut self.vision_adapter.mlp.fc1);
689        layers.push(&mut self.vision_adapter.mlp.fc2);
690        layers
691    }
692}
693
694impl IsqModel for Llama4VisionModel {
695    fn get_layers(
696        &mut self,
697    ) -> (
698        Vec<(
699            &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
700            Option<usize>,
701        )>,
702        &dyn crate::device_map::DeviceMapper,
703    ) {
704        unreachable!("Llama4Vision model cannot be quantized.");
705    }
706    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
707        let uvb = UnVarBuilder::new();
708
709        uvb.pp("patch_embedding")
710            .pp("linear")
711            .add(&self.patch_embedding.linear);
712        uvb.add_tensor("class_embedding", self.class_embedding.clone());
713        uvb.add_tensor(
714            "positional_embedding_vlm",
715            self.positional_embedding_vlm.clone(),
716        );
717
718        uvb.pp("layernorm_pre").add(&self.layernorm_pre);
719        uvb.pp("layernorm_post").add(&self.layernorm_post);
720
721        uvb.pp("model").extend(self.model.residual_tensors());
722
723        uvb.to_safetensors()
724    }
725}