1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
8
9use crate::{
10 attention::SdpaParams,
11 layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
12 pipeline::IsqModel,
13 utils::unvarbuilder::UnVarBuilder,
14};
15
16use super::{MLlamaVisionConfig, VisionActivation};
17
18struct MLlamaPrecomputedPositionEmbedding {
19 gate: Tensor,
20 embedding: Tensor,
21 tile_embedding: Embedding,
22 num_patches: usize,
23 hidden_size: usize,
24 max_num_tiles: usize,
25}
26
27impl MLlamaPrecomputedPositionEmbedding {
28 fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
29 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
30 Ok(Self {
31 gate: vb.get((1,), "gate")?,
32 embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
33 tile_embedding: embedding(
34 cfg.max_aspect_ratio_id() + 1,
35 cfg.max_num_tiles * num_patches * cfg.hidden_size,
36 vb.pp("tile_embedding"),
37 &None,
38 )?,
39 num_patches,
40 hidden_size: cfg.hidden_size,
41 max_num_tiles: cfg.max_num_tiles,
42 })
43 }
44
45 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
47 let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
49 let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
50 1,
51 1,
52 self.num_patches,
53 self.hidden_size,
54 ))?)?;
55
56 let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
58 let bs = hidden_state.dim(0)?;
59 tile_position_embedding = tile_position_embedding.reshape((
60 bs,
61 self.max_num_tiles,
62 self.num_patches,
63 self.hidden_size,
64 ))?;
65 gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
66
67 hidden_state.broadcast_add(&gated_pos_embed)
68 }
69
70 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
71 let uvb_gpe = UnVarBuilder::new();
72
73 uvb_gpe.add_tensor("gate", self.gate.clone());
74 uvb_gpe.add_tensor("embedding", self.embedding.clone());
75 uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
76
77 uvb_gpe.to_safetensors()
78 }
79}
80
81struct MLlamaPrecomputedAspectRatioEmbedding {
82 embedding: Embedding,
83 gate: Option<Tensor>,
84 max_num_tiles: usize,
85 hidden_size: usize,
86}
87
88impl MLlamaPrecomputedAspectRatioEmbedding {
89 fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
90 Ok(Self {
91 embedding: embedding(
92 cfg.max_aspect_ratio_id() + 1,
93 cfg.max_num_tiles * cfg.hidden_size,
94 vb.pp("embedding"),
95 &None,
96 )?,
97 gate: if GATED {
98 Some(vb.get((1,), "gate")?)
99 } else {
100 None
101 },
102 max_num_tiles: cfg.max_num_tiles,
103 hidden_size: cfg.hidden_size,
104 })
105 }
106
107 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
108 let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
109 embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
110
111 if let Some(gate) = &self.gate {
112 embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
113 }
114
115 hidden_state.broadcast_add(&embeddings)
116 }
117
118 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
119 let uvb_ptpe = UnVarBuilder::new();
120
121 if let Some(gate) = self.gate.clone() {
122 uvb_ptpe.add_tensor("gate", gate);
123 }
124 uvb_ptpe.pp("embedding").add(&self.embedding);
125
126 uvb_ptpe.to_safetensors()
127 }
128}
129
130struct MLlamaVisionAttention {
131 q_proj: Arc<dyn QuantMethod>,
132 k_proj: Arc<dyn QuantMethod>,
133 v_proj: Arc<dyn QuantMethod>,
134 o_proj: Arc<dyn QuantMethod>,
135 sdpa_params: SdpaParams,
136 num_heads: usize,
137 head_dim: usize,
138}
139
140impl MLlamaVisionAttention {
141 fn new(
142 cfg: &MLlamaVisionConfig,
143 vb: ShardedVarBuilder,
144 comm: &Arc<mistralrs_quant::Comm>,
145 ) -> Result<Self> {
146 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
147 Ok(Self {
148 q_proj: ColumnParallelLayer::new(
149 cfg.hidden_size,
150 cfg.num_attention_heads * head_dim,
151 &None,
152 false,
153 comm,
154 vb.pp("q_proj"),
155 )?,
156 k_proj: ColumnParallelLayer::new(
157 cfg.hidden_size,
158 cfg.num_attention_heads * head_dim,
159 &None,
160 false,
161 comm,
162 vb.pp("k_proj"),
163 )?,
164 v_proj: ColumnParallelLayer::new(
165 cfg.hidden_size,
166 cfg.num_attention_heads * head_dim,
167 &None,
168 false,
169 comm,
170 vb.pp("v_proj"),
171 )?,
172 o_proj: RowParallelLayer::new(
173 cfg.hidden_size,
174 cfg.num_attention_heads * head_dim,
175 &None,
176 false,
177 comm,
178 vb.pp("o_proj"),
179 )?,
180 sdpa_params: SdpaParams {
181 n_kv_groups: 1,
182 softcap: None,
183 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
184 sliding_window: None,
185 },
186 num_heads: cfg.num_attention_heads,
187 head_dim,
188 })
189 }
190
191 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
193 let mut hidden_state = hidden_state.clone();
194 let original_dtype = hidden_state.dtype();
195 if let Some(t) = self.q_proj.quantized_act_type() {
196 hidden_state = hidden_state.to_dtype(t)?;
197 }
198 let mut q = self.q_proj.forward(&hidden_state)?;
199 let mut k = self.k_proj.forward(&hidden_state)?;
200 let mut v = self.v_proj.forward(&hidden_state)?;
201 if self.q_proj.quantized_act_type().is_some() {
202 q = q.to_dtype(original_dtype)?;
203 k = k.to_dtype(original_dtype)?;
204 v = v.to_dtype(original_dtype)?;
205 }
206
207 let (bs, q_sq, _) = q.dims3()?;
209 let (_, k_sq, _) = k.dims3()?;
210
211 q = q
212 .reshape((bs, q_sq, self.num_heads, self.head_dim))?
213 .transpose(1, 2)?;
214 k = k
215 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
216 .transpose(1, 2)?;
217 v = v
218 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
219 .transpose(1, 2)?;
220
221 let mut attn_output = Sdpa
222 .run_attention(
223 &q.contiguous()?,
224 &k.contiguous()?,
225 &v.contiguous()?,
226 attention_mask,
227 None,
228 &self.sdpa_params,
229 )?
230 .transpose(1, 2)?
231 .contiguous()?
232 .reshape((bs, q_sq, ()))?
233 .to_dtype(q.dtype())?;
234
235 if let Some(t) = self.q_proj.quantized_act_type() {
236 attn_output = attn_output.to_dtype(t)?;
237 }
238 let mut res = self.o_proj.forward(&attn_output)?;
239 if self.q_proj.quantized_act_type().is_some() {
240 res = res.to_dtype(original_dtype)?;
241 }
242 Ok(res)
243 }
244}
245
246struct MLlamaMlp {
247 act: VisionActivation,
248 fc1: Arc<dyn QuantMethod>,
249 fc2: Arc<dyn QuantMethod>,
250}
251
252impl MLlamaMlp {
253 fn new(
254 cfg: &MLlamaVisionConfig,
255 vb: ShardedVarBuilder,
256 comm: &Arc<mistralrs_quant::Comm>,
257 ) -> Result<Self> {
258 Ok(Self {
259 act: cfg.hidden_act,
260 fc1: ColumnParallelLayer::new(
261 cfg.hidden_size,
262 cfg.intermediate_size,
263 &None,
264 true,
265 comm,
266 vb.pp("fc1"),
267 )?,
268 fc2: RowParallelLayer::new(
269 cfg.intermediate_size,
270 cfg.hidden_size,
271 &None,
272 true,
273 comm,
274 vb.pp("fc2"),
275 )?,
276 })
277 }
278
279 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
281 let original_dtype = hidden_states.dtype();
282 let mut hidden_states = hidden_states.clone();
283 if let Some(t) = self.fc1.quantized_act_type() {
284 hidden_states = hidden_states.to_dtype(t)?;
285 }
286 hidden_states = self
287 .fc2
288 .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
289 if self.fc1.quantized_act_type().is_some() {
290 hidden_states = hidden_states.to_dtype(original_dtype)?;
291 }
292 Ok(hidden_states)
293 }
294}
295
296struct MLlamaVisionEncoderLayer {
297 self_attn: MLlamaVisionAttention,
298 mlp: MLlamaMlp,
299 input_layernorm: LayerNorm,
300 post_attention_layernorm: LayerNorm,
301 gate_attn: Option<Tensor>,
302 gate_ffn: Option<Tensor>,
303}
304
305impl MLlamaVisionEncoderLayer {
306 fn new<const GATED: bool>(
307 cfg: &MLlamaVisionConfig,
308 vb: ShardedVarBuilder,
309 real_dev: &Device,
310 comm: &Arc<mistralrs_quant::Comm>,
311 ) -> Result<Self> {
312 let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
313 let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
314
315 let input_layernorm = layer_norm(
316 cfg.hidden_size,
317 cfg.norm_eps,
318 vb.pp("input_layernorm").set_device(real_dev.clone()),
319 )?;
320 let post_attention_layernorm = layer_norm(
321 cfg.hidden_size,
322 cfg.norm_eps,
323 vb.pp("post_attention_layernorm")
324 .set_device(real_dev.clone()),
325 )?;
326
327 if GATED {
328 Ok(Self {
329 self_attn,
330 mlp,
331 input_layernorm,
332 post_attention_layernorm,
333 gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
334 gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
335 })
336 } else {
337 Ok(Self {
338 self_attn,
339 mlp,
340 input_layernorm,
341 post_attention_layernorm,
342 gate_attn: None,
343 gate_ffn: None,
344 })
345 }
346 }
347
348 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
350 let residual = hidden_state;
352 let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
353
354 hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
355
356 if let Some(gate) = &self.gate_attn {
357 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
358 }
359 hidden_state = (residual + hidden_state)?;
360
361 let residual = hidden_state.clone();
363 hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
364
365 hidden_state = self.mlp.forward(&hidden_state)?;
366
367 if let Some(gate) = &self.gate_ffn {
368 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
369 }
370 residual + hidden_state
371 }
372}
373
374struct MLlamaVisionEncoder {
375 layers: Vec<MLlamaVisionEncoderLayer>,
376}
377
378impl MLlamaVisionEncoder {
379 fn new<const GATED: bool>(
380 cfg: &MLlamaVisionConfig,
381 num_layers: usize,
382 vb: ShardedVarBuilder,
383 real_dev: &Device,
384 comm: &Arc<mistralrs_quant::Comm>,
385 ) -> Result<Self> {
386 let mut layers = Vec::with_capacity(num_layers);
387 let layers_vb = vb.pp("layers");
388 for i in 0..num_layers {
389 layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
390 cfg,
391 layers_vb.pp(i),
392 real_dev,
393 comm,
394 )?);
395 }
396 Ok(Self { layers })
397 }
398
399 fn forward_with_states(
402 &self,
403 hidden_state: &Tensor,
404 attention_mask: Option<&Tensor>,
405 intermediate_layers_indices: Option<&[usize]>,
406 ) -> Result<(Tensor, Vec<Tensor>)> {
407 let mut hidden_state = hidden_state.clone();
408 let mut hidden_states = Vec::new();
409 for (i, layer) in self.layers.iter().enumerate() {
410 if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
411 hidden_states.push(hidden_state.clone());
412 }
413 hidden_state = layer.forward(&hidden_state, attention_mask)?;
414 }
415 Ok((hidden_state, hidden_states))
416 }
417
418 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
419 let uvb_t = UnVarBuilder::new();
420
421 for (i, layer) in self.layers.iter().enumerate() {
422 let uvb_l = uvb_t.pp("layers").pp(i);
423 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
424 uvb_l
425 .pp("post_attention_layernorm")
426 .add(&layer.post_attention_layernorm);
427 if let Some(gate) = layer.gate_attn.clone() {
428 uvb_l.add_tensor("gate_attn", gate);
429 }
430 if let Some(gate) = layer.gate_ffn.clone() {
431 uvb_l.add_tensor("gate_ffn", gate);
432 }
433 }
434
435 uvb_t.to_safetensors()
436 }
437}
438
439fn _prepare_aspect_ratio_attention_mask(
440 aspect_ratio_mask: &Tensor,
441 num_patches: usize,
442 target_length: usize,
443 dtype: DType,
444 _num_attn_heads: usize,
445) -> Result<Tensor> {
446 let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
447 let mut attention_mask = aspect_ratio_mask
448 .reshape((bs, max_num_tiles, 1, 1))?
449 .repeat((1, 1, target_length, 1))?;
450
451 let pad_patches = target_length - num_patches;
453 let (bs, d1, d2, d3) = attention_mask.dims4()?;
454 attention_mask = attention_mask.slice_assign(
455 &[&.., &.., &(d2 - pad_patches..), &..],
456 &Tensor::zeros(
457 (bs, d1, pad_patches, d3),
458 attention_mask.dtype(),
459 attention_mask.device(),
460 )?,
461 )?;
462
463 attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
465
466 let neg_inf_value = dtype.finfo()?.min;
469 attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
470 attention_mask.matmul(
471 &attention_mask
472 .transpose(D::Minus1, D::Minus2)?
473 .mul(neg_inf_value)?,
474 )
475}
476
477pub(super) struct MLlamaVisionModel {
478 patch_embedding: Conv2d,
479 class_embedding: Tensor,
480 gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
481 pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
482 post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
483 layernorm_pre: LayerNorm,
484 layernorm_post: LayerNorm,
485 transformer: MLlamaVisionEncoder,
486 global_transformer: MLlamaVisionEncoder,
487 pub(super) num_patches: usize,
488 intermediate_layers_indices: Vec<usize>,
489 num_attn_heads: usize,
490}
491
492impl MLlamaVisionModel {
493 pub(super) fn new(
494 cfg: &MLlamaVisionConfig,
495 vb: ShardedVarBuilder,
496 real_dev: &Device,
497 comm: &Arc<mistralrs_quant::Comm>,
498 ) -> Result<Self> {
499 let patch_embedding = conv2d_no_bias(
500 cfg.num_channels,
501 cfg.hidden_size,
502 cfg.patch_size,
503 Conv2dConfig {
504 stride: cfg.patch_size,
505 ..Default::default()
506 },
507 vb.pp("patch_embedding").set_device(real_dev.clone()),
508 )?;
509
510 let class_embedding = vb
511 .get((cfg.hidden_size,), "class_embedding")?
512 .to_device(real_dev)?;
513 let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
514 cfg,
515 vb.pp("gated_positional_embedding")
516 .set_device(real_dev.clone()),
517 )?;
518
519 let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
520 cfg,
521 vb.pp("pre_tile_positional_embedding")
522 .set_device(real_dev.clone()),
523 )?;
524 let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
525 cfg,
526 vb.pp("post_tile_positional_embedding")
527 .set_device(real_dev.clone()),
528 )?;
529
530 let layernorm_pre = layer_norm(
532 cfg.hidden_size,
533 LayerNormConfig::default(),
534 vb.pp("layernorm_pre").set_device(real_dev.clone()),
535 )?;
536 let layernorm_post = layer_norm(
537 cfg.hidden_size,
538 LayerNormConfig::default(),
539 vb.pp("layernorm_post").set_device(real_dev.clone()),
540 )?;
541
542 let transformer = MLlamaVisionEncoder::new::<false>(
544 cfg,
545 cfg.num_hidden_layers,
546 vb.pp("transformer"),
547 real_dev,
548 comm,
549 )?;
550 let global_transformer = MLlamaVisionEncoder::new::<true>(
551 cfg,
552 cfg.num_global_layers,
553 vb.pp("global_transformer"),
554 real_dev,
555 comm,
556 )?;
557
558 Ok(Self {
559 patch_embedding,
560 class_embedding,
561 gated_positional_embedding,
562 pre_tile_positional_embedding,
563 post_tile_positional_embedding,
564 layernorm_post,
565 layernorm_pre,
566 transformer,
567 global_transformer,
568 num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
569 intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
570 num_attn_heads: cfg.num_attention_heads,
571 })
572 }
573
574 pub(super) fn forward(
576 &self,
577 pixel_values: &Tensor,
578 aspect_ratio_ids: &Tensor,
579 aspect_ratio_mask: &Tensor,
580 ) -> Result<Tensor> {
581 let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
582
583 let bs = pixel_values.dim(0)?;
584 let num_concurrent_media = pixel_values.dim(1)?;
585 let num_tiles = pixel_values.dim(2)?;
586 let num_channels = pixel_values.dim(3)?;
587 let height = pixel_values.dim(4)?;
588 let width = pixel_values.dim(5)?;
589
590 let pixel_values = pixel_values.reshape((
591 bs * num_concurrent_media * num_tiles,
592 num_channels,
593 height,
594 width,
595 ))?;
596 let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
597
598 let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
600 let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
601
602 let (_, mut num_patches, dim) = hidden_state.dims3()?;
604 hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
605 hidden_state = self
606 .pre_tile_positional_embedding
607 .forward(&hidden_state, &aspect_ratio_ids)?;
608
609 hidden_state =
611 hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
612 hidden_state = self.apply_class_embedding(&hidden_state)?;
613 num_patches += 1;
614
615 hidden_state =
617 hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
618 hidden_state = self
619 .gated_positional_embedding
620 .forward(&hidden_state, &aspect_ratio_ids)?;
621
622 hidden_state = self.layernorm_pre.forward(&hidden_state)?;
623
624 let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
626 let _padding = (0usize, 0usize, 0usize, num_padding_patches);
629 if num_padding_patches >= 0 {
630 hidden_state =
631 hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
632 } else {
633 hidden_state = hidden_state.narrow(
634 D::Minus2,
635 0,
636 (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
637 )?;
638 }
639
640 let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
642 attention_mask = _prepare_aspect_ratio_attention_mask(
643 &attention_mask,
644 self.num_patches,
645 hidden_state.dim(2)?,
646 hidden_state.dtype(),
647 self.num_attn_heads,
648 )?;
649 if attention_mask.dim(0)? != 1 {
650 attention_mask = attention_mask.unsqueeze(1)?;
651 }
652
653 hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
655 let (mut hidden_state, all_intermediate_hidden_states) =
656 self.transformer.forward_with_states(
657 &hidden_state,
658 Some(&attention_mask),
659 Some(&self.intermediate_layers_indices),
660 )?;
661
662 let mut intermediate_hidden_states =
664 Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
665 drop(all_intermediate_hidden_states);
666
667 hidden_state = self.layernorm_post.forward(&hidden_state)?;
668
669 hidden_state = hidden_state.reshape((
671 bs * num_concurrent_media,
672 num_tiles,
673 (num_patches as isize + num_padding_patches) as usize,
674 dim,
675 ))?;
676 hidden_state = self
677 .post_tile_positional_embedding
678 .forward(&hidden_state, &aspect_ratio_ids)?;
679 hidden_state = hidden_state.reshape((
680 bs * num_concurrent_media,
681 num_tiles * (num_patches as isize + num_padding_patches) as usize,
682 dim,
683 ))?;
684 (hidden_state, _) = self.global_transformer.forward_with_states(
685 &hidden_state,
686 Some(&attention_mask),
687 None,
688 )?;
689
690 hidden_state = hidden_state.reshape((
692 bs * num_concurrent_media,
693 num_tiles,
694 (num_patches as isize + num_padding_patches) as usize,
695 dim,
696 ))?;
697 hidden_state = hidden_state.narrow(
698 2,
699 0,
700 (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
701 )?;
702 hidden_state =
703 hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
704
705 intermediate_hidden_states = intermediate_hidden_states.reshape((
707 bs * num_concurrent_media,
708 num_tiles,
709 (num_patches as isize + num_padding_patches) as usize,
710 (),
711 ))?;
712 intermediate_hidden_states = intermediate_hidden_states.narrow(
713 2,
714 0,
715 (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
716 )?;
717 intermediate_hidden_states = intermediate_hidden_states.reshape((
718 bs,
719 num_concurrent_media,
720 num_tiles,
721 num_patches,
722 (),
723 ))?;
724
725 Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
727 }
728
729 fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
730 let (bs, _, hidden_size) = hidden_state.dims3()?;
731 let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
732 Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
733 }
734
735 pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
736 let mut layers = Vec::new();
737 for layer in &mut self.global_transformer.layers {
738 layers.push(&mut layer.self_attn.q_proj);
739 layers.push(&mut layer.self_attn.k_proj);
740 layers.push(&mut layer.self_attn.v_proj);
741 layers.push(&mut layer.self_attn.o_proj);
742
743 layers.push(&mut layer.mlp.fc1);
744 layers.push(&mut layer.mlp.fc2);
745 }
746 for layer in &mut self.transformer.layers {
747 layers.push(&mut layer.self_attn.q_proj);
748 layers.push(&mut layer.self_attn.k_proj);
749 layers.push(&mut layer.self_attn.v_proj);
750 layers.push(&mut layer.self_attn.o_proj);
751
752 layers.push(&mut layer.mlp.fc1);
753 layers.push(&mut layer.mlp.fc2);
754 }
755 layers
756 }
757}
758
759impl IsqModel for MLlamaVisionModel {
760 fn get_layers(
761 &mut self,
762 ) -> (
763 Vec<(
764 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
765 Option<usize>,
766 )>,
767 &dyn crate::device_map::DeviceMapper,
768 ) {
769 unreachable!("MLlamaVision model cannot be quantized.");
770 }
771 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
772 let uvb = UnVarBuilder::new();
773
774 uvb.pp("patch_embedding").add(&self.patch_embedding);
775 uvb.add_tensor("class_embedding", self.class_embedding.clone());
776
777 uvb.pp("gated_positional_embedding")
779 .extend(self.gated_positional_embedding.residual_tensors());
780
781 uvb.pp("pre_tile_positional_embedding")
783 .extend(self.pre_tile_positional_embedding.residual_tensors());
784
785 uvb.pp("post_tile_positional_embedding")
787 .extend(self.post_tile_positional_embedding.residual_tensors());
788
789 uvb.pp("layernorm_pre").add(&self.layernorm_pre);
790 uvb.pp("layernorm_post").add(&self.layernorm_post);
791
792 uvb.pp("transformer")
794 .extend(self.transformer.residual_tensors());
795
796 uvb.pp("global_transformer")
798 .extend(self.global_transformer.residual_tensors());
799
800 uvb.to_safetensors()
801 }
802}