1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{LayerNorm, LayerNormConfig, Linear, Module};
7use indicatif::MultiProgress;
8use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
9
10use crate::{
11 attention::SdpaParams,
12 layers::{layer_norm, linear_no_bias, Activation, Sdpa},
13 ops::RepeatInterleaveOp,
14 pipeline::IsqModel,
15 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
16};
17
18use super::config::VisionConfig;
19
20struct Llama4UnfoldConvolution {
21 linear: Linear,
22 kernel_size: usize,
23 patch_size: usize,
24}
25
26impl Llama4UnfoldConvolution {
27 fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
28 let kernel_size = cfg.patch_size;
29 let linear = linear_no_bias(
30 cfg.num_channels * kernel_size * kernel_size,
31 cfg.hidden_size,
32 vb.pp("linear"),
33 )?;
34 Ok(Self {
35 linear,
36 kernel_size,
37 patch_size: cfg.patch_size,
38 })
39 }
40
41 fn unfold(&self, xs: &Tensor) -> Result<Tensor> {
42 let kernel_size = (self.kernel_size, self.kernel_size);
44 let stride = (self.patch_size, self.patch_size);
45 let padding = (0, 0);
46 let dilation = (1, 1);
47 let (bs, c, h, w) = xs.dims4()?;
48
49 let h_out = (h + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1) / stride.0 + 1;
50 let w_out = (w + 2 * padding.1 - dilation.1 * (kernel_size.1 - 1) - 1) / stride.1 + 1;
51
52 let mut blocks = Vec::new();
54 for i in (0..h - kernel_size.0 * dilation.0 + 1).step_by(stride.0) {
55 for j in (0..w - kernel_size.1 * dilation.1 + 1).step_by(stride.1) {
56 let mut block = Vec::new();
57 for di in 0..kernel_size.0 {
58 for dj in 0..kernel_size.1 {
59 let h_idx = i + di * dilation.0;
60 let w_idx = j + dj * dilation.0;
61 block.push(xs.i((.., .., h_idx, w_idx))?);
63 }
64 }
65
66 let mut block = Tensor::stack(&block, 1)?;
69 block = block.permute((0, 2, 1))?;
70 blocks.push(block);
71 }
72 }
73
74 let mut result = Tensor::stack(&blocks, D::Minus1)?;
76 result = result.reshape((bs, c * kernel_size.0 * kernel_size.1, h_out * w_out))?;
78 Ok(result)
79 }
80
81 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
82 let mut hidden_states = self.unfold(hidden_states)?;
95 hidden_states = hidden_states.transpose(1, 2)?;
96 self.linear.forward(&hidden_states)
97 }
98}
99
100struct Llama4VisionAttention {
101 q_proj: Arc<dyn QuantMethod>,
102 k_proj: Arc<dyn QuantMethod>,
103 v_proj: Arc<dyn QuantMethod>,
104 o_proj: Arc<dyn QuantMethod>,
105 sdpa_params: SdpaParams,
106 head_dim: usize,
107 freqs: Llama4VisionRotaryEmbedding,
108}
109
110impl Llama4VisionAttention {
111 fn new(
112 cfg: &VisionConfig,
113 vb: ShardedVarBuilder,
114 freqs: Llama4VisionRotaryEmbedding,
115 comm: &Arc<mistralrs_quant::Comm>,
116 ) -> Result<Self> {
117 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
118 Ok(Self {
119 q_proj: ColumnParallelLayer::new(
120 cfg.hidden_size,
121 cfg.num_attention_heads * head_dim,
122 &None,
123 true,
124 comm,
125 vb.pp("q_proj"),
126 )?,
127 k_proj: ColumnParallelLayer::new(
128 cfg.hidden_size,
129 cfg.num_attention_heads * head_dim,
130 &None,
131 true,
132 comm,
133 vb.pp("k_proj"),
134 )?,
135 v_proj: ColumnParallelLayer::new(
136 cfg.hidden_size,
137 cfg.num_attention_heads * head_dim,
138 &None,
139 true,
140 comm,
141 vb.pp("v_proj"),
142 )?,
143 o_proj: RowParallelLayer::new(
144 cfg.hidden_size,
145 cfg.num_attention_heads * head_dim,
146 &None,
147 true,
148 comm,
149 vb.pp("o_proj"),
150 )?,
151 sdpa_params: SdpaParams {
152 n_kv_groups: 1,
153 use_flash_attn: false,
154 softcap: None,
155 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
156 sliding_window: None,
157 },
158 head_dim,
159 freqs,
160 })
161 }
162
163 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
164 let mut hidden_state = hidden_state.clone();
165 let original_dtype = hidden_state.dtype();
166 if let Some(t) = self.q_proj.quantized_act_type() {
167 hidden_state = hidden_state.to_dtype(t)?;
168 }
169 let mut q = self.q_proj.forward(&hidden_state)?;
170 let mut k = self.k_proj.forward(&hidden_state)?;
171 let mut v = self.v_proj.forward(&hidden_state)?;
172 if self.q_proj.quantized_act_type().is_some() {
173 q = q.to_dtype(original_dtype)?;
174 k = k.to_dtype(original_dtype)?;
175 v = v.to_dtype(original_dtype)?;
176 }
177
178 let (bs, q_sq, _) = q.dims3()?;
180 let (_, k_sq, _) = k.dims3()?;
181
182 q = q
183 .reshape((bs, q_sq, (), self.head_dim))?
184 .transpose(1, 2)?
185 .contiguous()?;
186 k = k
187 .reshape((bs, k_sq, (), self.head_dim))?
188 .transpose(1, 2)?
189 .contiguous()?;
190 v = v
191 .reshape((bs, k_sq, (), self.head_dim))?
192 .transpose(1, 2)?
193 .contiguous()?;
194
195 {
197 q = candle_nn::rotary_emb::rope_i(&q, &self.freqs.cos, &self.freqs.sin)?;
198 k = candle_nn::rotary_emb::rope_i(&k, &self.freqs.cos, &self.freqs.sin)?;
199 }
200
201 let mut attn_output = Sdpa
202 .run_attention(&q, &k, &v, attention_mask, None, &self.sdpa_params)?
203 .transpose(1, 2)?
204 .contiguous()?
205 .reshape((bs, q_sq, ()))?
206 .to_dtype(q.dtype())?;
207
208 if let Some(t) = self.q_proj.quantized_act_type() {
209 attn_output = attn_output.to_dtype(t)?;
210 }
211 let mut res = self.o_proj.forward(&attn_output)?;
212 if self.q_proj.quantized_act_type().is_some() {
213 res = res.to_dtype(original_dtype)?;
214 }
215 Ok(res)
216 }
217}
218
219struct Llama4Mlp {
220 act: Activation,
221 fc1: Arc<dyn QuantMethod>,
222 fc2: Arc<dyn QuantMethod>,
223}
224
225impl Llama4Mlp {
226 fn new(
227 cfg: &VisionConfig,
228 vb: ShardedVarBuilder,
229 comm: &Arc<mistralrs_quant::Comm>,
230 ) -> Result<Self> {
231 Ok(Self {
232 act: cfg.hidden_act,
233 fc1: ColumnParallelLayer::new(
234 cfg.hidden_size,
235 cfg.intermediate_size,
236 &None,
237 true,
238 comm,
239 vb.pp("fc1"),
240 )?,
241 fc2: RowParallelLayer::new(
242 cfg.intermediate_size,
243 cfg.hidden_size,
244 &None,
245 true,
246 comm,
247 vb.pp("fc2"),
248 )?,
249 })
250 }
251
252 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
253 let original_dtype = hidden_states.dtype();
254 let mut hidden_states = hidden_states.clone();
255 if let Some(t) = self.fc1.quantized_act_type() {
256 hidden_states = hidden_states.to_dtype(t)?;
257 }
258 hidden_states = self.fc1.forward(&hidden_states)?;
259 hidden_states = self.act.forward(&hidden_states)?;
260 hidden_states = self.fc2.forward(&hidden_states)?;
261 if self.fc1.quantized_act_type().is_some() {
262 hidden_states = hidden_states.to_dtype(original_dtype)?;
263 }
264 Ok(hidden_states)
265 }
266}
267
268struct Llama4VisionEncoderLayer {
269 self_attn: Llama4VisionAttention,
270 mlp: Llama4Mlp,
271 input_layernorm: LayerNorm,
272 post_attention_layernorm: LayerNorm,
273}
274
275impl Llama4VisionEncoderLayer {
276 fn new(
277 cfg: &VisionConfig,
278 vb: ShardedVarBuilder,
279 freqs: Llama4VisionRotaryEmbedding,
280 real_dev: &Device,
281 comm: &Arc<mistralrs_quant::Comm>,
282 ) -> Result<Self> {
283 let self_attn = Llama4VisionAttention::new(cfg, vb.pp("self_attn"), freqs, comm)?;
284 let mlp = Llama4Mlp::new(cfg, vb.pp("mlp"), comm)?;
285
286 let input_layernorm = layer_norm(
287 cfg.hidden_size,
288 cfg.norm_eps,
289 vb.pp("input_layernorm").set_device(real_dev.clone()),
290 )?;
291 let post_attention_layernorm = layer_norm(
292 cfg.hidden_size,
293 cfg.norm_eps,
294 vb.pp("post_attention_layernorm")
295 .set_device(real_dev.clone()),
296 )?;
297
298 Ok(Self {
299 self_attn,
300 mlp,
301 input_layernorm,
302 post_attention_layernorm,
303 })
304 }
305
306 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
307 let residual = hidden_state;
309 let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
310
311 hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
312 hidden_state = (residual + hidden_state)?;
313
314 let residual = hidden_state.clone();
316 hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
317
318 hidden_state = self.mlp.forward(&hidden_state)?;
319 residual + hidden_state
320 }
321}
322
323struct Llama4VisionEncoder {
324 layers: Vec<Llama4VisionEncoderLayer>,
325}
326
327impl Llama4VisionEncoder {
328 fn new(
329 cfg: &VisionConfig,
330 num_layers: usize,
331 vb: ShardedVarBuilder,
332 freqs: Llama4VisionRotaryEmbedding,
333 real_dev: &Device,
334 comm: &Arc<mistralrs_quant::Comm>,
335 multi_progress: &Arc<MultiProgress>,
336 ) -> Result<Self> {
337 let mut layers = Vec::with_capacity(num_layers);
338 let layers_vb = vb.pp("layers");
339 for i in NiceProgressBar::<_, 'b'>(
340 0..num_layers,
341 "Loading vision repeating layers",
342 multi_progress,
343 ) {
344 layers.push(Llama4VisionEncoderLayer::new(
345 cfg,
346 layers_vb.pp(i),
347 freqs.clone(),
348 real_dev,
349 comm,
350 )?);
351 }
352 Ok(Self { layers })
353 }
354
355 fn forward_with_states(
356 &self,
357 hidden_state: &Tensor,
358 attention_mask: Option<&Tensor>,
359 ) -> Result<Tensor> {
360 let mut hidden_state = hidden_state.clone();
361 for layer in self.layers.iter() {
362 hidden_state = layer.forward(&hidden_state, attention_mask)?;
363 }
364 Ok(hidden_state)
365 }
366
367 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
368 let uvb_t = UnVarBuilder::new();
369
370 for (i, layer) in self.layers.iter().enumerate() {
371 let uvb_l = uvb_t.pp("layers").pp(i);
372 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
373 uvb_l
374 .pp("post_attention_layernorm")
375 .add(&layer.post_attention_layernorm);
376 }
377
378 uvb_t.to_safetensors()
379 }
380}
381
382struct Llama4VisionPixelShuffleMLP {
383 act: Activation,
384 fc1: Arc<dyn QuantMethod>,
385 fc2: Arc<dyn QuantMethod>,
386}
387
388impl Llama4VisionPixelShuffleMLP {
389 fn new(
390 cfg: &VisionConfig,
391 vb: ShardedVarBuilder,
392 comm: &Arc<mistralrs_quant::Comm>,
393 ) -> Result<Self> {
394 Ok(Self {
395 act: Activation::Gelu,
396 fc1: ColumnParallelLayer::new(
397 cfg.intermediate_size,
398 cfg.projector_input_dim,
399 &None,
400 false,
401 comm,
402 vb.pp("fc1"),
403 )?,
404 fc2: RowParallelLayer::new(
405 cfg.projector_input_dim,
406 cfg.projector_output_dim,
407 &None,
408 false,
409 comm,
410 vb.pp("fc2"),
411 )?,
412 })
413 }
414
415 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
416 let original_dtype = hidden_states.dtype();
417 let mut hidden_states = hidden_states.clone();
418 if let Some(t) = self.fc1.quantized_act_type() {
419 hidden_states = hidden_states.to_dtype(t)?;
420 }
421 hidden_states = self.act.forward(
422 &self
423 .fc2
424 .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?,
425 )?;
426 if self.fc1.quantized_act_type().is_some() {
427 hidden_states = hidden_states.to_dtype(original_dtype)?;
428 }
429 Ok(hidden_states)
430 }
431}
432
433struct Llama4VisionPixelShuffle {
434 mlp: Llama4VisionPixelShuffleMLP,
435 pixel_shuffle_ratio: f32,
436}
437
438impl Llama4VisionPixelShuffle {
439 fn new(
440 cfg: &VisionConfig,
441 vb: ShardedVarBuilder,
442 comm: &Arc<mistralrs_quant::Comm>,
443 ) -> Result<Self> {
444 let mlp = Llama4VisionPixelShuffleMLP::new(cfg, vb.pp("mlp"), comm)?;
445 Ok(Self {
446 mlp,
447 pixel_shuffle_ratio: cfg.pixel_shuffle_ratio,
448 })
449 }
450
451 fn pixel_shuffle(&self, xs: &Tensor) -> Result<Tensor> {
452 let (bs, num_patches, _c) = xs.dims3()?;
453 let patch_size = (num_patches as f32).sqrt() as usize;
454
455 let mut xs = xs.reshape((bs, patch_size, patch_size, ()))?;
456 let (_bs, h, w, c) = xs.dims4()?;
457
458 xs = xs.reshape((
459 bs,
460 h,
461 (w as f32 * self.pixel_shuffle_ratio) as usize,
462 (c as f32 / self.pixel_shuffle_ratio) as usize,
463 ))?;
464 xs = xs.permute((0, 2, 1, 3))?.contiguous()?;
465
466 xs = xs.reshape((
467 bs,
468 (h as f32 * self.pixel_shuffle_ratio) as usize,
469 (w as f32 * self.pixel_shuffle_ratio) as usize,
470 (c as f32 / self.pixel_shuffle_ratio.powi(2)) as usize,
471 ))?;
472 xs = xs.permute((0, 2, 1, 3))?.contiguous()?;
473
474 xs.reshape((bs, (), xs.dim(D::Minus1)?))
475 }
476
477 fn forward(&self, encoded_patches: &Tensor) -> Result<Tensor> {
478 let encoded_patches = self.pixel_shuffle(encoded_patches)?;
479 self.mlp.forward(&encoded_patches)
480 }
481}
482
483#[derive(Clone)]
484struct Llama4VisionRotaryEmbedding {
485 cos: Tensor,
486 sin: Tensor,
487}
488
489impl Llama4VisionRotaryEmbedding {
490 fn new(cfg: &VisionConfig, device: &Device, dtype: DType) -> Result<Self> {
491 let idx = cfg.image_size / cfg.patch_size;
492 let mut img_idx =
493 Tensor::arange(0f32, idx.pow(2) as f32, device)?.reshape((idx.pow(2), 1))?;
494 img_idx = Tensor::cat(&[&img_idx, &img_idx.narrow(0, 0, 1)?], 0)?;
495 img_idx = img_idx.slice_assign(
497 &[
498 &(img_idx.dim(0)? - 1..img_idx.dim(0)?),
499 &(img_idx.dim(1)? - 1..img_idx.dim(1)?),
500 ],
501 &Tensor::new(-2f32, device)?.reshape((1, 1))?,
502 )?;
503 let img_ids_flat = img_idx.flatten_all()?.to_vec1::<f32>()?;
504 let frequencies_x = {
507 let frequencies_x = img_ids_flat
508 .iter()
509 .map(|x| x % idx as f32)
510 .collect::<Vec<_>>();
511 Tensor::from_vec(frequencies_x, img_idx.shape().clone(), device)?
512 };
513 let frequencies_y = {
516 let frequencies_y = img_ids_flat
517 .iter()
518 .map(|x| x / idx as f32)
519 .collect::<Vec<_>>();
520 Tensor::from_vec(frequencies_y, img_idx.shape().clone(), device)?
521 };
522 let rope_freq = {
523 let freq_dim = cfg.hidden_size / cfg.num_attention_heads / 2;
524 let freqs: Vec<_> = (0..freq_dim)
525 .step_by(2)
526 .take(freq_dim / 2)
527 .map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / freq_dim as f32))
528 .collect();
529 let freqs_len = freqs.len();
530 Tensor::from_vec(freqs, freqs_len, device)?
531 };
532 let freqs_x = (frequencies_x + 1.)?
533 .unsqueeze(D::Minus1)?
534 .broadcast_mul(&rope_freq.unsqueeze(0)?.unsqueeze(0)?)?
535 .repeat_interleave(2, D::Minus1)?;
536 let freqs_y = (frequencies_y + 1.)?
537 .unsqueeze(D::Minus1)?
538 .broadcast_mul(&rope_freq.unsqueeze(0)?.unsqueeze(0)?)?
539 .repeat_interleave(2, D::Minus1)?;
540 let mut freqs = {
541 let freqs = Tensor::cat(&[freqs_x, freqs_y], D::Minus1)?.contiguous()?;
542 let indices_every_two = Tensor::new(
544 (0..freqs.dim(D::Minus1)?)
545 .step_by(2)
546 .map(|x| x as u32)
547 .collect::<Vec<_>>(),
548 device,
549 )?;
550 freqs.index_select(&indices_every_two, D::Minus1)?
551 };
552 freqs = freqs.squeeze(1)?;
553 freqs = freqs.lt(0.)?.where_cond(&freqs.zeros_like()?, &freqs)?;
554
555 Ok(Self {
556 cos: freqs.cos()?.to_dtype(dtype)?,
557 sin: freqs.sin()?.to_dtype(dtype)?,
558 })
559 }
560}
561
562pub(super) struct Llama4VisionModel {
563 patch_embedding: Llama4UnfoldConvolution,
564 class_embedding: Tensor,
565 positional_embedding_vlm: Tensor,
566 layernorm_pre: LayerNorm,
567 layernorm_post: LayerNorm,
568 model: Llama4VisionEncoder,
569 vision_adapter: Llama4VisionPixelShuffle,
570}
571
572impl Llama4VisionModel {
573 pub(super) fn new(
574 cfg: &VisionConfig,
575 vb: ShardedVarBuilder,
576 real_dev: &Device,
577 comm: &Arc<mistralrs_quant::Comm>,
578 multi_progress: &Arc<MultiProgress>,
579 ) -> Result<Self> {
580 let patch_embedding = Llama4UnfoldConvolution::new(
581 cfg,
582 vb.pp("patch_embedding").set_device(real_dev.clone()),
583 )?;
584
585 let class_embedding = vb
586 .get((cfg.hidden_size,), "class_embedding")?
587 .to_device(real_dev)?;
588 let num_patches = cfg.num_patches();
589 let positional_embedding_vlm = vb
590 .get((num_patches, cfg.hidden_size), "positional_embedding_vlm")?
591 .to_device(real_dev)?;
592
593 let layernorm_pre = layer_norm(
595 cfg.hidden_size,
596 LayerNormConfig::default(),
597 vb.pp("layernorm_pre").set_device(real_dev.clone()),
598 )?;
599 let layernorm_post = layer_norm(
600 cfg.hidden_size,
601 LayerNormConfig::default(),
602 vb.pp("layernorm_post").set_device(real_dev.clone()),
603 )?;
604
605 let rotary_embedding = Llama4VisionRotaryEmbedding::new(cfg, real_dev, vb.dtype())?;
606 let model = Llama4VisionEncoder::new(
607 cfg,
608 cfg.num_hidden_layers,
609 vb.pp("model"),
610 rotary_embedding,
611 real_dev,
612 comm,
613 multi_progress,
614 )?;
615
616 let vision_adapter = Llama4VisionPixelShuffle::new(cfg, vb.pp("vision_adapter"), comm)?;
617
618 assert_eq!(cfg.vision_feature_layer, -1);
619
620 Ok(Self {
621 patch_embedding,
622 class_embedding,
623 positional_embedding_vlm,
624 layernorm_post,
625 layernorm_pre,
626 model,
627 vision_adapter,
628 })
629 }
630
631 pub(super) fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
632 let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
633
634 let (bs_times_num_tiles, _num_channels, _height, _width) = pixel_values.dims4()?;
635 let num_concurrent_media = 1;
636
637 let mut hidden_state = self.patch_embedding.forward(&pixel_values)?;
639 let (_, mut num_patches, hidden_dim) = hidden_state.dims3()?;
640
641 hidden_state = hidden_state.reshape((
643 bs_times_num_tiles * num_concurrent_media,
644 num_patches,
645 hidden_dim,
646 ))?;
647 let class_embedding =
648 self.class_embedding
649 .expand((hidden_state.dim(0)?, 1, hidden_state.dim(D::Minus1)?))?;
650 hidden_state = Tensor::cat(&[hidden_state, class_embedding], 1)?;
651 num_patches += 1;
652
653 hidden_state = hidden_state.reshape((
655 bs_times_num_tiles * num_concurrent_media,
656 num_patches,
657 hidden_dim,
658 ))?;
659 hidden_state = hidden_state.broadcast_add(&self.positional_embedding_vlm)?;
660
661 hidden_state = self.layernorm_pre.forward(&hidden_state)?;
662
663 hidden_state = hidden_state.reshape((bs_times_num_tiles, (), hidden_dim))?;
664
665 hidden_state =
667 hidden_state.reshape((bs_times_num_tiles * num_concurrent_media, (), hidden_dim))?;
668 hidden_state = self.model.forward_with_states(&hidden_state, None)?;
669
670 hidden_state = self.layernorm_post.forward(&hidden_state)?;
671
672 hidden_state = hidden_state.narrow(1, 0, hidden_state.dim(1)? - 1)?;
673
674 self.vision_adapter.forward(&hidden_state)
675 }
676
677 pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
678 let mut layers = Vec::new();
679 for layer in &mut self.model.layers {
680 layers.push(&mut layer.self_attn.q_proj);
681 layers.push(&mut layer.self_attn.k_proj);
682 layers.push(&mut layer.self_attn.v_proj);
683 layers.push(&mut layer.self_attn.o_proj);
684
685 layers.push(&mut layer.mlp.fc1);
686 layers.push(&mut layer.mlp.fc2);
687 }
688 layers.push(&mut self.vision_adapter.mlp.fc1);
689 layers.push(&mut self.vision_adapter.mlp.fc2);
690 layers
691 }
692}
693
694impl IsqModel for Llama4VisionModel {
695 fn get_layers(
696 &mut self,
697 ) -> (
698 Vec<(
699 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
700 Option<usize>,
701 )>,
702 &dyn crate::device_map::DeviceMapper,
703 ) {
704 unreachable!("Llama4Vision model cannot be quantized.");
705 }
706 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
707 let uvb = UnVarBuilder::new();
708
709 uvb.pp("patch_embedding")
710 .pp("linear")
711 .add(&self.patch_embedding.linear);
712 uvb.add_tensor("class_embedding", self.class_embedding.clone());
713 uvb.add_tensor(
714 "positional_embedding_vlm",
715 self.positional_embedding_vlm.clone(),
716 );
717
718 uvb.pp("layernorm_pre").add(&self.layernorm_pre);
719 uvb.pp("layernorm_post").add(&self.layernorm_post);
720
721 uvb.pp("model").extend(self.model.residual_tensors());
722
723 uvb.to_safetensors()
724 }
725}