1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
8
9use crate::{
10 attention::SdpaParams,
11 layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
12 pipeline::IsqModel,
13 utils::unvarbuilder::UnVarBuilder,
14};
15
16use super::{MLlamaVisionConfig, VisionActivation};
17
18struct MLlamaPrecomputedPositionEmbedding {
19 gate: Tensor,
20 embedding: Tensor,
21 tile_embedding: Embedding,
22 num_patches: usize,
23 hidden_size: usize,
24 max_num_tiles: usize,
25}
26
27impl MLlamaPrecomputedPositionEmbedding {
28 fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
29 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
30 Ok(Self {
31 gate: vb.get((1,), "gate")?,
32 embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
33 tile_embedding: embedding(
34 cfg.max_aspect_ratio_id() + 1,
35 cfg.max_num_tiles * num_patches * cfg.hidden_size,
36 vb.pp("tile_embedding"),
37 &None,
38 )?,
39 num_patches,
40 hidden_size: cfg.hidden_size,
41 max_num_tiles: cfg.max_num_tiles,
42 })
43 }
44
45 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
47 let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
49 let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
50 1,
51 1,
52 self.num_patches,
53 self.hidden_size,
54 ))?)?;
55
56 let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
58 let bs = hidden_state.dim(0)?;
59 tile_position_embedding = tile_position_embedding.reshape((
60 bs,
61 self.max_num_tiles,
62 self.num_patches,
63 self.hidden_size,
64 ))?;
65 gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
66
67 hidden_state.broadcast_add(&gated_pos_embed)
68 }
69
70 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
71 let uvb_gpe = UnVarBuilder::new();
72
73 uvb_gpe.add_tensor("gate", self.gate.clone());
74 uvb_gpe.add_tensor("embedding", self.embedding.clone());
75 uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
76
77 uvb_gpe.to_safetensors()
78 }
79}
80
81struct MLlamaPrecomputedAspectRatioEmbedding {
82 embedding: Embedding,
83 gate: Option<Tensor>,
84 max_num_tiles: usize,
85 hidden_size: usize,
86}
87
88impl MLlamaPrecomputedAspectRatioEmbedding {
89 fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
90 Ok(Self {
91 embedding: embedding(
92 cfg.max_aspect_ratio_id() + 1,
93 cfg.max_num_tiles * cfg.hidden_size,
94 vb.pp("embedding"),
95 &None,
96 )?,
97 gate: if GATED {
98 Some(vb.get((1,), "gate")?)
99 } else {
100 None
101 },
102 max_num_tiles: cfg.max_num_tiles,
103 hidden_size: cfg.hidden_size,
104 })
105 }
106
107 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
108 let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
109 embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
110
111 if let Some(gate) = &self.gate {
112 embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
113 }
114
115 hidden_state.broadcast_add(&embeddings)
116 }
117
118 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
119 let uvb_ptpe = UnVarBuilder::new();
120
121 if let Some(gate) = self.gate.clone() {
122 uvb_ptpe.add_tensor("gate", gate);
123 }
124 uvb_ptpe.pp("embedding").add(&self.embedding);
125
126 uvb_ptpe.to_safetensors()
127 }
128}
129
130struct MLlamaVisionAttention {
131 q_proj: Arc<dyn QuantMethod>,
132 k_proj: Arc<dyn QuantMethod>,
133 v_proj: Arc<dyn QuantMethod>,
134 o_proj: Arc<dyn QuantMethod>,
135 sdpa_params: SdpaParams,
136 num_heads: usize,
137 head_dim: usize,
138}
139
140impl MLlamaVisionAttention {
141 fn new(
142 cfg: &MLlamaVisionConfig,
143 vb: ShardedVarBuilder,
144 comm: &Arc<mistralrs_quant::Comm>,
145 ) -> Result<Self> {
146 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
147 Ok(Self {
148 q_proj: ColumnParallelLayer::new(
149 cfg.hidden_size,
150 cfg.num_attention_heads * head_dim,
151 &None,
152 false,
153 comm,
154 vb.pp("q_proj"),
155 )?,
156 k_proj: ColumnParallelLayer::new(
157 cfg.hidden_size,
158 cfg.num_attention_heads * head_dim,
159 &None,
160 false,
161 comm,
162 vb.pp("k_proj"),
163 )?,
164 v_proj: ColumnParallelLayer::new(
165 cfg.hidden_size,
166 cfg.num_attention_heads * head_dim,
167 &None,
168 false,
169 comm,
170 vb.pp("v_proj"),
171 )?,
172 o_proj: RowParallelLayer::new(
173 cfg.hidden_size,
174 cfg.num_attention_heads * head_dim,
175 &None,
176 false,
177 comm,
178 vb.pp("o_proj"),
179 )?,
180 sdpa_params: SdpaParams {
181 n_kv_groups: 1,
182 use_flash_attn: false,
183 softcap: None,
184 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
185 sliding_window: None,
186 },
187 num_heads: cfg.num_attention_heads,
188 head_dim,
189 })
190 }
191
192 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
194 let mut hidden_state = hidden_state.clone();
195 let original_dtype = hidden_state.dtype();
196 if let Some(t) = self.q_proj.quantized_act_type() {
197 hidden_state = hidden_state.to_dtype(t)?;
198 }
199 let mut q = self.q_proj.forward(&hidden_state)?;
200 let mut k = self.k_proj.forward(&hidden_state)?;
201 let mut v = self.v_proj.forward(&hidden_state)?;
202 if self.q_proj.quantized_act_type().is_some() {
203 q = q.to_dtype(original_dtype)?;
204 k = k.to_dtype(original_dtype)?;
205 v = v.to_dtype(original_dtype)?;
206 }
207
208 let (bs, q_sq, _) = q.dims3()?;
210 let (_, k_sq, _) = k.dims3()?;
211
212 q = q
213 .reshape((bs, q_sq, self.num_heads, self.head_dim))?
214 .transpose(1, 2)?;
215 k = k
216 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
217 .transpose(1, 2)?;
218 v = v
219 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
220 .transpose(1, 2)?;
221
222 let mut attn_output = Sdpa
223 .run_attention(
224 &q.contiguous()?,
225 &k.contiguous()?,
226 &v.contiguous()?,
227 attention_mask,
228 None,
229 &self.sdpa_params,
230 )?
231 .transpose(1, 2)?
232 .contiguous()?
233 .reshape((bs, q_sq, ()))?
234 .to_dtype(q.dtype())?;
235
236 if let Some(t) = self.q_proj.quantized_act_type() {
237 attn_output = attn_output.to_dtype(t)?;
238 }
239 let mut res = self.o_proj.forward(&attn_output)?;
240 if self.q_proj.quantized_act_type().is_some() {
241 res = res.to_dtype(original_dtype)?;
242 }
243 Ok(res)
244 }
245}
246
247struct MLlamaMlp {
248 act: VisionActivation,
249 fc1: Arc<dyn QuantMethod>,
250 fc2: Arc<dyn QuantMethod>,
251}
252
253impl MLlamaMlp {
254 fn new(
255 cfg: &MLlamaVisionConfig,
256 vb: ShardedVarBuilder,
257 comm: &Arc<mistralrs_quant::Comm>,
258 ) -> Result<Self> {
259 Ok(Self {
260 act: cfg.hidden_act,
261 fc1: ColumnParallelLayer::new(
262 cfg.hidden_size,
263 cfg.intermediate_size,
264 &None,
265 true,
266 comm,
267 vb.pp("fc1"),
268 )?,
269 fc2: RowParallelLayer::new(
270 cfg.intermediate_size,
271 cfg.hidden_size,
272 &None,
273 true,
274 comm,
275 vb.pp("fc2"),
276 )?,
277 })
278 }
279
280 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
282 let original_dtype = hidden_states.dtype();
283 let mut hidden_states = hidden_states.clone();
284 if let Some(t) = self.fc1.quantized_act_type() {
285 hidden_states = hidden_states.to_dtype(t)?;
286 }
287 hidden_states = self
288 .fc2
289 .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
290 if self.fc1.quantized_act_type().is_some() {
291 hidden_states = hidden_states.to_dtype(original_dtype)?;
292 }
293 Ok(hidden_states)
294 }
295}
296
297struct MLlamaVisionEncoderLayer {
298 self_attn: MLlamaVisionAttention,
299 mlp: MLlamaMlp,
300 input_layernorm: LayerNorm,
301 post_attention_layernorm: LayerNorm,
302 gate_attn: Option<Tensor>,
303 gate_ffn: Option<Tensor>,
304}
305
306impl MLlamaVisionEncoderLayer {
307 fn new<const GATED: bool>(
308 cfg: &MLlamaVisionConfig,
309 vb: ShardedVarBuilder,
310 real_dev: &Device,
311 comm: &Arc<mistralrs_quant::Comm>,
312 ) -> Result<Self> {
313 let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
314 let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
315
316 let input_layernorm = layer_norm(
317 cfg.hidden_size,
318 cfg.norm_eps,
319 vb.pp("input_layernorm").set_device(real_dev.clone()),
320 )?;
321 let post_attention_layernorm = layer_norm(
322 cfg.hidden_size,
323 cfg.norm_eps,
324 vb.pp("post_attention_layernorm")
325 .set_device(real_dev.clone()),
326 )?;
327
328 if GATED {
329 Ok(Self {
330 self_attn,
331 mlp,
332 input_layernorm,
333 post_attention_layernorm,
334 gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
335 gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
336 })
337 } else {
338 Ok(Self {
339 self_attn,
340 mlp,
341 input_layernorm,
342 post_attention_layernorm,
343 gate_attn: None,
344 gate_ffn: None,
345 })
346 }
347 }
348
349 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
351 let residual = hidden_state;
353 let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
354
355 hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
356
357 if let Some(gate) = &self.gate_attn {
358 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
359 }
360 hidden_state = (residual + hidden_state)?;
361
362 let residual = hidden_state.clone();
364 hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
365
366 hidden_state = self.mlp.forward(&hidden_state)?;
367
368 if let Some(gate) = &self.gate_ffn {
369 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
370 }
371 residual + hidden_state
372 }
373}
374
375struct MLlamaVisionEncoder {
376 layers: Vec<MLlamaVisionEncoderLayer>,
377}
378
379impl MLlamaVisionEncoder {
380 fn new<const GATED: bool>(
381 cfg: &MLlamaVisionConfig,
382 num_layers: usize,
383 vb: ShardedVarBuilder,
384 real_dev: &Device,
385 comm: &Arc<mistralrs_quant::Comm>,
386 ) -> Result<Self> {
387 let mut layers = Vec::with_capacity(num_layers);
388 let layers_vb = vb.pp("layers");
389 for i in 0..num_layers {
390 layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
391 cfg,
392 layers_vb.pp(i),
393 real_dev,
394 comm,
395 )?);
396 }
397 Ok(Self { layers })
398 }
399
400 fn forward_with_states(
403 &self,
404 hidden_state: &Tensor,
405 attention_mask: Option<&Tensor>,
406 intermediate_layers_indices: Option<&[usize]>,
407 ) -> Result<(Tensor, Vec<Tensor>)> {
408 let mut hidden_state = hidden_state.clone();
409 let mut hidden_states = Vec::new();
410 for (i, layer) in self.layers.iter().enumerate() {
411 if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
412 hidden_states.push(hidden_state.clone());
413 }
414 hidden_state = layer.forward(&hidden_state, attention_mask)?;
415 }
416 Ok((hidden_state, hidden_states))
417 }
418
419 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
420 let uvb_t = UnVarBuilder::new();
421
422 for (i, layer) in self.layers.iter().enumerate() {
423 let uvb_l = uvb_t.pp("layers").pp(i);
424 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
425 uvb_l
426 .pp("post_attention_layernorm")
427 .add(&layer.post_attention_layernorm);
428 if let Some(gate) = layer.gate_attn.clone() {
429 uvb_l.add_tensor("gate_attn", gate);
430 }
431 if let Some(gate) = layer.gate_ffn.clone() {
432 uvb_l.add_tensor("gate_ffn", gate);
433 }
434 }
435
436 uvb_t.to_safetensors()
437 }
438}
439
440fn _prepare_aspect_ratio_attention_mask(
441 aspect_ratio_mask: &Tensor,
442 num_patches: usize,
443 target_length: usize,
444 dtype: DType,
445 _num_attn_heads: usize,
446) -> Result<Tensor> {
447 let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
448 let mut attention_mask = aspect_ratio_mask
449 .reshape((bs, max_num_tiles, 1, 1))?
450 .repeat((1, 1, target_length, 1))?;
451
452 let pad_patches = target_length - num_patches;
454 let (bs, d1, d2, d3) = attention_mask.dims4()?;
455 attention_mask = attention_mask.slice_assign(
456 &[&.., &.., &(d2 - pad_patches..), &..],
457 &Tensor::zeros(
458 (bs, d1, pad_patches, d3),
459 attention_mask.dtype(),
460 attention_mask.device(),
461 )?,
462 )?;
463
464 attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
466
467 let neg_inf_value = dtype.finfo()?.min;
470 attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
471 attention_mask.matmul(
472 &attention_mask
473 .transpose(D::Minus1, D::Minus2)?
474 .mul(neg_inf_value)?,
475 )
476}
477
478pub(super) struct MLlamaVisionModel {
479 patch_embedding: Conv2d,
480 class_embedding: Tensor,
481 gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
482 pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
483 post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
484 layernorm_pre: LayerNorm,
485 layernorm_post: LayerNorm,
486 transformer: MLlamaVisionEncoder,
487 global_transformer: MLlamaVisionEncoder,
488 pub(super) num_patches: usize,
489 intermediate_layers_indices: Vec<usize>,
490 num_attn_heads: usize,
491}
492
493impl MLlamaVisionModel {
494 pub(super) fn new(
495 cfg: &MLlamaVisionConfig,
496 vb: ShardedVarBuilder,
497 real_dev: &Device,
498 comm: &Arc<mistralrs_quant::Comm>,
499 ) -> Result<Self> {
500 let patch_embedding = conv2d_no_bias(
501 cfg.num_channels,
502 cfg.hidden_size,
503 cfg.patch_size,
504 Conv2dConfig {
505 stride: cfg.patch_size,
506 ..Default::default()
507 },
508 vb.pp("patch_embedding").set_device(real_dev.clone()),
509 )?;
510
511 let class_embedding = vb
512 .get((cfg.hidden_size,), "class_embedding")?
513 .to_device(real_dev)?;
514 let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
515 cfg,
516 vb.pp("gated_positional_embedding")
517 .set_device(real_dev.clone()),
518 )?;
519
520 let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
521 cfg,
522 vb.pp("pre_tile_positional_embedding")
523 .set_device(real_dev.clone()),
524 )?;
525 let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
526 cfg,
527 vb.pp("post_tile_positional_embedding")
528 .set_device(real_dev.clone()),
529 )?;
530
531 let layernorm_pre = layer_norm(
533 cfg.hidden_size,
534 LayerNormConfig::default(),
535 vb.pp("layernorm_pre").set_device(real_dev.clone()),
536 )?;
537 let layernorm_post = layer_norm(
538 cfg.hidden_size,
539 LayerNormConfig::default(),
540 vb.pp("layernorm_post").set_device(real_dev.clone()),
541 )?;
542
543 let transformer = MLlamaVisionEncoder::new::<false>(
545 cfg,
546 cfg.num_hidden_layers,
547 vb.pp("transformer"),
548 real_dev,
549 comm,
550 )?;
551 let global_transformer = MLlamaVisionEncoder::new::<true>(
552 cfg,
553 cfg.num_global_layers,
554 vb.pp("global_transformer"),
555 real_dev,
556 comm,
557 )?;
558
559 Ok(Self {
560 patch_embedding,
561 class_embedding,
562 gated_positional_embedding,
563 pre_tile_positional_embedding,
564 post_tile_positional_embedding,
565 layernorm_post,
566 layernorm_pre,
567 transformer,
568 global_transformer,
569 num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
570 intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
571 num_attn_heads: cfg.num_attention_heads,
572 })
573 }
574
575 pub(super) fn forward(
577 &self,
578 pixel_values: &Tensor,
579 aspect_ratio_ids: &Tensor,
580 aspect_ratio_mask: &Tensor,
581 ) -> Result<Tensor> {
582 let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
583
584 let bs = pixel_values.dim(0)?;
585 let num_concurrent_media = pixel_values.dim(1)?;
586 let num_tiles = pixel_values.dim(2)?;
587 let num_channels = pixel_values.dim(3)?;
588 let height = pixel_values.dim(4)?;
589 let width = pixel_values.dim(5)?;
590
591 let pixel_values = pixel_values.reshape((
592 bs * num_concurrent_media * num_tiles,
593 num_channels,
594 height,
595 width,
596 ))?;
597 let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
598
599 let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
601 let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
602
603 let (_, mut num_patches, dim) = hidden_state.dims3()?;
605 hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
606 hidden_state = self
607 .pre_tile_positional_embedding
608 .forward(&hidden_state, &aspect_ratio_ids)?;
609
610 hidden_state =
612 hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
613 hidden_state = self.apply_class_embedding(&hidden_state)?;
614 num_patches += 1;
615
616 hidden_state =
618 hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
619 hidden_state = self
620 .gated_positional_embedding
621 .forward(&hidden_state, &aspect_ratio_ids)?;
622
623 hidden_state = self.layernorm_pre.forward(&hidden_state)?;
624
625 let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
627 let _padding = (0usize, 0usize, 0usize, num_padding_patches);
630 if num_padding_patches >= 0 {
631 hidden_state =
632 hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
633 } else {
634 hidden_state = hidden_state.narrow(
635 D::Minus2,
636 0,
637 (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
638 )?;
639 }
640
641 let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
643 attention_mask = _prepare_aspect_ratio_attention_mask(
644 &attention_mask,
645 self.num_patches,
646 hidden_state.dim(2)?,
647 hidden_state.dtype(),
648 self.num_attn_heads,
649 )?;
650 if attention_mask.dim(0)? != 1 {
651 attention_mask = attention_mask.unsqueeze(1)?;
652 }
653
654 hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
656 let (mut hidden_state, all_intermediate_hidden_states) =
657 self.transformer.forward_with_states(
658 &hidden_state,
659 Some(&attention_mask),
660 Some(&self.intermediate_layers_indices),
661 )?;
662
663 let mut intermediate_hidden_states =
665 Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
666 drop(all_intermediate_hidden_states);
667
668 hidden_state = self.layernorm_post.forward(&hidden_state)?;
669
670 hidden_state = hidden_state.reshape((
672 bs * num_concurrent_media,
673 num_tiles,
674 (num_patches as isize + num_padding_patches) as usize,
675 dim,
676 ))?;
677 hidden_state = self
678 .post_tile_positional_embedding
679 .forward(&hidden_state, &aspect_ratio_ids)?;
680 hidden_state = hidden_state.reshape((
681 bs * num_concurrent_media,
682 num_tiles * (num_patches as isize + num_padding_patches) as usize,
683 dim,
684 ))?;
685 (hidden_state, _) = self.global_transformer.forward_with_states(
686 &hidden_state,
687 Some(&attention_mask),
688 None,
689 )?;
690
691 hidden_state = hidden_state.reshape((
693 bs * num_concurrent_media,
694 num_tiles,
695 (num_patches as isize + num_padding_patches) as usize,
696 dim,
697 ))?;
698 hidden_state = hidden_state.narrow(
699 2,
700 0,
701 (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
702 )?;
703 hidden_state =
704 hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
705
706 intermediate_hidden_states = intermediate_hidden_states.reshape((
708 bs * num_concurrent_media,
709 num_tiles,
710 (num_patches as isize + num_padding_patches) as usize,
711 (),
712 ))?;
713 intermediate_hidden_states = intermediate_hidden_states.narrow(
714 2,
715 0,
716 (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
717 )?;
718 intermediate_hidden_states = intermediate_hidden_states.reshape((
719 bs,
720 num_concurrent_media,
721 num_tiles,
722 num_patches,
723 (),
724 ))?;
725
726 Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
728 }
729
730 fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
731 let (bs, _, hidden_size) = hidden_state.dims3()?;
732 let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
733 Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
734 }
735
736 pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
737 let mut layers = Vec::new();
738 for layer in &mut self.global_transformer.layers {
739 layers.push(&mut layer.self_attn.q_proj);
740 layers.push(&mut layer.self_attn.k_proj);
741 layers.push(&mut layer.self_attn.v_proj);
742 layers.push(&mut layer.self_attn.o_proj);
743
744 layers.push(&mut layer.mlp.fc1);
745 layers.push(&mut layer.mlp.fc2);
746 }
747 for layer in &mut self.transformer.layers {
748 layers.push(&mut layer.self_attn.q_proj);
749 layers.push(&mut layer.self_attn.k_proj);
750 layers.push(&mut layer.self_attn.v_proj);
751 layers.push(&mut layer.self_attn.o_proj);
752
753 layers.push(&mut layer.mlp.fc1);
754 layers.push(&mut layer.mlp.fc2);
755 }
756 layers
757 }
758}
759
760impl IsqModel for MLlamaVisionModel {
761 fn get_layers(
762 &mut self,
763 ) -> (
764 Vec<(
765 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
766 Option<usize>,
767 )>,
768 &dyn crate::device_map::DeviceMapper,
769 ) {
770 unreachable!("MLlamaVision model cannot be quantized.");
771 }
772 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
773 let uvb = UnVarBuilder::new();
774
775 uvb.pp("patch_embedding").add(&self.patch_embedding);
776 uvb.add_tensor("class_embedding", self.class_embedding.clone());
777
778 uvb.pp("gated_positional_embedding")
780 .extend(self.gated_positional_embedding.residual_tensors());
781
782 uvb.pp("pre_tile_positional_embedding")
784 .extend(self.pre_tile_positional_embedding.residual_tensors());
785
786 uvb.pp("post_tile_positional_embedding")
788 .extend(self.post_tile_positional_embedding.residual_tensors());
789
790 uvb.pp("layernorm_pre").add(&self.layernorm_pre);
791 uvb.pp("layernorm_post").add(&self.layernorm_post);
792
793 uvb.pp("transformer")
795 .extend(self.transformer.residual_tensors());
796
797 uvb.pp("global_transformer")
799 .extend(self.global_transformer.residual_tensors());
800
801 uvb.to_safetensors()
802 }
803}