1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{ops::Mul, sync::Arc};
4
5use candle_core::{DType, Device, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, LayerNormConfig, Module};
7use mistralrs_quant::{
8 ColumnParallelLayer, Convolution, QuantMethod, RowParallelLayer, ShardedVarBuilder,
9};
10
11use crate::{
12 attention::SdpaParams,
13 layers::{conv2d_no_bias, embedding, layer_norm, GetFloatInfo, Sdpa},
14 pipeline::IsqModel,
15 utils::unvarbuilder::UnVarBuilder,
16};
17
18use super::{MLlamaVisionConfig, VisionActivation};
19
20struct MLlamaPrecomputedPositionEmbedding {
21 gate: Tensor,
22 embedding: Tensor,
23 tile_embedding: Embedding,
24 num_patches: usize,
25 hidden_size: usize,
26 max_num_tiles: usize,
27}
28
29impl MLlamaPrecomputedPositionEmbedding {
30 fn new(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
31 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
32 Ok(Self {
33 gate: vb.get((1,), "gate")?,
34 embedding: vb.get((num_patches, cfg.hidden_size), "embedding")?,
35 tile_embedding: embedding(
36 cfg.max_aspect_ratio_id() + 1,
37 cfg.max_num_tiles * num_patches * cfg.hidden_size,
38 vb.pp("tile_embedding"),
39 &None,
40 )?,
41 num_patches,
42 hidden_size: cfg.hidden_size,
43 max_num_tiles: cfg.max_num_tiles,
44 })
45 }
46
47 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
49 let mut gated_pos_embed = (1. - &self.gate.tanh()?)?.broadcast_mul(&self.embedding)?;
51 let hidden_state = hidden_state.broadcast_add(&gated_pos_embed.reshape((
52 1,
53 1,
54 self.num_patches,
55 self.hidden_size,
56 ))?)?;
57
58 let mut tile_position_embedding = self.tile_embedding.forward(aspect_ratio_ids)?;
60 let bs = hidden_state.dim(0)?;
61 tile_position_embedding = tile_position_embedding.reshape((
62 bs,
63 self.max_num_tiles,
64 self.num_patches,
65 self.hidden_size,
66 ))?;
67 gated_pos_embed = self.gate.tanh()?.broadcast_mul(&tile_position_embedding)?;
68
69 hidden_state.broadcast_add(&gated_pos_embed)
70 }
71
72 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
73 let uvb_gpe = UnVarBuilder::new();
74
75 uvb_gpe.add_tensor("gate", self.gate.clone());
76 uvb_gpe.add_tensor("embedding", self.embedding.clone());
77 uvb_gpe.pp("tile_embedding").add(&self.tile_embedding);
78
79 uvb_gpe.to_safetensors()
80 }
81}
82
83struct MLlamaPrecomputedAspectRatioEmbedding {
84 embedding: Embedding,
85 gate: Option<Tensor>,
86 max_num_tiles: usize,
87 hidden_size: usize,
88}
89
90impl MLlamaPrecomputedAspectRatioEmbedding {
91 fn new<const GATED: bool>(cfg: &MLlamaVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
92 Ok(Self {
93 embedding: embedding(
94 cfg.max_aspect_ratio_id() + 1,
95 cfg.max_num_tiles * cfg.hidden_size,
96 vb.pp("embedding"),
97 &None,
98 )?,
99 gate: if GATED {
100 Some(vb.get((1,), "gate")?)
101 } else {
102 None
103 },
104 max_num_tiles: cfg.max_num_tiles,
105 hidden_size: cfg.hidden_size,
106 })
107 }
108
109 fn forward(&self, hidden_state: &Tensor, aspect_ratio_ids: &Tensor) -> Result<Tensor> {
110 let mut embeddings = self.embedding.forward(aspect_ratio_ids)?;
111 embeddings = embeddings.reshape(((), self.max_num_tiles, 1, self.hidden_size))?;
112
113 if let Some(gate) = &self.gate {
114 embeddings = embeddings.broadcast_mul(&gate.tanh()?)?;
115 }
116
117 hidden_state.broadcast_add(&embeddings)
118 }
119
120 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
121 let uvb_ptpe = UnVarBuilder::new();
122
123 if let Some(gate) = self.gate.clone() {
124 uvb_ptpe.add_tensor("gate", gate);
125 }
126 uvb_ptpe.pp("embedding").add(&self.embedding);
127
128 uvb_ptpe.to_safetensors()
129 }
130}
131
132struct MLlamaVisionAttention {
133 q_proj: Arc<dyn QuantMethod>,
134 k_proj: Arc<dyn QuantMethod>,
135 v_proj: Arc<dyn QuantMethod>,
136 o_proj: Arc<dyn QuantMethod>,
137 sdpa_params: SdpaParams,
138 num_heads: usize,
139 head_dim: usize,
140}
141
142impl MLlamaVisionAttention {
143 fn new(
144 cfg: &MLlamaVisionConfig,
145 vb: ShardedVarBuilder,
146 comm: &Arc<mistralrs_quant::Comm>,
147 ) -> Result<Self> {
148 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
149 Ok(Self {
150 q_proj: ColumnParallelLayer::new(
151 cfg.hidden_size,
152 cfg.num_attention_heads * head_dim,
153 &None,
154 false,
155 comm,
156 vb.pp("q_proj"),
157 )?,
158 k_proj: ColumnParallelLayer::new(
159 cfg.hidden_size,
160 cfg.num_attention_heads * head_dim,
161 &None,
162 false,
163 comm,
164 vb.pp("k_proj"),
165 )?,
166 v_proj: ColumnParallelLayer::new(
167 cfg.hidden_size,
168 cfg.num_attention_heads * head_dim,
169 &None,
170 false,
171 comm,
172 vb.pp("v_proj"),
173 )?,
174 o_proj: RowParallelLayer::new(
175 cfg.hidden_size,
176 cfg.num_attention_heads * head_dim,
177 &None,
178 false,
179 comm,
180 vb.pp("o_proj"),
181 )?,
182 sdpa_params: SdpaParams {
183 n_kv_groups: 1,
184 softcap: None,
185 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
186 sliding_window: None,
187 },
188 num_heads: cfg.num_attention_heads,
189 head_dim,
190 })
191 }
192
193 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
195 let mut hidden_state = hidden_state.clone();
196 let original_dtype = hidden_state.dtype();
197 if let Some(t) = self.q_proj.quantized_act_type() {
198 hidden_state = hidden_state.to_dtype(t)?;
199 }
200 let mut q = self.q_proj.forward(&hidden_state)?;
201 let mut k = self.k_proj.forward(&hidden_state)?;
202 let mut v = self.v_proj.forward(&hidden_state)?;
203 if self.q_proj.quantized_act_type().is_some() {
204 q = q.to_dtype(original_dtype)?;
205 k = k.to_dtype(original_dtype)?;
206 v = v.to_dtype(original_dtype)?;
207 }
208
209 let (bs, q_sq, _) = q.dims3()?;
211 let (_, k_sq, _) = k.dims3()?;
212
213 q = q
214 .reshape((bs, q_sq, self.num_heads, self.head_dim))?
215 .transpose(1, 2)?;
216 k = k
217 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
218 .transpose(1, 2)?;
219 v = v
220 .reshape((bs, k_sq, self.num_heads, self.head_dim))?
221 .transpose(1, 2)?;
222
223 let mut attn_output = Sdpa
224 .run_attention(
225 &q.contiguous()?,
226 &k.contiguous()?,
227 &v.contiguous()?,
228 attention_mask,
229 None,
230 &self.sdpa_params,
231 )?
232 .transpose(1, 2)?
233 .contiguous()?
234 .reshape((bs, q_sq, ()))?
235 .to_dtype(q.dtype())?;
236
237 if let Some(t) = self.q_proj.quantized_act_type() {
238 attn_output = attn_output.to_dtype(t)?;
239 }
240 let mut res = self.o_proj.forward(&attn_output)?;
241 if self.q_proj.quantized_act_type().is_some() {
242 res = res.to_dtype(original_dtype)?;
243 }
244 Ok(res)
245 }
246}
247
248struct MLlamaMlp {
249 act: VisionActivation,
250 fc1: Arc<dyn QuantMethod>,
251 fc2: Arc<dyn QuantMethod>,
252}
253
254impl MLlamaMlp {
255 fn new(
256 cfg: &MLlamaVisionConfig,
257 vb: ShardedVarBuilder,
258 comm: &Arc<mistralrs_quant::Comm>,
259 ) -> Result<Self> {
260 Ok(Self {
261 act: cfg.hidden_act,
262 fc1: ColumnParallelLayer::new(
263 cfg.hidden_size,
264 cfg.intermediate_size,
265 &None,
266 true,
267 comm,
268 vb.pp("fc1"),
269 )?,
270 fc2: RowParallelLayer::new(
271 cfg.intermediate_size,
272 cfg.hidden_size,
273 &None,
274 true,
275 comm,
276 vb.pp("fc2"),
277 )?,
278 })
279 }
280
281 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
283 let original_dtype = hidden_states.dtype();
284 let mut hidden_states = hidden_states.clone();
285 if let Some(t) = self.fc1.quantized_act_type() {
286 hidden_states = hidden_states.to_dtype(t)?;
287 }
288 hidden_states = self
289 .fc2
290 .forward(&self.act.forward(&self.fc1.forward(&hidden_states)?)?)?;
291 if self.fc1.quantized_act_type().is_some() {
292 hidden_states = hidden_states.to_dtype(original_dtype)?;
293 }
294 Ok(hidden_states)
295 }
296}
297
298struct MLlamaVisionEncoderLayer {
299 self_attn: MLlamaVisionAttention,
300 mlp: MLlamaMlp,
301 input_layernorm: LayerNorm,
302 post_attention_layernorm: LayerNorm,
303 gate_attn: Option<Tensor>,
304 gate_ffn: Option<Tensor>,
305}
306
307impl MLlamaVisionEncoderLayer {
308 fn new<const GATED: bool>(
309 cfg: &MLlamaVisionConfig,
310 vb: ShardedVarBuilder,
311 real_dev: &Device,
312 comm: &Arc<mistralrs_quant::Comm>,
313 ) -> Result<Self> {
314 let self_attn = MLlamaVisionAttention::new(cfg, vb.pp("self_attn"), comm)?;
315 let mlp = MLlamaMlp::new(cfg, vb.pp("mlp"), comm)?;
316
317 let input_layernorm = layer_norm(
318 cfg.hidden_size,
319 cfg.norm_eps,
320 vb.pp("input_layernorm").set_device(real_dev.clone()),
321 )?;
322 let post_attention_layernorm = layer_norm(
323 cfg.hidden_size,
324 cfg.norm_eps,
325 vb.pp("post_attention_layernorm")
326 .set_device(real_dev.clone()),
327 )?;
328
329 if GATED {
330 Ok(Self {
331 self_attn,
332 mlp,
333 input_layernorm,
334 post_attention_layernorm,
335 gate_attn: Some(vb.get((1,), "gate_attn")?.to_device(real_dev)?),
336 gate_ffn: Some(vb.get((1,), "gate_ffn")?.to_device(real_dev)?),
337 })
338 } else {
339 Ok(Self {
340 self_attn,
341 mlp,
342 input_layernorm,
343 post_attention_layernorm,
344 gate_attn: None,
345 gate_ffn: None,
346 })
347 }
348 }
349
350 fn forward(&self, hidden_state: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
352 let residual = hidden_state;
354 let mut hidden_state = self.input_layernorm.forward(hidden_state)?;
355
356 hidden_state = self.self_attn.forward(&hidden_state, attention_mask)?;
357
358 if let Some(gate) = &self.gate_attn {
359 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
360 }
361 hidden_state = (residual + hidden_state)?;
362
363 let residual = hidden_state.clone();
365 hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
366
367 hidden_state = self.mlp.forward(&hidden_state)?;
368
369 if let Some(gate) = &self.gate_ffn {
370 hidden_state = gate.broadcast_mul(&hidden_state.tanh()?)?;
371 }
372 residual + hidden_state
373 }
374}
375
376struct MLlamaVisionEncoder {
377 layers: Vec<MLlamaVisionEncoderLayer>,
378}
379
380impl MLlamaVisionEncoder {
381 fn new<const GATED: bool>(
382 cfg: &MLlamaVisionConfig,
383 num_layers: usize,
384 vb: ShardedVarBuilder,
385 real_dev: &Device,
386 comm: &Arc<mistralrs_quant::Comm>,
387 ) -> Result<Self> {
388 let mut layers = Vec::with_capacity(num_layers);
389 let layers_vb = vb.pp("layers");
390 for i in 0..num_layers {
391 layers.push(MLlamaVisionEncoderLayer::new::<GATED>(
392 cfg,
393 layers_vb.pp(i),
394 real_dev,
395 comm,
396 )?);
397 }
398 Ok(Self { layers })
399 }
400
401 fn forward_with_states(
404 &self,
405 hidden_state: &Tensor,
406 attention_mask: Option<&Tensor>,
407 intermediate_layers_indices: Option<&[usize]>,
408 ) -> Result<(Tensor, Vec<Tensor>)> {
409 let mut hidden_state = hidden_state.clone();
410 let mut hidden_states = Vec::new();
411 for (i, layer) in self.layers.iter().enumerate() {
412 if intermediate_layers_indices.is_some_and(|indices: &[usize]| indices.contains(&i)) {
413 hidden_states.push(hidden_state.clone());
414 }
415 hidden_state = layer.forward(&hidden_state, attention_mask)?;
416 }
417 Ok((hidden_state, hidden_states))
418 }
419
420 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
421 let uvb_t = UnVarBuilder::new();
422
423 for (i, layer) in self.layers.iter().enumerate() {
424 let uvb_l = uvb_t.pp("layers").pp(i);
425 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
426 uvb_l
427 .pp("post_attention_layernorm")
428 .add(&layer.post_attention_layernorm);
429 if let Some(gate) = layer.gate_attn.clone() {
430 uvb_l.add_tensor("gate_attn", gate);
431 }
432 if let Some(gate) = layer.gate_ffn.clone() {
433 uvb_l.add_tensor("gate_ffn", gate);
434 }
435 }
436
437 uvb_t.to_safetensors()
438 }
439}
440
441fn _prepare_aspect_ratio_attention_mask(
442 aspect_ratio_mask: &Tensor,
443 num_patches: usize,
444 target_length: usize,
445 dtype: DType,
446 _num_attn_heads: usize,
447) -> Result<Tensor> {
448 let (bs, max_num_tiles) = aspect_ratio_mask.dims2()?;
449 let mut attention_mask = aspect_ratio_mask
450 .reshape((bs, max_num_tiles, 1, 1))?
451 .repeat((1, 1, target_length, 1))?;
452
453 let pad_patches = target_length - num_patches;
455 let (bs, d1, d2, d3) = attention_mask.dims4()?;
456 attention_mask = attention_mask.slice_assign(
457 &[0..bs, 0..d1, (d2 - pad_patches)..d2, 0..d3],
458 &Tensor::zeros(
459 (bs, d1, pad_patches, d3),
460 attention_mask.dtype(),
461 attention_mask.device(),
462 )?,
463 )?;
464
465 attention_mask = (1. - attention_mask.to_dtype(DType::F32)?.to_dtype(dtype)?)?;
467
468 let neg_inf_value = dtype.finfo()?.min;
471 attention_mask = attention_mask.reshape((bs, max_num_tiles * target_length, 1))?;
472 attention_mask.matmul(
473 &attention_mask
474 .transpose(D::Minus1, D::Minus2)?
475 .mul(neg_inf_value)?,
476 )
477}
478
479pub(super) struct MLlamaVisionModel {
480 patch_embedding: Conv2d,
481 class_embedding: Tensor,
482 gated_positional_embedding: MLlamaPrecomputedPositionEmbedding,
483 pre_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
484 post_tile_positional_embedding: MLlamaPrecomputedAspectRatioEmbedding,
485 layernorm_pre: LayerNorm,
486 layernorm_post: LayerNorm,
487 transformer: MLlamaVisionEncoder,
488 global_transformer: MLlamaVisionEncoder,
489 pub(super) num_patches: usize,
490 intermediate_layers_indices: Vec<usize>,
491 num_attn_heads: usize,
492}
493
494impl MLlamaVisionModel {
495 pub(super) fn new(
496 cfg: &MLlamaVisionConfig,
497 vb: ShardedVarBuilder,
498 real_dev: &Device,
499 comm: &Arc<mistralrs_quant::Comm>,
500 ) -> Result<Self> {
501 let patch_embedding = conv2d_no_bias(
502 cfg.num_channels,
503 cfg.hidden_size,
504 cfg.patch_size,
505 Conv2dConfig {
506 stride: cfg.patch_size,
507 ..Default::default()
508 },
509 vb.pp("patch_embedding").set_device(real_dev.clone()),
510 )?;
511
512 let class_embedding = vb
513 .get((cfg.hidden_size,), "class_embedding")?
514 .to_device(real_dev)?;
515 let gated_positional_embedding = MLlamaPrecomputedPositionEmbedding::new(
516 cfg,
517 vb.pp("gated_positional_embedding")
518 .set_device(real_dev.clone()),
519 )?;
520
521 let pre_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
522 cfg,
523 vb.pp("pre_tile_positional_embedding")
524 .set_device(real_dev.clone()),
525 )?;
526 let post_tile_positional_embedding = MLlamaPrecomputedAspectRatioEmbedding::new::<true>(
527 cfg,
528 vb.pp("post_tile_positional_embedding")
529 .set_device(real_dev.clone()),
530 )?;
531
532 let layernorm_pre = layer_norm(
534 cfg.hidden_size,
535 LayerNormConfig::default(),
536 vb.pp("layernorm_pre").set_device(real_dev.clone()),
537 )?;
538 let layernorm_post = layer_norm(
539 cfg.hidden_size,
540 LayerNormConfig::default(),
541 vb.pp("layernorm_post").set_device(real_dev.clone()),
542 )?;
543
544 let transformer = MLlamaVisionEncoder::new::<false>(
546 cfg,
547 cfg.num_hidden_layers,
548 vb.pp("transformer"),
549 real_dev,
550 comm,
551 )?;
552 let global_transformer = MLlamaVisionEncoder::new::<true>(
553 cfg,
554 cfg.num_global_layers,
555 vb.pp("global_transformer"),
556 real_dev,
557 comm,
558 )?;
559
560 Ok(Self {
561 patch_embedding,
562 class_embedding,
563 gated_positional_embedding,
564 pre_tile_positional_embedding,
565 post_tile_positional_embedding,
566 layernorm_post,
567 layernorm_pre,
568 transformer,
569 global_transformer,
570 num_patches: (cfg.image_size / cfg.patch_size).pow(2) + 1,
571 intermediate_layers_indices: cfg.intermediate_layers_indices.clone(),
572 num_attn_heads: cfg.num_attention_heads,
573 })
574 }
575
576 pub(super) fn forward(
578 &self,
579 pixel_values: &Tensor,
580 aspect_ratio_ids: &Tensor,
581 aspect_ratio_mask: &Tensor,
582 ) -> Result<Tensor> {
583 let pixel_values = pixel_values.to_dtype(self.class_embedding.dtype())?;
584
585 let bs = pixel_values.dim(0)?;
586 let num_concurrent_media = pixel_values.dim(1)?;
587 let num_tiles = pixel_values.dim(2)?;
588 let num_channels = pixel_values.dim(3)?;
589 let height = pixel_values.dim(4)?;
590 let width = pixel_values.dim(5)?;
591
592 let pixel_values = pixel_values.reshape((
593 bs * num_concurrent_media * num_tiles,
594 num_channels,
595 height,
596 width,
597 ))?;
598 let aspect_ratio_ids = aspect_ratio_ids.reshape((bs * num_concurrent_media, ()))?;
599
600 let patch_embeds = Convolution.forward_2d(&self.patch_embedding, &pixel_values)?;
602 let mut hidden_state = patch_embeds.flatten_from(2)?.transpose(1, 2)?;
603
604 let (_, mut num_patches, dim) = hidden_state.dims3()?;
606 hidden_state = hidden_state.reshape((bs * num_concurrent_media, num_tiles, (), dim))?;
607 hidden_state = self
608 .pre_tile_positional_embedding
609 .forward(&hidden_state, &aspect_ratio_ids)?;
610
611 hidden_state =
613 hidden_state.reshape((bs * num_concurrent_media * num_tiles, num_patches, dim))?;
614 hidden_state = self.apply_class_embedding(&hidden_state)?;
615 num_patches += 1;
616
617 hidden_state =
619 hidden_state.reshape((bs * num_concurrent_media, num_tiles, num_patches, dim))?;
620 hidden_state = self
621 .gated_positional_embedding
622 .forward(&hidden_state, &aspect_ratio_ids)?;
623
624 hidden_state = self.layernorm_pre.forward(&hidden_state)?;
625
626 let num_padding_patches = (8 - (hidden_state.dim(D::Minus2)? as isize % 8)) % 8;
628 let _padding = (0usize, 0usize, 0usize, num_padding_patches);
631 if num_padding_patches >= 0 {
632 hidden_state =
633 hidden_state.pad_with_zeros(D::Minus2, 0, num_padding_patches as usize)?;
634 } else {
635 hidden_state = hidden_state.narrow(
636 D::Minus2,
637 0,
638 (hidden_state.dim(2)? as isize + num_padding_patches) as usize,
639 )?;
640 }
641
642 let mut attention_mask = aspect_ratio_mask.reshape((bs * num_concurrent_media, ()))?;
644 attention_mask = _prepare_aspect_ratio_attention_mask(
645 &attention_mask,
646 self.num_patches,
647 hidden_state.dim(2)?,
648 hidden_state.dtype(),
649 self.num_attn_heads,
650 )?;
651 if attention_mask.dim(0)? != 1 {
652 attention_mask = attention_mask.unsqueeze(1)?;
653 }
654
655 hidden_state = hidden_state.reshape((bs * num_concurrent_media, (), dim))?;
657 let (mut hidden_state, all_intermediate_hidden_states) =
658 self.transformer.forward_with_states(
659 &hidden_state,
660 Some(&attention_mask),
661 Some(&self.intermediate_layers_indices),
662 )?;
663
664 let mut intermediate_hidden_states =
666 Tensor::stack(&all_intermediate_hidden_states, D::Minus1)?;
667 drop(all_intermediate_hidden_states);
668
669 hidden_state = self.layernorm_post.forward(&hidden_state)?;
670
671 hidden_state = hidden_state.reshape((
673 bs * num_concurrent_media,
674 num_tiles,
675 (num_patches as isize + num_padding_patches) as usize,
676 dim,
677 ))?;
678 hidden_state = self
679 .post_tile_positional_embedding
680 .forward(&hidden_state, &aspect_ratio_ids)?;
681 hidden_state = hidden_state.reshape((
682 bs * num_concurrent_media,
683 num_tiles * (num_patches as isize + num_padding_patches) as usize,
684 dim,
685 ))?;
686 (hidden_state, _) = self.global_transformer.forward_with_states(
687 &hidden_state,
688 Some(&attention_mask),
689 None,
690 )?;
691
692 hidden_state = hidden_state.reshape((
694 bs * num_concurrent_media,
695 num_tiles,
696 (num_patches as isize + num_padding_patches) as usize,
697 dim,
698 ))?;
699 hidden_state = hidden_state.narrow(
700 2,
701 0,
702 (hidden_state.dims()[2] as isize - num_padding_patches) as usize,
703 )?;
704 hidden_state =
705 hidden_state.reshape((bs, num_concurrent_media, num_tiles, num_patches, dim))?;
706
707 intermediate_hidden_states = intermediate_hidden_states.reshape((
709 bs * num_concurrent_media,
710 num_tiles,
711 (num_patches as isize + num_padding_patches) as usize,
712 (),
713 ))?;
714 intermediate_hidden_states = intermediate_hidden_states.narrow(
715 2,
716 0,
717 (intermediate_hidden_states.dims()[2] as isize - num_padding_patches) as usize,
718 )?;
719 intermediate_hidden_states = intermediate_hidden_states.reshape((
720 bs,
721 num_concurrent_media,
722 num_tiles,
723 num_patches,
724 (),
725 ))?;
726
727 Tensor::cat(&[hidden_state, intermediate_hidden_states], D::Minus1)
729 }
730
731 fn apply_class_embedding(&self, hidden_state: &Tensor) -> Result<Tensor> {
732 let (bs, _, hidden_size) = hidden_state.dims3()?;
733 let class_embedding = self.class_embedding.expand((bs, 1, hidden_size))?;
734 Tensor::cat(&[class_embedding, hidden_state.clone()], 1)
735 }
736
737 pub fn get_isq_layers(&mut self) -> Vec<&mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>> {
738 let mut layers = Vec::new();
739 for layer in &mut self.global_transformer.layers {
740 layers.push(&mut layer.self_attn.q_proj);
741 layers.push(&mut layer.self_attn.k_proj);
742 layers.push(&mut layer.self_attn.v_proj);
743 layers.push(&mut layer.self_attn.o_proj);
744
745 layers.push(&mut layer.mlp.fc1);
746 layers.push(&mut layer.mlp.fc2);
747 }
748 for layer in &mut self.transformer.layers {
749 layers.push(&mut layer.self_attn.q_proj);
750 layers.push(&mut layer.self_attn.k_proj);
751 layers.push(&mut layer.self_attn.v_proj);
752 layers.push(&mut layer.self_attn.o_proj);
753
754 layers.push(&mut layer.mlp.fc1);
755 layers.push(&mut layer.mlp.fc2);
756 }
757 layers
758 }
759}
760
761impl IsqModel for MLlamaVisionModel {
762 fn get_layers(
763 &mut self,
764 ) -> (
765 Vec<(
766 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
767 Option<usize>,
768 )>,
769 &dyn crate::device_map::DeviceMapper,
770 ) {
771 unreachable!("MLlamaVision model cannot be quantized.");
772 }
773 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
774 let uvb = UnVarBuilder::new();
775
776 uvb.pp("patch_embedding").add(&self.patch_embedding);
777 uvb.add_tensor("class_embedding", self.class_embedding.clone());
778
779 uvb.pp("gated_positional_embedding")
781 .extend(self.gated_positional_embedding.residual_tensors());
782
783 uvb.pp("pre_tile_positional_embedding")
785 .extend(self.pre_tile_positional_embedding.residual_tensors());
786
787 uvb.pp("post_tile_positional_embedding")
789 .extend(self.post_tile_positional_embedding.residual_tensors());
790
791 uvb.pp("layernorm_pre").add(&self.layernorm_pre);
792 uvb.pp("layernorm_post").add(&self.layernorm_post);
793
794 uvb.pp("transformer")
796 .extend(self.transformer.residual_tensors());
797
798 uvb.pp("global_transformer")
800 .extend(self.global_transformer.residual_tensors());
801
802 uvb.to_safetensors()
803 }
804}