1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod idefics2_input_processor;
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
7use mistralrs_quant::ShardedVarBuilder;
8use serde::Deserialize;
9use std::{any::Any, ops::Mul};
10
11use crate::{
12 amoe::{AnyMoeBaseModelMixin, MlpLayer},
13 device_map::DeviceMapper,
14 layers::{
15 conv2d, embedding, layer_norm, linear, linear_no_bias, repeat_kv, Activation, CausalMasker,
16 MatMul, QLinear, RmsNorm,
17 },
18 models::mistral::Model as Mistral,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata},
20 pipeline::{
21 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
23 },
24 utils::unvarbuilder::UnVarBuilder,
25 AnyMoeConfig, AnyMoeExpertType,
26};
27
28use crate::models::mistral;
29
30fn default_32000() -> usize {
33 32000
34}
35fn default_32001() -> usize {
36 32001
37}
38fn default_4096() -> usize {
39 4096
40}
41fn default_14336() -> usize {
42 14336
43}
44fn default_32() -> usize {
45 32
46}
47fn default_8() -> usize {
48 8
49}
50fn default_act() -> Activation {
51 Activation::Silu
52}
53fn default_131072() -> usize {
54 131072
55}
56fn default_eps() -> f64 {
57 1e-6
58}
59fn default_rope() -> f64 {
60 10000.0
61}
62fn default_false() -> bool {
63 false
64}
65fn default_sliding() -> Option<usize> {
66 Some(4096)
67}
68fn default_gelu() -> Activation {
69 Activation::GeluPytorchTanh
70}
71fn default_64() -> usize {
72 64
73}
74fn default_3() -> usize {
75 3
76}
77fn default_16() -> usize {
78 16
79}
80fn default_96() -> usize {
81 96
82}
83fn default_4() -> usize {
84 4
85}
86fn default_0_0() -> f32 {
87 0.0
88}
89fn default_0_02() -> f32 {
90 0.02
91}
92fn default_768() -> usize {
93 768
94}
95fn default_3072() -> usize {
96 3072
97}
98fn default_12() -> usize {
99 12
100}
101fn default_224() -> usize {
102 224
103}
104
105#[derive(Debug, Clone, PartialEq, Deserialize)]
106pub struct PerceiverConfig {
107 #[serde(default = "default_act")]
108 pub hidden_act: Activation,
109 #[serde(default = "default_64")]
110 pub resampler_n_latents: usize,
111 #[serde(default = "default_3")]
112 pub resampler_depth: usize,
113 #[serde(default = "default_16")]
114 pub resampler_n_heads: usize,
115 #[serde(default = "default_96")]
116 pub resampler_head_dim: usize,
117 #[serde(default = "default_4")]
118 pub num_key_value_heads: usize,
119 #[serde(default = "default_0_0")]
120 pub attention_dropout: f32,
121}
122
123#[derive(Debug, Clone, PartialEq, Deserialize)]
124pub struct VisionConfig {
125 #[serde(default = "default_768")]
126 pub hidden_size: usize,
127 #[serde(default = "default_3072")]
128 pub intermediate_size: usize,
129 #[serde(default = "default_12")]
130 pub num_hidden_layers: usize,
131 #[serde(default = "default_12")]
132 pub num_attention_heads: usize,
133 #[serde(default = "default_3")]
134 pub num_channels: usize,
135 #[serde(default = "default_224")]
136 pub image_size: usize,
137 #[serde(default = "default_32")]
138 pub patch_size: usize,
139 #[serde(default = "default_gelu")]
140 pub hidden_act: Activation,
141 #[serde(default = "default_eps")]
142 pub layer_norm_eps: f64,
143 #[serde(default = "default_0_0")]
144 pub attn_dropout: f32,
145 #[serde(default = "default_0_02")]
146 pub initiailizer_range: f32,
147}
148
149#[derive(Debug, Clone, PartialEq, Deserialize)]
150pub(crate) struct TextConfig {
151 #[serde(default = "default_32000")]
152 pub(crate) vocab_size: usize,
153 #[serde(default = "default_4096")]
154 pub(crate) hidden_size: usize,
155 #[serde(default = "default_14336")]
156 pub(crate) intermediate_size: usize,
157 #[serde(default = "default_32")]
158 pub(crate) num_hidden_layers: usize,
159 #[serde(default = "default_32")]
160 pub(crate) num_attention_heads: usize,
161 #[serde(default = "default_8")]
162 pub(crate) num_key_value_heads: usize,
163 #[serde(default = "default_act")]
164 pub(crate) hidden_act: Activation,
165 #[serde(default = "default_131072")]
166 pub(crate) max_position_embeddings: usize,
167 #[serde(default = "default_eps")]
168 pub(crate) rms_norm_eps: f64,
169 #[serde(default = "default_rope")]
170 pub(crate) rope_theta: f64,
171 #[serde(default = "default_sliding")]
172 pub(crate) sliding_window: Option<usize>,
173
174 #[serde(default = "default_false")]
175 pub(crate) use_flash_attn: bool,
176 model_type: String, }
178
179impl From<TextConfig> for mistral::Config {
180 fn from(val: TextConfig) -> Self {
181 mistral::Config {
182 vocab_size: val.vocab_size,
183 hidden_act: val.hidden_act,
184 hidden_size: val.hidden_size,
185 intermediate_size: val.intermediate_size,
186 num_hidden_layers: val.num_hidden_layers,
187 num_attention_heads: val.num_attention_heads,
188 num_key_value_heads: val.num_key_value_heads,
189 max_position_embeddings: val.max_position_embeddings,
190 rms_norm_eps: val.rms_norm_eps,
191 rope_theta: val.rope_theta,
192 sliding_window: val.sliding_window,
193 use_flash_attn: val.use_flash_attn,
194 head_dim: None,
195 quantization_config: None,
196 tie_word_embeddings: false,
197 }
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Deserialize)]
202pub(crate) struct Config {
203 pub perceiver_config: PerceiverConfig,
204 pub vision_config: VisionConfig,
205 pub(crate) text_config: TextConfig,
206 #[serde(default = "default_32001")]
207 pub image_token_id: usize,
208 #[serde(default = "default_false")]
209 pub tie_word_embeddings: bool,
210}
211
212struct VisionEmbeddings {
215 patch_size: usize,
216 patch_embedding: Conv2d,
217 num_patches_per_side: usize,
218 position_embedding: Embedding,
219}
220
221fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
224 use std::cmp::Ordering;
225
226 let mut result = Vec::with_capacity(xs.len());
227
228 for &x in xs {
229 let idx = match boundaries.binary_search_by(|&val| {
238 val.partial_cmp(&x).unwrap_or(Ordering::Less)
241 }) {
242 Ok(i) => i,
243 Err(i) => i,
244 };
245
246 result.push(idx as u32);
247 }
248
249 Tensor::from_vec(result, (xs.len(),), device)
250}
251
252impl VisionEmbeddings {
253 fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
254 let conv_config = Conv2dConfig {
255 stride: config.patch_size,
256 ..Default::default()
257 };
258 let patch_embedding = conv2d(
259 config.num_channels,
260 config.hidden_size,
261 config.patch_size,
262 conv_config,
263 vb.pp("patch_embedding"),
264 )?;
265 let num_patches_per_side = config.image_size / config.patch_size;
266 let num_patches = num_patches_per_side.pow(2);
267 Ok(Self {
268 patch_size: config.patch_size,
269 patch_embedding,
270 num_patches_per_side,
271 position_embedding: embedding(
272 num_patches,
273 config.hidden_size,
274 vb.pp("position_embedding"),
275 &None,
276 )?,
277 })
278 }
279
280 fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
281 let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
282
283 let patch_embeds = self.patch_embedding.forward(pixel_values)?;
284
285 let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
286
287 let (max_nb_patches_h, max_nb_patches_w) =
288 (max_im_h / self.patch_size, max_im_w / self.patch_size);
289 let boundaries = Tensor::arange_step(
290 1.0 / self.num_patches_per_side as f32,
291 1.0,
292 1.0 / self.num_patches_per_side as f32,
293 pixel_values.device(),
294 )?
295 .to_vec1::<f32>()?;
296 let position_ids = Tensor::full(
297 0u32,
298 (bs, max_nb_patches_h * max_nb_patches_w),
299 pixel_values.device(),
300 )?;
301
302 let mut new_position_ids = Vec::new();
303 for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
304 let p_attn_mask = p_attn_mask.squeeze(0)?;
305 let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
306 let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
307
308 let fractional_coords_h = Tensor::arange_step(
309 0.0,
310 1.0 - 1e-6,
311 1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
312 pixel_values.device(),
313 )?
314 .to_vec1::<f32>()?;
315 let fractional_coords_w = Tensor::arange_step(
316 0.0,
317 1.0 - 1e-6,
318 1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
319 pixel_values.device(),
320 )?
321 .to_vec1::<f32>()?;
322
323 let bucket_coords_h =
324 bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
325 let bucket_coords_w =
326 bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
327
328 let pos_ids = bucket_coords_h
329 .unsqueeze(D::Minus1)?
330 .mul(self.num_patches_per_side as f64)?
331 .broadcast_add(&bucket_coords_w)?
332 .flatten_all()?
333 .to_vec1::<u32>()?;
334
335 let true_indices = p_attn_mask
336 .flatten_all()?
337 .to_vec1::<u8>()?
338 .iter()
339 .enumerate()
340 .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
341 .collect::<Vec<_>>();
342 let position_ids_b = position_ids.i(b_idx)?;
343
344 let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
345 let new_position_ids_b_len = new_position_ids_b.len();
346 for (i, true_idx) in true_indices.into_iter().enumerate() {
347 new_position_ids_b[true_idx] = pos_ids[i];
348 }
349
350 new_position_ids.push(Tensor::from_vec(
351 new_position_ids_b,
352 new_position_ids_b_len,
353 pixel_values.device(),
354 )?);
355 }
356 let position_ids = Tensor::stack(&new_position_ids, 0)?;
357 let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
358 embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
359 }
360
361 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
362 let uvb = UnVarBuilder::new();
363
364 uvb.pp("patch_embedding").add(&self.patch_embedding);
365 uvb.pp("position_embedding").add(&self.position_embedding);
366
367 uvb.to_safetensors()
368 }
369}
370
371struct Attention {
372 embed_dim: usize,
373 num_heads: usize,
374 head_dim: usize,
375 scale: f64,
376 q_proj: QLinear,
377 k_proj: QLinear,
378 v_proj: QLinear,
379 o_proj: QLinear,
380 neg_inf: Tensor,
381}
382
383impl Attention {
384 fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
385 let embed_dim = config.hidden_size;
386 let num_heads = config.num_attention_heads;
387 let head_dim = embed_dim / num_heads;
388 let scale = 1.0 / (head_dim as f64).sqrt();
389
390 let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
391 let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
392 let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
393 let o_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
394
395 Ok(Self {
396 embed_dim,
397 num_heads,
398 head_dim,
399 scale,
400 q_proj: QLinear::from_linear(q_proj),
401 k_proj: QLinear::from_linear(k_proj),
402 v_proj: QLinear::from_linear(v_proj),
403 o_proj: QLinear::from_linear(o_proj),
404 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
405 })
406 }
407
408 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
409 let (b_sz, q_len, _) = xs.dims3()?;
410
411 let original_dtype = xs.dtype();
412 let mut xs = xs.clone();
413 if self.q_proj.is_quant() {
414 xs = xs.to_dtype(DType::F32)?;
415 }
416 let mut q = self.q_proj.forward(&xs)?;
417 let mut k = self.k_proj.forward(&xs)?;
418 let mut v = self.v_proj.forward(&xs)?;
419 if self.q_proj.is_quant() {
420 q = q.to_dtype(original_dtype)?;
421 k = k.to_dtype(original_dtype)?;
422 v = v.to_dtype(original_dtype)?;
423 }
424
425 let q = q
426 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
427 .transpose(1, 2)?;
428 let k = k
429 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
430 .transpose(1, 2)?;
431 let v = v
432 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
433 .transpose(1, 2)?;
434
435 let attn_weights =
436 (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
437
438 let attn_weights = CausalMasker.apply_mask_one_and_zero(
439 &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
440 attn_weights,
441 &self.neg_inf,
442 )?;
443 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
444 let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
445
446 if self.q_proj.is_quant() {
447 attn_output = attn_output.to_dtype(DType::F32)?;
448 }
449 let mut res = attn_output
450 .transpose(1, 2)?
451 .reshape((b_sz, q_len, self.embed_dim))?
452 .apply(&self.o_proj)?;
453 if self.q_proj.is_quant() {
454 res = res.to_dtype(original_dtype)?;
455 }
456 Ok(res)
457 }
458
459 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
460 let uvb = UnVarBuilder::new();
461
462 uvb.pp("q_proj").add(&self.q_proj);
463 uvb.pp("k_proj").add(&self.k_proj);
464 uvb.pp("v_proj").add(&self.v_proj);
465 uvb.pp("out_proj").add(&self.o_proj);
466
467 uvb.to_safetensors()
468 }
469}
470
471struct VisionMLP {
472 activation: Activation,
473 fc1: QLinear,
474 fc2: QLinear,
475}
476
477impl VisionMLP {
478 fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
479 let fc1 = linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
480 let fc2 = linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
481 Ok(Self {
482 activation: config.hidden_act,
483 fc1: QLinear::from_linear(fc1),
484 fc2: QLinear::from_linear(fc2),
485 })
486 }
487
488 fn forward(&self, x: &Tensor) -> Result<Tensor> {
489 let mut x = x.clone();
490 let original_dtype = x.dtype();
491 if self.fc1.is_quant() {
492 x = x.to_dtype(DType::F32)?;
493 }
494 let x = self.fc1.forward(&x)?;
495 let x = self.activation.forward(&x)?;
496 let mut res = self.fc2.forward(&x)?;
497 if self.fc1.is_quant() {
498 res = res.to_dtype(original_dtype)?;
499 }
500 Ok(res)
501 }
502
503 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
504 let uvb = UnVarBuilder::new();
505
506 uvb.pp("fc1").add(&self.fc1);
507 uvb.pp("fc2").add(&self.fc2);
508
509 uvb.to_safetensors()
510 }
511}
512
513struct EncoderLayer {
514 mlp: VisionMLP,
515 attn: Attention,
516 layer_norm_1: LayerNorm,
517 layer_norm_2: LayerNorm,
518}
519
520impl EncoderLayer {
521 fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
522 let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
523 let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
524 let layer_norm_1 = layer_norm(
525 config.hidden_size,
526 config.layer_norm_eps,
527 vb.pp("layer_norm1"),
528 )?;
529 let layer_norm_2 = layer_norm(
530 config.hidden_size,
531 config.layer_norm_eps,
532 vb.pp("layer_norm2"),
533 )?;
534 Ok(Self {
535 mlp,
536 attn,
537 layer_norm_1,
538 layer_norm_2,
539 })
540 }
541
542 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
543 let residual = xs.clone();
544
545 let hidden_states = self.layer_norm_1.forward(xs)?;
546 let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
547 let hidden_states = (hidden_states + residual)?;
548
549 let residual = &hidden_states;
550 let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
551 let hidden_states = self.mlp.forward(&hidden_states)?;
552 hidden_states + residual
553 }
554}
555
556struct Encoder {
557 layers: Vec<EncoderLayer>,
558}
559
560impl Encoder {
561 fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
562 let mut layers = Vec::new();
563 let vb_l = vb.pp("layers");
564 for i in 0..config.num_hidden_layers {
565 layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
566 }
567 Ok(Self { layers })
568 }
569
570 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
571 let mut hidden_states = xs.clone();
572 for layer in &self.layers {
573 hidden_states = layer.forward(&hidden_states, attention_mask)?;
574 }
575 Ok(hidden_states)
576 }
577}
578
579struct VisionTransformer {
580 embeddings: VisionEmbeddings,
581 encoder: Encoder,
582 post_layernorm: LayerNorm,
583 config: VisionConfig,
584}
585
586impl VisionTransformer {
587 fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
588 let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
589 let post_layernorm = layer_norm(
590 config.hidden_size,
591 config.layer_norm_eps,
592 vb.pp("post_layernorm"),
593 )?;
594 let encoder = Encoder::new(config, vb.pp("encoder"))?;
595 Ok(Self {
596 embeddings,
597 encoder,
598 post_layernorm,
599 config: config.clone(),
600 })
601 }
602
603 fn forward(&self, pixel_values: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
604 let bs = pixel_values.dim(0)?;
605 let patch_attention_mask = if let Some(attn_mask) = attention_mask {
606 attn_mask.clone()
607 } else {
608 let patch_size = self.config.patch_size;
609 Tensor::ones(
610 (
611 bs,
612 pixel_values.dim(2)? / patch_size,
613 pixel_values.dim(3)? / patch_size,
614 ),
615 DType::U8,
616 pixel_values.device(),
617 )?
618 };
619
620 let hidden_states = self
621 .embeddings
622 .forward(pixel_values, &patch_attention_mask)?;
623
624 let attention_mask = if attention_mask.is_none() {
625 None
626 } else {
627 let mask = patch_attention_mask
628 .reshape((patch_attention_mask.dim(0)?, ()))?
629 .to_dtype(hidden_states.dtype())?;
630 Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
631 };
632 let hidden_states = self
633 .encoder
634 .forward(&hidden_states, attention_mask.as_ref())?;
635 hidden_states.apply(&self.post_layernorm)
636 }
637
638 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
639 let uvb = UnVarBuilder::new();
640
641 uvb.pp("post_layernorm").add(&self.post_layernorm);
642 uvb.pp("embeddings")
643 .extend(self.embeddings.residual_tensors());
644
645 let uvb_enc = uvb.pp("encoder");
646 for (i, layer) in self.encoder.layers.iter().enumerate() {
647 let uvb_l = uvb_enc.pp("layers").pp(i);
648
649 uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
650 uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
651 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
652 uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
653 }
654
655 uvb.to_safetensors()
656 }
657}
658
659struct Mlp {
663 gate_proj: QLinear,
664 up_proj: QLinear,
665 down_proj: QLinear,
666 activation: Activation,
667}
668
669impl Mlp {
670 fn new(
671 hidden_size: usize,
672 intermediate_size: usize,
673 output_size: usize,
674 activation: Activation,
675 vb: ShardedVarBuilder,
676 ) -> Result<Self> {
677 let gate_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?;
678 let up_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?;
679 let down_proj = linear_no_bias(intermediate_size, output_size, vb.pp("down_proj"))?;
680 Ok(Self {
681 gate_proj: QLinear::from_linear(gate_proj),
682 up_proj: QLinear::from_linear(up_proj),
683 down_proj: QLinear::from_linear(down_proj),
684 activation,
685 })
686 }
687
688 fn forward(&self, x: &Tensor) -> Result<Tensor> {
689 let mut x = x.clone();
690 let original_dtype = x.dtype();
691 if self.gate_proj.is_quant() {
692 x = x.to_dtype(DType::F32)?;
693 }
694 let mut res = self.down_proj.forward(
695 &(self.activation.forward(&self.gate_proj.forward(&x)?)?
696 * self.up_proj.forward(&x)?)?,
697 )?;
698 if self.gate_proj.is_quant() {
699 res = res.to_dtype(original_dtype)?;
700 }
701 Ok(res)
702 }
703
704 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
705 let uvb = UnVarBuilder::new();
706
707 uvb.pp("gate_proj").add(&self.gate_proj);
708 uvb.pp("up_proj").add(&self.up_proj);
709 uvb.pp("down_proj").add(&self.down_proj);
710
711 uvb.to_safetensors()
712 }
713}
714
715struct PerceiverAttention {
716 num_heads: usize,
717 num_kv_heads: usize,
718 num_kv_groups: usize,
719 head_dim: usize,
720 q_proj: QLinear,
721 k_proj: QLinear,
722 v_proj: QLinear,
723 o_proj: QLinear,
724 neg_inf: Tensor,
725}
726
727impl PerceiverAttention {
728 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
729 let hidden_size = config.text_config.hidden_size;
730 let num_heads = config.perceiver_config.resampler_n_heads;
731 let head_dim = config.perceiver_config.resampler_head_dim;
732 let num_key_value_heads = config.perceiver_config.num_key_value_heads;
733 let num_key_value_groups = num_heads / num_key_value_heads;
734
735 let q_proj = linear_no_bias(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
736 let k_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("k_proj"))?;
737 let v_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("v_proj"))?;
738 let o_proj = linear_no_bias(num_heads * head_dim, hidden_size, vb.pp("o_proj"))?;
739
740 Ok(Self {
741 num_heads,
742 head_dim,
743 q_proj: QLinear::from_linear(q_proj),
744 k_proj: QLinear::from_linear(k_proj),
745 v_proj: QLinear::from_linear(v_proj),
746 o_proj: QLinear::from_linear(o_proj),
747 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
748 num_kv_heads: num_key_value_heads,
749 num_kv_groups: num_key_value_groups,
750 })
751 }
752
753 fn forward(
754 &self,
755 latents: &Tensor,
756 context: &Tensor,
757 attention_mask: &Tensor,
758 ) -> Result<Tensor> {
759 let (b_sz, q_len, _) = latents.dims3()?;
760 let kv_seq_len = q_len + context.dims()[1];
761
762 let mut hidden_states = Tensor::cat(&[context, latents], D::Minus2)?;
763
764 let original_dtype = latents.dtype();
765 let mut latents = latents.clone();
766 if self.q_proj.is_quant() {
767 latents = latents.to_dtype(DType::F32)?;
768 hidden_states = hidden_states.to_dtype(DType::F32)?;
769 }
770 let mut q = self.q_proj.forward(&latents)?;
771 let mut k = self.k_proj.forward(&hidden_states)?;
772 let mut v = self.v_proj.forward(&hidden_states)?;
773 if self.q_proj.is_quant() {
774 q = q.to_dtype(original_dtype)?;
775 k = k.to_dtype(original_dtype)?;
776 v = v.to_dtype(original_dtype)?;
777 }
778
779 let q = q
780 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
781 .transpose(1, 2)?;
782 let k = k
783 .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
784 .transpose(1, 2)?;
785 let v = v
786 .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
787 .transpose(1, 2)?;
788
789 let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
790 let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
791
792 let attn_weights = (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)?
793 / (self.head_dim as f64).sqrt())?;
794
795 let attn_weights = CausalMasker.apply_mask_one_and_zero(
796 &Some(attention_mask.to_dtype(DType::U8)?),
797 attn_weights,
798 &self.neg_inf,
799 )?;
800 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
801 let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
802
803 if self.q_proj.is_quant() {
804 attn_output = attn_output.to_dtype(DType::F32)?;
805 }
806 let mut res = attn_output
807 .transpose(1, 2)?
808 .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
809 .apply(&self.o_proj)?;
810 if self.q_proj.is_quant() {
811 res = res.to_dtype(original_dtype)?;
812 }
813 Ok(res)
814 }
815
816 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
817 let uvb = UnVarBuilder::new();
818
819 uvb.pp("q_proj").add(&self.q_proj);
820 uvb.pp("k_proj").add(&self.k_proj);
821 uvb.pp("v_proj").add(&self.v_proj);
822 uvb.pp("o_proj").add(&self.o_proj);
823
824 uvb.to_safetensors()
825 }
826}
827
828struct PerceiverLayer {
829 input_latents_norm: RmsNorm,
830 input_context_norm: RmsNorm,
831 self_attn: PerceiverAttention,
832 post_attn_norm: RmsNorm,
833 mlp: Mlp,
834}
835
836impl PerceiverLayer {
837 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
838 let hidden_size = config.text_config.hidden_size;
839 let mlp_act = config.perceiver_config.hidden_act;
840 let rms_eps = config.text_config.rms_norm_eps;
841
842 Ok(Self {
843 input_latents_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_latents_norm"))?,
844 input_context_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_context_norm"))?,
845 self_attn: PerceiverAttention::new(config, vb.pp("self_attn"))?,
846 post_attn_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("post_attention_layernorm"))?,
847 mlp: Mlp::new(
848 hidden_size,
849 hidden_size * 4,
850 hidden_size,
851 mlp_act,
852 vb.pp("mlp"),
853 )?,
854 })
855 }
856
857 fn forward(
858 &self,
859 latents: &Tensor,
860 context: &Tensor,
861 attention_mask: &Tensor,
862 ) -> Result<Tensor> {
863 let residual = latents;
864
865 let latents = self.input_latents_norm.forward(latents)?;
866 let context = self.input_context_norm.forward(context)?;
867
868 let latents = self.self_attn.forward(&latents, &context, attention_mask)?;
869 let latents = (residual + latents)?;
870 let residual = &latents;
871
872 let latents = self.post_attn_norm.forward(&latents)?;
873 let latents = self.mlp.forward(&latents)?;
874 residual + latents
875 }
876}
877
878struct PerceiverResampler {
879 latents: Tensor,
880 layers: Vec<PerceiverLayer>,
881 norm: RmsNorm,
882 n_latents: usize,
883}
884
885impl PerceiverResampler {
886 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
887 let n_latents = config.perceiver_config.resampler_n_latents;
888 let hidden_size = config.text_config.hidden_size;
889 let depth = config.perceiver_config.resampler_depth;
890
891 let latents = vb.get((n_latents, hidden_size), "latents")?;
892 let mut layers = Vec::new();
893 let vb_l = vb.pp("layers");
894 for i in 0..depth {
895 layers.push(PerceiverLayer::new(config, vb_l.pp(i))?);
896 }
897 let norm = RmsNorm::new(hidden_size, config.text_config.rms_norm_eps, vb.pp("norm"))?;
898 Ok(Self {
899 latents,
900 layers,
901 norm,
902 n_latents,
903 })
904 }
905
906 fn forward(&self, context: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
907 let mut s = vec![context.dim(0)?];
908 s.extend(self.latents.dims());
909 let latents = self.latents.unsqueeze(0)?.expand(s)?;
910
911 let latent_attention_mask = Tensor::ones(
912 (attention_mask.dim(0)?, latents.dim(1)?),
913 attention_mask.dtype(),
914 attention_mask.device(),
915 )?;
916 let attention_mask = Tensor::cat(&[attention_mask, &latent_attention_mask], D::Minus1)?;
917 let attention_mask =
918 CausalMasker.expand_mask(&attention_mask, latents.dtype(), Some(self.n_latents))?;
919
920 let mut compressed_context = latents;
921 for perceiver_layer in &self.layers {
922 compressed_context =
923 perceiver_layer.forward(&compressed_context, context, &attention_mask)?;
924 }
925 self.norm.forward(&compressed_context)
926 }
927
928 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
929 let uvb = UnVarBuilder::new();
930
931 uvb.pp("norm").add(&self.norm);
932 uvb.add_tensor("latents", self.latents.clone());
933
934 for (i, layer) in self.layers.iter().enumerate() {
935 let uvb_l = uvb.pp("layers").pp(i);
936
937 uvb_l
938 .pp("input_latents_norm")
939 .add(&layer.input_latents_norm);
940 uvb_l
941 .pp("input_context_norm")
942 .add(&layer.input_context_norm);
943 uvb_l
944 .pp("post_attention_layernorm")
945 .add(&layer.post_attn_norm);
946 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
947 uvb_l
948 .pp("self_attn")
949 .extend(layer.self_attn.residual_tensors());
950 }
951
952 uvb.to_safetensors()
953 }
954}
955
956struct Connector {
957 modality_projection: Mlp,
958 perceiver_resampler: PerceiverResampler,
959}
960
961impl Connector {
962 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
963 let modality_projection = Mlp::new(
964 config.vision_config.hidden_size,
965 config.text_config.intermediate_size,
966 config.text_config.hidden_size,
967 config.text_config.hidden_act,
968 vb.pp("modality_projection"),
969 )?;
970 let perceiver_resampler = PerceiverResampler::new(config, vb.pp("perceiver_resampler"))?;
971 Ok(Self {
972 modality_projection,
973 perceiver_resampler,
974 })
975 }
976
977 fn forward(&self, image_hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
978 let image_hidden_states = self.modality_projection.forward(image_hidden_states)?;
979 self.perceiver_resampler
980 .forward(&image_hidden_states, attention_mask)
981 }
982
983 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
984 let uvb = UnVarBuilder::new();
985
986 uvb.pp("modality_projection")
987 .extend(self.modality_projection.residual_tensors());
988 uvb.pp("perceiver_resampler")
989 .extend(self.perceiver_resampler.residual_tensors());
990
991 uvb.to_safetensors()
992 }
993}
994
995pub struct Idefics2 {
1000 vision_model: VisionTransformer,
1001 connector: Connector,
1002 text_model: Mistral,
1003 dtype: DType,
1004 config: Config,
1005}
1006
1007impl Idefics2 {
1008 pub fn new(
1009 config: &Config,
1010 vb: ShardedVarBuilder,
1011 is_gptx: bool,
1012 normal_loading_metadata: NormalLoadingMetadata,
1013 attention_mechanism: AttentionImplementation,
1014 ) -> Result<Self> {
1015 let vb_m = vb.pp("model");
1016 let text_model = Mistral::new_inner(
1017 &config.text_config.clone().into(),
1018 vb_m.pp("text_model"),
1019 vb.pp("lm_head"),
1020 is_gptx,
1021 normal_loading_metadata,
1022 attention_mechanism,
1023 )?;
1024 let vision_model = VisionTransformer::new(
1025 &config.vision_config,
1026 vb_m.pp("vision_model")
1027 .set_device(text_model.device().clone()),
1028 )?;
1029 let connector = Connector::new(
1030 config,
1031 vb_m.pp("connector").set_device(text_model.device().clone()),
1032 )?;
1033 Ok(Self {
1034 vision_model,
1035 connector,
1036 text_model,
1037 dtype: vb.dtype(),
1038 config: config.clone(),
1039 })
1040 }
1041
1042 fn inputs_merger(
1043 &self,
1044 input_ids: &Tensor,
1045 input_embeds: &Tensor,
1046 image_hidden_states: &Tensor,
1047 ) -> Result<Tensor> {
1048 let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
1059 let bs = input_ids.dim(0)?;
1060 let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
1061 let mut new_inputs_embeds = input_embeds.clone();
1062 let reshaped_image_hidden_states =
1063 image_hidden_states.reshape((bs, (), vision_hidden_size))?;
1064 assert_eq!(input_embeds.dim(0)?, 1);
1065 assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
1066 let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
1067 let mut image_hidden_state_i = 0;
1068 for (i, v) in special_image_token_mask.iter().enumerate() {
1069 if *v != 0 {
1070 new_inputs_embeds = new_inputs_embeds.slice_assign(
1071 &[&.., &i, &..],
1072 &reshaped_image_hidden_states
1073 .i((.., image_hidden_state_i, ..))?
1074 .unsqueeze(1)?,
1075 )?;
1076 image_hidden_state_i += 1;
1077 }
1078 }
1079 Ok(new_inputs_embeds)
1080 }
1081
1082 #[allow(clippy::too_many_arguments)]
1083 fn forward_inner(
1084 &self,
1085 input_ids: &Tensor,
1086 pixel_values: Option<Tensor>,
1087 seqlen_offsets: &[usize],
1088 context_lens: Vec<(usize, usize)>,
1089 pixel_attention_mask: Option<Tensor>,
1090 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1091 flash_params: &FlashParams,
1092 ) -> Result<Tensor> {
1093 let input_embeds = if let Some(pixel_values) = pixel_values {
1094 let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
1096 let mut s = vec![batch_size * num_images];
1097 s.extend(pixel_values.dims()[2..].to_vec());
1098 let pixel_values = pixel_values.reshape(s)?;
1099
1100 let nb_values_per_image = pixel_values.dims()[1..].iter().product::<usize>();
1102 let real_images_inds = pixel_values
1103 .eq(0.0f64)?
1104 .sum(vec![
1105 pixel_values.dims().len() - 1,
1106 pixel_values.dims().len() - 2,
1107 pixel_values.dims().len() - 3,
1108 ])?
1109 .ne(nb_values_per_image as f64)?;
1110 let mut batches = Vec::new();
1111 for (batch, use_it) in pixel_values
1112 .chunk(pixel_values.dim(0)?, 0)?
1113 .iter()
1114 .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1115 {
1116 let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1117 if use_it {
1118 batches.push(batch.clone());
1119 }
1120 }
1121 let pixel_values = Tensor::cat(&batches, 0)?;
1122
1123 let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
1125 let pixel_attention_mask = pixel_attention_mask.reshape((
1126 batch_size * num_images,
1127 pixel_attention_mask.dims()[2],
1128 pixel_attention_mask.dims()[3],
1129 ))?;
1130 let mut batches = Vec::new();
1131 for (batch, use_it) in pixel_attention_mask
1132 .chunk(pixel_attention_mask.dim(0)?, 0)?
1133 .iter()
1134 .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1135 {
1136 let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1137 if use_it {
1138 batches.push(batch.clone());
1139 }
1140 }
1141 Tensor::cat(&batches, 0)?
1142 } else {
1143 Tensor::ones(
1144 (
1145 pixel_values.dims()[0],
1146 pixel_values.dims()[2],
1147 pixel_values.dims()[3],
1148 ),
1149 DType::U8,
1150 pixel_values.device(),
1151 )?
1152 };
1153
1154 let patch_size = self.config.vision_config.patch_size;
1155 let patches_subgrid = pixel_attention_mask.unfold(1, patch_size, patch_size)?;
1156 let patches_subgrid = patches_subgrid.unfold(2, patch_size, patch_size)?;
1157
1158 let patch_attention_mask = patches_subgrid
1159 .sum((D::Minus1, D::Minus2))?
1160 .eq((patch_size * patch_size) as f64)?
1161 .to_dtype(DType::U8)?;
1162
1163 let pixel_values = pixel_values.to_dtype(self.dtype)?;
1164
1165 let image_hidden_states = self
1167 .vision_model
1168 .forward(&pixel_values, Some(&patch_attention_mask))?;
1169
1170 let image_hidden_states = self.connector.forward(
1172 &image_hidden_states,
1173 &patch_attention_mask.reshape((pixel_values.dim(0)?, ()))?,
1174 )?;
1175
1176 if self.text_model.cache.normal().0[0].current_seq_len() == 0 {
1177 self.inputs_merger(
1178 input_ids,
1179 &self.text_model.get_input_embeddings(input_ids)?,
1180 &image_hidden_states,
1181 )?
1182 } else {
1183 candle_core::bail!("Pixel values were specified for a non-prompt.")
1184 }
1185 } else {
1186 self.text_model.get_input_embeddings(input_ids)?
1187 };
1188
1189 self.text_model.forward_embeds(
1190 input_ids,
1191 input_embeds,
1192 seqlen_offsets,
1193 context_lens,
1194 metadata,
1195 flash_params,
1196 )
1197 }
1198}
1199
1200impl IsqModel for Idefics2 {
1201 fn get_layers(
1202 &mut self,
1203 ) -> (
1204 Vec<(
1205 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
1206 Option<usize>,
1207 )>,
1208 &dyn DeviceMapper,
1209 ) {
1210 self.text_model.get_layers()
1211 }
1212
1213 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1214 let uvb = UnVarBuilder::new();
1215
1216 let uvb_m = uvb.pp("model");
1217 uvb_m
1218 .pp("text_model")
1219 .extend(self.text_model.residual_tensors());
1220 uvb_m
1221 .pp("vision_model")
1222 .extend(self.vision_model.residual_tensors());
1223 uvb_m
1224 .pp("connector")
1225 .extend(self.connector.residual_tensors());
1226
1227 uvb.to_safetensors()
1228 }
1229}
1230
1231impl AnyMoeBaseModelMixin for Idefics2 {
1233 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1234 self.text_model.get_mlps()
1235 }
1236 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1237 self.text_model.get_mlps_mut()
1238 }
1239 fn create_anymoe_layers(
1240 &mut self,
1241 additional_vbs: Vec<ShardedVarBuilder>,
1242 config: AnyMoeConfig,
1243 (prefix, mlp): (String, String),
1244 layers: Vec<usize>,
1245 expert_type: AnyMoeExpertType,
1246 gate_vb: Option<ShardedVarBuilder>,
1247 ) -> Result<()> {
1248 self.text_model.create_anymoe_layers(
1249 additional_vbs,
1250 config,
1251 (prefix, mlp),
1252 layers,
1253 expert_type,
1254 gate_vb,
1255 )
1256 }
1257 fn amoe_supported(&self) -> bool {
1258 true
1259 }
1260}
1261
1262impl VisionModel for Idefics2 {
1263 fn forward(
1264 &self,
1265 input_ids: &Tensor,
1266 pixel_values: Option<Tensor>,
1267 seqlen_offsets: &[usize],
1268 context_lens: Vec<(usize, usize)>,
1269 _: Vec<usize>, model_specific_args: Box<dyn Any>,
1271 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1272 flash_params: &FlashParams,
1273 ) -> candle_core::Result<Tensor> {
1274 let pixel_attention_mask: Option<Tensor> = *model_specific_args
1275 .downcast()
1276 .expect("Cannot downcast into `Option<Tensor>`");
1277 self.forward_inner(
1278 input_ids,
1279 pixel_values,
1280 seqlen_offsets,
1281 context_lens,
1282 pixel_attention_mask,
1283 metadata,
1284 flash_params,
1285 )
1286 }
1287 fn cache(&self) -> &EitherCache {
1288 self.text_model.cache()
1289 }
1290 fn cache_mut(&mut self) -> &mut EitherCache {
1291 self.text_model.cache_mut()
1292 }
1293 fn device(&self) -> &Device {
1294 self.text_model.device()
1295 }
1296 fn max_seq_len(&self) -> usize {
1297 self.text_model.max_seq_len()
1298 }
1299 fn has_conv2d(&self) -> bool {
1300 true
1301 }
1302 fn config(&self) -> &ModelConfigMetadata {
1303 self.text_model.config()
1304 }
1305 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1306 let args: Option<Tensor> = None;
1307 Box::new(args)
1308 }
1309}