1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};
8
9use crate::{
10 attention::SdpaParams,
11 layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
12 pipeline::IsqModel,
13 utils::unvarbuilder::UnVarBuilder,
14};
15
16use super::{MLlamaVisionConfig, VisionActivation};
17
18struct MLlamaPrecomputedPositionEmbedding {
19 gate: Tensor,
20 embedding: Tensor,
21 tile_embedding: Embedding,
22 num_patches: usize,
23 hidden_size: usize,
24 max_num_tiles: usize,
25}
26
27impl MLlamaPrecomputedPositionEmbedding {
28 fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
29 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
30 Ok(Self {
31 gate: vb.get((1,), "gate")?,
32 embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
33 tile_embedding: embedding(
34 cfg.max_aspect_ratio_id() + 1,
35 cfg.max_num_tiles * num_patches * cfg.hidden_size,
36 vb.pp("tile_embedding"),
37 )?,
38 num_patches,
39 hidden_size: cfg.hidden_size,
40 max_num_tiles: cfg.max_num_tiles,
41 })
42 }
43
44 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
46 let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
48 let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
49 1,
50 1,
51 self.num_patches,
52 self.hidden_size,
53 ))?)?;
54
55 let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
57 let bs = hidden_state.dim(0)?;
58 tile_position_embedding = tile_position_embedding.reshape((
59 bs,
60 self.max_num_tiles,
61 self.num_patches,
62 self.hidden_size,
63 ))?;
64 gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
65
66 hidden_state.broadcast_add(&gated_pos_embed)
67 }
68
69 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
70 let uvb_gpe = UnVarBuilder::new();
71
72 uvb_gpe.add_tensor("gate", self.gate.clone());
73 uvb_gpe.add_tensor("embedding", self.embedding.clone());
74 uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
75
76 uvb_gpe.to_safetensors()
77 }
78}
79
80struct MLlamaPrecomputedAspectRatioEmbedding {
81 embedding: Embedding,
82 gate: Option<Tensor>,
83 max_num_tiles: usize,
84 hidden_size: usize,
85}
86
87impl MLlamaPrecomputedAspectRatioEmbedding {
88 fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
89 Ok(Self {
90 embedding: embedding(
91 cfg.max_aspect_ratio_id() + 1,
92 cfg.max_num_tiles * cfg.hidden_size,
93 vb.pp("embedding"),
94 )?,
95 gate: if GATED {
96 Some(vb.get((1,), "gate")?)
97 } else {
98 None
99 },
100 max_num_tiles: cfg.max_num_tiles,
101 hidden_size: cfg.hidden_size,
102 })
103 }
104
105 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
106 let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
107 embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
108
109 if let Some(gate) = &self.gate {
110 embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
111 }
112
113 hidden_state.broadcast_add(&embeddings)
114 }
115
116 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
117 let uvb_ptpe = UnVarBuilder::new();
118
119 if let Some(gate) = self.gate.clone() {
120 uvb_ptpe.add_tensor("gate", gate);
121 }
122 uvb_ptpe.pp("embedding").add(&self.embedding);
123
124 uvb_ptpe.to_safetensors()
125 }
126}
127
128struct MLlamaVisionAttention {
129 q_proj: Arc<dyn QuantMethod>,
130 k_proj: Arc<dyn QuantMethod>,
131 v_proj: Arc<dyn QuantMethod>,
132 o_proj: Arc<dyn QuantMethod>,
133 sdpa_params: SdpaParams,
134 num_heads: usize,
135 head_dim: usize,
136}
137
138impl MLlamaVisionAttention {
139 fn new(
140 cfg: &MLlamaVisionConfig,
141 vb: ShardedVarBuilder,
142 comm: &Arc<mistralrs_quant::Comm>,
143 ) -> Result<Self> {
144 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
145 Ok(Self {
146 q_proj: ColumnParallelLayer::new(
147 cfg.hidden_size,
148 cfg.num_attention_heads * head_dim,
149 &None,
150 false,
151 comm,
152 vb.pp("q_proj"),
153 )?,
154 k_proj: ColumnParallelLayer::new(
155 cfg.hidden_size,
156 cfg.num_attention_heads * head_dim,
157 &None,
158 false,
159 comm,
160 vb.pp("k_proj"),
161 )?,
162 v_proj: ColumnParallelLayer::new(
163 cfg.hidden_size,
164 cfg.num_attention_heads * head_dim,
165 &None,
166 false,
167 comm,
168 vb.pp("v_proj"),
169 )?,
170 o_proj: RowParallelLayer::new(
171 cfg.hidden_size,
172 cfg.num_attention_heads * head_dim,
173 &None,
174 false,
175 comm,
176 vb.pp("o_proj"),
177 )?,
178 sdpa_params: SdpaParams {
179 n_kv_groups: 1,
180 use_flash_attn: false,
181 softcap: None,
182 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
183 sliding_window: None,
184 },
185 num_heads: cfg.num_attention_heads,
186 head_dim,
187 })
188 }
189
190 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
192 let mut hidden_state = hidden_state.clone();
193 let original_dtype = hidden_state.dtype();
194 if let Some(t) = self.q_proj.quantized_act_type() {
195 hidden_state = hidden_state.to_dtype(t)?;
196 }
197 let mut q = self.q_proj.forward(&hidden_state)?;
198 let mut k = self.k_proj.forward(&hidden_state)?;
199 let mut v = self.v_proj.forward(&hidden_state)?;
200 if self.q_proj.quantized_act_type().is_some() {
201 q = q.to_dtype(original_dtype)?;
202 k = k.to_dtype(original_dtype)?;
203 v = v.to_dtype(original_dtype)?;
204 }
205
206 let (bs, q_sq, _) = q.dims3()?;
208 let (_, k_sq, _) = k.dims3()?;
209
210 q = q
211 .reshape((bs, q_sq, self.num_heads, self.head_dim))?
212 .transpose(1, 2)?;
213 k = k
214 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
215 .transpose(1, 2)?;
216 v = v
217 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
218 .transpose(1, 2)?;
219
220 let mut attn_output = Sdpa
221 .run_attention(
222 &q.contiguous()?,
223 &k.contiguous()?,
224 &v.contiguous()?,
225 attention_mask,
226 None,
227 &self.sdpa_params,
228 )?
229 .transpose(1, 2)?
230 .contiguous()?
231 .reshape((bs, q_sq, ()))?
232 .to_dtype(q.dtype())?;
233
234 if let Some(t) = self.q_proj.quantized_act_type() {
235 attn_output = attn_output.to_dtype(t)?;
236 }
237 let mut res = self.o_proj.forward(&attn_output)?;
238 if self.q_proj.quantized_act_type().is_some() {
239 res = res.to_dtype(original_dtype)?;
240 }
241 Ok(res)
242 }
243}
244
245struct MLlamaMlp {
246 act: VisionActivation,
247 fc1: Arc<dyn QuantMethod>,
248 fc2: Arc<dyn QuantMethod>,
249}
250
251impl MLlamaMlp {
252 fn new(
253 cfg: &MLlamaVisionConfig,
254 vb: ShardedVarBuilder,
255 comm: &Arc<mistralrs_quant::Comm>,
256 ) -> Result<Self> {
257 Ok(Self {
258 act: cfg.hidden_act,
259 fc1: ColumnParallelLayer::new(
260 cfg.hidden_size,
261 cfg.intermediate_size,
262 &None,
263 true,
264 comm,
265 vb.pp("fc1"),
266 )?,
267 fc2: RowParallelLayer::new(
268 cfg.intermediate_size,
269 cfg.hidden_size,
270 &None,
271 true,
272 comm,
273 vb.pp("fc2"),
274 )?,
275 })
276 }
277
278 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
280 let original_dtype = hidden_states.dtype();
281 let mut hidden_states = hidden_states.clone();
282 if let Some(t) = self.fc1.quantized_act_type() {
283 hidden_states = hidden_states.to_dtype(t)?;
284 }
285 hidden_states = self
286 .fc2
287 .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
288 if self.fc1.quantized_act_type().is_some() {
289 hidden_states = hidden_states.to_dtype(original_dtype)?;
290 }
291 Ok(hidden_states)
292 }
293}
294
295struct MLlamaVisionEncoderLayer {
296 self_attn: MLlamaVisionAttention,
297 mlp: MLlamaMlp,
298 input_layernorm: LayerNorm,
299 post_attention_layernorm: LayerNorm,
300 gate_attn: Option<Tensor>,
301 gate_ffn: Option<Tensor>,
302}
303
304impl MLlamaVisionEncoderLayer {
305 fn new<const GATED: bool>(
306 cfg: &MLlamaVisionConfig,
307 vb: ShardedVarBuilder,
308 real_dev: &Device,
309 comm: &Arc<mistralrs_quant::Comm>,
310 ) -> Result<Self> {
311 let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
312 let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
313
314 let input_layernorm = layer_norm(
315 cfg.hidden_size,
316 cfg.norm_eps,
317 vb.pp("input_layernorm").set_device(real_dev.clone()),
318 )?;
319 let post_attention_layernorm = layer_norm(
320 cfg.hidden_size,
321 cfg.norm_eps,
322 vb.pp("post_attention_layernorm")
323 .set_device(real_dev.clone()),
324 )?;
325
326 if GATED {
327 Ok(Self {
328 self_attn,
329 mlp,
330 input_layernorm,
331 post_attention_layernorm,
332 gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
333 gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
334 })
335 } else {
336 Ok(Self {
337 self_attn,
338 mlp,
339 input_layernorm,
340 post_attention_layernorm,
341 gate_attn: None,
342 gate_ffn: None,
343 })
344 }
345 }
346
347 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
349 let residual = hidden_state;
351 let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
352
353 hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
354
355 if let Some(gate) = &self.gate_attn {
356 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
357 }
358 hidden_state = (residual + hidden_state)?;
359
360 let residual = hidden_state.clone();
362 hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
363
364 hidden_state = self.mlp.forward(&hidden_state)?;
365
366 if let Some(gate) = &self.gate_ffn {
367 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
368 }
369 residual + hidden_state
370 }
371}
372
373struct MLlamaVisionEncoder {
374 layers: Vec<MLlamaVisionEncoderLayer>,
375}
376
377impl MLlamaVisionEncoder {
378 fn new<const GATED: bool>(
379 cfg: &MLlamaVisionConfig,
380 num_layers: usize,
381 vb: ShardedVarBuilder,
382 real_dev: &Device,
383 comm: &Arc<mistralrs_quant::Comm>,
384 ) -> Result<Self> {
385 let mut layers = Vec::with_capacity(num_layers);
386 let layers_vb = vb.pp("layers");
387 for i in 0..num_layers {
388 layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
389 cfg,
390 layers_vb.pp(i),
391 real_dev,
392 comm,
393 )?);
394 }
395 Ok(Self { layers })
396 }
397
398 fn forward_with_states(
401 &self,
402 hidden_state: &Tensor,
403 attention_mask: Option<&Tensor>,
404 intermediate_layers_indices: Option<&[usize]>,
405 ) -> Result<(Tensor, Vec<Tensor>)> {
406 let mut hidden_state = hidden_state.clone();
407 let mut hidden_states = Vec::new();
408 for (i, layer) in self.layers.iter().enumerate() {
409 if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
410 hidden_states.push(hidden_state.clone());
411 }
412 hidden_state = layer.forward(&hidden_state, attention_mask)?;
413 }
414 Ok((hidden_state, hidden_states))
415 }
416
417 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
418 let uvb_t = UnVarBuilder::new();
419
420 for (i, layer) in self.layers.iter().enumerate() {
421 let uvb_l = uvb_t.pp("layers").pp(i);
422 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
423 uvb_l
424 .pp("post_attention_layernorm")
425 .add(&layer.post_attention_layernorm);
426 if let Some(gate) = layer.gate_attn.clone() {
427 uvb_l.add_tensor("gate_attn", gate);
428 }
429 if let Some(gate) = layer.gate_ffn.clone() {
430 uvb_l.add_tensor("gate_ffn", gate);
431 }
432 }
433
434 uvb_t.to_safetensors()
435 }
436}
437
438fn _prepare_aspect_ratio_attention_mask(
439 aspect_ratio_mask: &Tensor,
440 num_patches: usize,
441 target_length: usize,
442 dtype: DType,
443 _num_attn_heads: usize,
444) -> Result<Tensor> {
445 let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
446 let mut attention_mask = aspect_ratio_mask
447 .reshape((bs, max_num_tiles, 1, 1))?
448 .repeat((1, 1, target_length, 1))?;
449
450 let pad_patches = target_length - num_patches;
452 let (bs, d1, d2, d3) = attention_mask.dims4()?;
453 attention_mask = attention_mask.slice_assign(
454 &[&.., &.., &(d2 - pad_patches..), &..],
455 &Tensor::zeros(
456 (bs, d1, pad_patches, d3),
457 attention_mask.dtype(),
458 attention_mask.device(),
459 )?,
460 )?;
461
462 attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
464
465 let neg_inf_value = dtype.finfo()?.min;
468 attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
469 attention_mask.matmul(
470 &attention_mask
471 .transpose(D::Minus1, D::Minus2)?
472 .mul(neg_inf_value)?,
473 )
474}
475
476pub(super) struct MLlamaVisionModel {
477 patch_embedding: Conv2d,
478 class_embedding: Tensor,
479 gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
480 pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
481 post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
482 layernorm_pre: LayerNorm,
483 layernorm_post: LayerNorm,
484 transformer: MLlamaVisionEncoder,
485 global_transformer: MLlamaVisionEncoder,
486 pub(super) num_patches: usize,
487 intermediate_layers_indices: Vec<usize>,
488 num_attn_heads: usize,
489}
490
491impl MLlamaVisionModel {
492 pub(super) fn new(
493 cfg: &MLlamaVisionConfig,
494 vb: ShardedVarBuilder,
495 real_dev: &Device,
496 comm: &Arc<mistralrs_quant::Comm>,
497 ) -> Result<Self> {
498 let patch_embedding = conv2d_no_bias(
499 cfg.num_channels,
500 cfg.hidden_size,
501 cfg.patch_size,
502 Conv2dConfig {
503 stride: cfg.patch_size,
504 ..Default::default()
505 },
506 vb.pp("patch_embedding").set_device(real_dev.clone()),
507 )?;
508
509 let class_embedding = vb
510 .get((cfg.hidden_size,), "class_embedding")?
511 .to_device(real_dev)?;
512 let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
513 cfg,
514 vb.pp("gated_positional_embedding")
515 .set_device(real_dev.clone()),
516 )?;
517
518 let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
519 cfg,
520 vb.pp("pre_tile_positional_embedding")
521 .set_device(real_dev.clone()),
522 )?;
523 let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
524 cfg,
525 vb.pp("post_tile_positional_embedding")
526 .set_device(real_dev.clone()),
527 )?;
528
529 let layernorm_pre = layer_norm(
531 cfg.hidden_size,
532 LayerNormConfig::default(),
533 vb.pp("layernorm_pre").set_device(real_dev.clone()),
534 )?;
535 let layernorm_post = layer_norm(
536 cfg.hidden_size,
537 LayerNormConfig::default(),
538 vb.pp("layernorm_post").set_device(real_dev.clone()),
539 )?;
540
541 let transformer = MLlamaVisionEncoder::new::<false>(
543 cfg,
544 cfg.num_hidden_layers,
545 vb.pp("transformer"),
546 real_dev,
547 comm,
548 )?;
549 let global_transformer = MLlamaVisionEncoder::new::<true>(
550 cfg,
551 cfg.num_global_layers,
552 vb.pp("global_transformer"),
553 real_dev,
554 comm,
555 )?;
556
557 Ok(Self {
558 patch_embedding,
559 class_embedding,
560 gated_positional_embedding,
561 pre_tile_positional_embedding,
562 post_tile_positional_embedding,
563 layernorm_post,
564 layernorm_pre,
565 transformer,
566 global_transformer,
567 num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
568 intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
569 num_attn_heads: cfg.num_attention_heads,
570 })
571 }
572
573 pub(super) fn forward(
575 &self,
576 pixel_values: &Tensor,
577 aspect_ratio_ids: &Tensor,
578 aspect_ratio_mask: &Tensor,
579 ) -> Result<Tensor> {
580 let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
581
582 let bs = pixel_values.dim(0)?;
583 let num_concurrent_media = pixel_values.dim(1)?;
584 let num_tiles = pixel_values.dim(2)?;
585 let num_channels = pixel_values.dim(3)?;
586 let height = pixel_values.dim(4)?;
587 let width = pixel_values.dim(5)?;
588
589 let pixel_values = pixel_values.reshape((
590 bs * num_concurrent_media * num_tiles,
591 num_channels,
592 height,
593 width,
594 ))?;
595 let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
596
597 let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
599 let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
600
601 let (_, mut num_patches, dim) = hidden_state.dims3()?;
603 hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
604 hidden_state = self
605 .pre_tile_positional_embedding
606 .forward(&hidden_state, &aspect_ratio_ids)?;
607
608 hidden_state =
610 hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
611 hidden_state = self.apply_class_embedding(&hidden_state)?;
612 num_patches += 1;
613
614 hidden_state =
616 hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
617 hidden_state = self
618 .gated_positional_embedding
619 .forward(&hidden_state, &aspect_ratio_ids)?;
620
621 hidden_state = self.layernorm_pre.forward(&hidden_state)?;
622
623 let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
625 let _padding = (0usize, 0usize, 0usize, num_padding_patches);
628 if num_padding_patches >= 0 {
629 hidden_state =
630 hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
631 } else {
632 hidden_state = hidden_state.narrow(
633 D::Minus2,
634 0,
635 (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
636 )?;
637 }
638
639 let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
641 attention_mask = _prepare_aspect_ratio_attention_mask(
642 &attention_mask,
643 self.num_patches,
644 hidden_state.dim(2)?,
645 hidden_state.dtype(),
646 self.num_attn_heads,
647 )?;
648 if attention_mask.dim(0)? != 1 {
649 attention_mask = attention_mask.unsqueeze(1)?;
650 }
651
652 hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
654 let (mut hidden_state, all_intermediate_hidden_states) =
655 self.transformer.forward_with_states(
656 &hidden_state,
657 Some(&attention_mask),
658 Some(&self.intermediate_layers_indices),
659 )?;
660
661 let mut intermediate_hidden_states =
663 Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
664 drop(all_intermediate_hidden_states);
665
666 hidden_state = self.layernorm_post.forward(&hidden_state)?;
667
668 hidden_state = hidden_state.reshape((
670 bs * num_concurrent_media,
671 num_tiles,
672 (num_patches as isize + num_padding_patches) as usize,
673 dim,
674 ))?;
675 hidden_state = self
676 .post_tile_positional_embedding
677 .forward(&hidden_state, &aspect_ratio_ids)?;
678 hidden_state = hidden_state.reshape((
679 bs * num_concurrent_media,
680 num_tiles * (num_patches as isize + num_padding_patches) as usize,
681 dim,
682 ))?;
683 (hidden_state, _) = self.global_transformer.forward_with_states(
684 &hidden_state,
685 Some(&attention_mask),
686 None,
687 )?;
688
689 hidden_state = hidden_state.reshape((
691 bs * num_concurrent_media,
692 num_tiles,
693 (num_patches as isize + num_padding_patches) as usize,
694 dim,
695 ))?;
696 hidden_state = hidden_state.narrow(
697 2,
698 0,
699 (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
700 )?;
701 hidden_state =
702 hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
703
704 intermediate_hidden_states = intermediate_hidden_states.reshape((
706 bs * num_concurrent_media,
707 num_tiles,
708 (num_patches as isize + num_padding_patches) as usize,
709 (),
710 ))?;
711 intermediate_hidden_states = intermediate_hidden_states.narrow(
712 2,
713 0,
714 (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
715 )?;
716 intermediate_hidden_states = intermediate_hidden_states.reshape((
717 bs,
718 num_concurrent_media,
719 num_tiles,
720 num_patches,
721 (),
722 ))?;
723
724 Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
726 }
727
728 fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
729 let (bs, _, hidden_size) = hidden_state.dims3()?;
730 let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
731 Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
732 }
733
734 pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
735 let mut layers = Vec::new();
736 for layer in &mut self.global_transformer.layers {
737 layers.push(&mut layer.self_attn.q_proj);
738 layers.push(&mut layer.self_attn.k_proj);
739 layers.push(&mut layer.self_attn.v_proj);
740 layers.push(&mut layer.self_attn.o_proj);
741
742 layers.push(&mut layer.mlp.fc1);
743 layers.push(&mut layer.mlp.fc2);
744 }
745 for layer in &mut self.transformer.layers {
746 layers.push(&mut layer.self_attn.q_proj);
747 layers.push(&mut layer.self_attn.k_proj);
748 layers.push(&mut layer.self_attn.v_proj);
749 layers.push(&mut layer.self_attn.o_proj);
750
751 layers.push(&mut layer.mlp.fc1);
752 layers.push(&mut layer.mlp.fc2);
753 }
754 layers
755 }
756}
757
758impl IsqModel for MLlamaVisionModel {
759 fn get_layers(
760 &mut self,
761 ) -> (
762 Vec<(
763 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
764 Option<usize>,
765 )>,
766 &dyn crate::device_map::DeviceMapper,
767 ) {
768 unreachable!("MLlamaVision model cannot be quantized.");
769 }
770 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
771 let uvb = UnVarBuilder::new();
772
773 uvb.pp("patch_embedding").add(&self.patch_embedding);
774 uvb.add_tensor("class_embedding", self.class_embedding.clone());
775
776 uvb.pp("gated_positional_embedding")
778 .extend(self.gated_positional_embedding.residual_tensors());
779
780 uvb.pp("pre_tile_positional_embedding")
782 .extend(self.pre_tile_positional_embedding.residual_tensors());
783
784 uvb.pp("post_tile_positional_embedding")
786 .extend(self.post_tile_positional_embedding.residual_tensors());
787
788 uvb.pp("layernorm_pre").add(&self.layernorm_pre);
789 uvb.pp("layernorm_post").add(&self.layernorm_post);
790
791 uvb.pp("transformer")
793 .extend(self.transformer.residual_tensors());
794
795 uvb.pp("global_transformer")
797 .extend(self.global_transformer.residual_tensors());
798
799 uvb.to_safetensors()
800 }
801}