1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3pub(crate) mod idefics2_input_processor;
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
7use mistralrs_quant::ShardedVarBuilder;
8use serde::Deserialize;
9use std::{any::Any, ops::Mul};
10
11use crate::{
12 amoe::{AnyMoeBaseModelMixin, MlpLayer},
13 device_map::DeviceMapper,
14 layers::{
15 conv2d, embedding, layer_norm, linear, linear_no_bias, repeat_kv, Activation, CausalMasker,
16 MatMul, QLinear, RmsNorm,
17 },
18 models::mistral::Model as Mistral,
19 paged_attention::{AttentionImplementation, ModelConfigMetadata},
20 pipeline::{
21 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22 EitherCache, IsqModel, NormalLoadingMetadata, NormalModel, VisionModel,
23 },
24 utils::unvarbuilder::UnVarBuilder,
25 AnyMoeConfig, AnyMoeExpertType,
26};
27
28use crate::models::mistral;
29
30fn default_32000() -> usize {
33 32000
34}
35fn default_32001() -> usize {
36 32001
37}
38fn default_4096() -> usize {
39 4096
40}
41fn default_14336() -> usize {
42 14336
43}
44fn default_32() -> usize {
45 32
46}
47fn default_8() -> usize {
48 8
49}
50fn default_act() -> Activation {
51 Activation::Silu
52}
53fn default_131072() -> usize {
54 131072
55}
56fn default_eps() -> f64 {
57 1e-6
58}
59fn default_rope() -> f64 {
60 10000.0
61}
62fn default_false() -> bool {
63 false
64}
65fn default_sliding() -> Option<usize> {
66 Some(4096)
67}
68fn default_gelu() -> Activation {
69 Activation::GeluPytorchTanh
70}
71fn default_64() -> usize {
72 64
73}
74fn default_3() -> usize {
75 3
76}
77fn default_16() -> usize {
78 16
79}
80fn default_96() -> usize {
81 96
82}
83fn default_4() -> usize {
84 4
85}
86fn default_0_0() -> f32 {
87 0.0
88}
89fn default_0_02() -> f32 {
90 0.02
91}
92fn default_768() -> usize {
93 768
94}
95fn default_3072() -> usize {
96 3072
97}
98fn default_12() -> usize {
99 12
100}
101fn default_224() -> usize {
102 224
103}
104
105#[derive(Debug, Clone, PartialEq, Deserialize)]
106pub struct PerceiverConfig {
107 #[serde(default = "default_act")]
108 pub hidden_act: Activation,
109 #[serde(default = "default_64")]
110 pub resampler_n_latents: usize,
111 #[serde(default = "default_3")]
112 pub resampler_depth: usize,
113 #[serde(default = "default_16")]
114 pub resampler_n_heads: usize,
115 #[serde(default = "default_96")]
116 pub resampler_head_dim: usize,
117 #[serde(default = "default_4")]
118 pub num_key_value_heads: usize,
119 #[serde(default = "default_0_0")]
120 pub attention_dropout: f32,
121}
122
123#[derive(Debug, Clone, PartialEq, Deserialize)]
124pub struct VisionConfig {
125 #[serde(default = "default_768")]
126 pub hidden_size: usize,
127 #[serde(default = "default_3072")]
128 pub intermediate_size: usize,
129 #[serde(default = "default_12")]
130 pub num_hidden_layers: usize,
131 #[serde(default = "default_12")]
132 pub num_attention_heads: usize,
133 #[serde(default = "default_3")]
134 pub num_channels: usize,
135 #[serde(default = "default_224")]
136 pub image_size: usize,
137 #[serde(default = "default_32")]
138 pub patch_size: usize,
139 #[serde(default = "default_gelu")]
140 pub hidden_act: Activation,
141 #[serde(default = "default_eps")]
142 pub layer_norm_eps: f64,
143 #[serde(default = "default_0_0")]
144 pub attn_dropout: f32,
145 #[serde(default = "default_0_02")]
146 pub initiailizer_range: f32,
147}
148
149#[derive(Debug, Clone, PartialEq, Deserialize)]
150pub(crate) struct TextConfig {
151 #[serde(default = "default_32000")]
152 pub(crate) vocab_size: usize,
153 #[serde(default = "default_4096")]
154 pub(crate) hidden_size: usize,
155 #[serde(default = "default_14336")]
156 pub(crate) intermediate_size: usize,
157 #[serde(default = "default_32")]
158 pub(crate) num_hidden_layers: usize,
159 #[serde(default = "default_32")]
160 pub(crate) num_attention_heads: usize,
161 #[serde(default = "default_8")]
162 pub(crate) num_key_value_heads: usize,
163 #[serde(default = "default_act")]
164 pub(crate) hidden_act: Activation,
165 #[serde(default = "default_131072")]
166 pub(crate) max_position_embeddings: usize,
167 #[serde(default = "default_eps")]
168 pub(crate) rms_norm_eps: f64,
169 #[serde(default = "default_rope")]
170 pub(crate) rope_theta: f64,
171 #[serde(default = "default_sliding")]
172 pub(crate) sliding_window: Option<usize>,
173
174 #[serde(default = "default_false")]
175 pub(crate) use_flash_attn: bool,
176 model_type: String, }
178
179impl From<TextConfig> for mistral::Config {
180 fn from(val: TextConfig) -> Self {
181 mistral::Config {
182 vocab_size: val.vocab_size,
183 hidden_act: val.hidden_act,
184 hidden_size: val.hidden_size,
185 intermediate_size: val.intermediate_size,
186 num_hidden_layers: val.num_hidden_layers,
187 num_attention_heads: val.num_attention_heads,
188 num_key_value_heads: val.num_key_value_heads,
189 max_position_embeddings: val.max_position_embeddings,
190 rms_norm_eps: val.rms_norm_eps,
191 rope_theta: val.rope_theta,
192 sliding_window: val.sliding_window,
193 use_flash_attn: val.use_flash_attn,
194 head_dim: None,
195 quantization_config: None,
196 tie_word_embeddings: false,
197 }
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Deserialize)]
202pub(crate) struct Config {
203 pub perceiver_config: PerceiverConfig,
204 pub vision_config: VisionConfig,
205 pub(crate) text_config: TextConfig,
206 #[serde(default = "default_32001")]
207 pub image_token_id: usize,
208 #[serde(default = "default_false")]
209 pub tie_word_embeddings: bool,
210}
211
212struct VisionEmbeddings {
215 patch_size: usize,
216 patch_embedding: Conv2d,
217 num_patches_per_side: usize,
218 position_embedding: Embedding,
219}
220
221fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
224 use std::cmp::Ordering;
225
226 let mut result = Vec::with_capacity(xs.len());
227
228 for &x in xs {
229 let idx = match boundaries.binary_search_by(|&val| {
238 val.partial_cmp(&x).unwrap_or(Ordering::Less)
241 }) {
242 Ok(i) => i,
243 Err(i) => i,
244 };
245
246 result.push(idx as u32);
247 }
248
249 Tensor::from_vec(result, (xs.len(),), device)
250}
251
252impl VisionEmbeddings {
253 fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
254 let conv_config = Conv2dConfig {
255 stride: config.patch_size,
256 ..Default::default()
257 };
258 let patch_embedding = conv2d(
259 config.num_channels,
260 config.hidden_size,
261 config.patch_size,
262 conv_config,
263 vb.pp("patch_embedding"),
264 )?;
265 let num_patches_per_side = config.image_size / config.patch_size;
266 let num_patches = num_patches_per_side.pow(2);
267 Ok(Self {
268 patch_size: config.patch_size,
269 patch_embedding,
270 num_patches_per_side,
271 position_embedding: embedding(
272 num_patches,
273 config.hidden_size,
274 vb.pp("position_embedding"),
275 )?,
276 })
277 }
278
279 fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
280 let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
281
282 let patch_embeds = self.patch_embedding.forward(pixel_values)?;
283
284 let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
285
286 let (max_nb_patches_h, max_nb_patches_w) =
287 (max_im_h / self.patch_size, max_im_w / self.patch_size);
288 let boundaries = Tensor::arange_step(
289 1.0 / self.num_patches_per_side as f32,
290 1.0,
291 1.0 / self.num_patches_per_side as f32,
292 pixel_values.device(),
293 )?
294 .to_vec1::<f32>()?;
295 let position_ids = Tensor::full(
296 0u32,
297 (bs, max_nb_patches_h * max_nb_patches_w),
298 pixel_values.device(),
299 )?;
300
301 let mut new_position_ids = Vec::new();
302 for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
303 let p_attn_mask = p_attn_mask.squeeze(0)?;
304 let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
305 let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
306
307 let fractional_coords_h = Tensor::arange_step(
308 0.0,
309 1.0 - 1e-6,
310 1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
311 pixel_values.device(),
312 )?
313 .to_vec1::<f32>()?;
314 let fractional_coords_w = Tensor::arange_step(
315 0.0,
316 1.0 - 1e-6,
317 1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
318 pixel_values.device(),
319 )?
320 .to_vec1::<f32>()?;
321
322 let bucket_coords_h =
323 bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
324 let bucket_coords_w =
325 bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
326
327 let pos_ids = bucket_coords_h
328 .unsqueeze(D::Minus1)?
329 .mul(self.num_patches_per_side as f64)?
330 .broadcast_add(&bucket_coords_w)?
331 .flatten_all()?
332 .to_vec1::<u32>()?;
333
334 let true_indices = p_attn_mask
335 .flatten_all()?
336 .to_vec1::<u8>()?
337 .iter()
338 .enumerate()
339 .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
340 .collect::<Vec<_>>();
341 let position_ids_b = position_ids.i(b_idx)?;
342
343 let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
344 let new_position_ids_b_len = new_position_ids_b.len();
345 for (i, true_idx) in true_indices.into_iter().enumerate() {
346 new_position_ids_b[true_idx] = pos_ids[i];
347 }
348
349 new_position_ids.push(Tensor::from_vec(
350 new_position_ids_b,
351 new_position_ids_b_len,
352 pixel_values.device(),
353 )?);
354 }
355 let position_ids = Tensor::stack(&new_position_ids, 0)?;
356 let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
357 embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
358 }
359
360 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
361 let uvb = UnVarBuilder::new();
362
363 uvb.pp("patch_embedding").add(&self.patch_embedding);
364 uvb.pp("position_embedding").add(&self.position_embedding);
365
366 uvb.to_safetensors()
367 }
368}
369
370struct Attention {
371 embed_dim: usize,
372 num_heads: usize,
373 head_dim: usize,
374 scale: f64,
375 q_proj: QLinear,
376 k_proj: QLinear,
377 v_proj: QLinear,
378 o_proj: QLinear,
379 neg_inf: Tensor,
380}
381
382impl Attention {
383 fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
384 let embed_dim = config.hidden_size;
385 let num_heads = config.num_attention_heads;
386 let head_dim = embed_dim / num_heads;
387 let scale = 1.0 / (head_dim as f64).sqrt();
388
389 let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
390 let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
391 let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
392 let o_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
393
394 Ok(Self {
395 embed_dim,
396 num_heads,
397 head_dim,
398 scale,
399 q_proj: QLinear::from_linear(q_proj),
400 k_proj: QLinear::from_linear(k_proj),
401 v_proj: QLinear::from_linear(v_proj),
402 o_proj: QLinear::from_linear(o_proj),
403 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
404 })
405 }
406
407 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
408 let (b_sz, q_len, _) = xs.dims3()?;
409
410 let original_dtype = xs.dtype();
411 let mut xs = xs.clone();
412 if self.q_proj.is_quant() {
413 xs = xs.to_dtype(DType::F32)?;
414 }
415 let mut q = self.q_proj.forward(&xs)?;
416 let mut k = self.k_proj.forward(&xs)?;
417 let mut v = self.v_proj.forward(&xs)?;
418 if self.q_proj.is_quant() {
419 q = q.to_dtype(original_dtype)?;
420 k = k.to_dtype(original_dtype)?;
421 v = v.to_dtype(original_dtype)?;
422 }
423
424 let q = q
425 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
426 .transpose(1, 2)?;
427 let k = k
428 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
429 .transpose(1, 2)?;
430 let v = v
431 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
432 .transpose(1, 2)?;
433
434 let attn_weights =
435 (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
436
437 let attn_weights = CausalMasker.apply_mask_one_and_zero(
438 &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
439 attn_weights,
440 &self.neg_inf,
441 )?;
442 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
443 let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
444
445 if self.q_proj.is_quant() {
446 attn_output = attn_output.to_dtype(DType::F32)?;
447 }
448 let mut res = attn_output
449 .transpose(1, 2)?
450 .reshape((b_sz, q_len, self.embed_dim))?
451 .apply(&self.o_proj)?;
452 if self.q_proj.is_quant() {
453 res = res.to_dtype(original_dtype)?;
454 }
455 Ok(res)
456 }
457
458 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
459 let uvb = UnVarBuilder::new();
460
461 uvb.pp("q_proj").add(&self.q_proj);
462 uvb.pp("k_proj").add(&self.k_proj);
463 uvb.pp("v_proj").add(&self.v_proj);
464 uvb.pp("out_proj").add(&self.o_proj);
465
466 uvb.to_safetensors()
467 }
468}
469
470struct VisionMLP {
471 activation: Activation,
472 fc1: QLinear,
473 fc2: QLinear,
474}
475
476impl VisionMLP {
477 fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
478 let fc1 = linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
479 let fc2 = linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
480 Ok(Self {
481 activation: config.hidden_act,
482 fc1: QLinear::from_linear(fc1),
483 fc2: QLinear::from_linear(fc2),
484 })
485 }
486
487 fn forward(&self, x: &Tensor) -> Result<Tensor> {
488 let mut x = x.clone();
489 let original_dtype = x.dtype();
490 if self.fc1.is_quant() {
491 x = x.to_dtype(DType::F32)?;
492 }
493 let x = self.fc1.forward(&x)?;
494 let x = self.activation.forward(&x)?;
495 let mut res = self.fc2.forward(&x)?;
496 if self.fc1.is_quant() {
497 res = res.to_dtype(original_dtype)?;
498 }
499 Ok(res)
500 }
501
502 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
503 let uvb = UnVarBuilder::new();
504
505 uvb.pp("fc1").add(&self.fc1);
506 uvb.pp("fc2").add(&self.fc2);
507
508 uvb.to_safetensors()
509 }
510}
511
512struct EncoderLayer {
513 mlp: VisionMLP,
514 attn: Attention,
515 layer_norm_1: LayerNorm,
516 layer_norm_2: LayerNorm,
517}
518
519impl EncoderLayer {
520 fn new(config: VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
521 let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
522 let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
523 let layer_norm_1 = layer_norm(
524 config.hidden_size,
525 config.layer_norm_eps,
526 vb.pp("layer_norm1"),
527 )?;
528 let layer_norm_2 = layer_norm(
529 config.hidden_size,
530 config.layer_norm_eps,
531 vb.pp("layer_norm2"),
532 )?;
533 Ok(Self {
534 mlp,
535 attn,
536 layer_norm_1,
537 layer_norm_2,
538 })
539 }
540
541 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
542 let residual = xs.clone();
543
544 let hidden_states = self.layer_norm_1.forward(xs)?;
545 let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
546 let hidden_states = (hidden_states + residual)?;
547
548 let residual = &hidden_states;
549 let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
550 let hidden_states = self.mlp.forward(&hidden_states)?;
551 hidden_states + residual
552 }
553}
554
555struct Encoder {
556 layers: Vec<EncoderLayer>,
557}
558
559impl Encoder {
560 fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
561 let mut layers = Vec::new();
562 let vb_l = vb.pp("layers");
563 for i in 0..config.num_hidden_layers {
564 layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
565 }
566 Ok(Self { layers })
567 }
568
569 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
570 let mut hidden_states = xs.clone();
571 for layer in &self.layers {
572 hidden_states = layer.forward(&hidden_states, attention_mask)?;
573 }
574 Ok(hidden_states)
575 }
576}
577
578struct VisionTransformer {
579 embeddings: VisionEmbeddings,
580 encoder: Encoder,
581 post_layernorm: LayerNorm,
582 config: VisionConfig,
583}
584
585impl VisionTransformer {
586 fn new(config: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
587 let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
588 let post_layernorm = layer_norm(
589 config.hidden_size,
590 config.layer_norm_eps,
591 vb.pp("post_layernorm"),
592 )?;
593 let encoder = Encoder::new(config, vb.pp("encoder"))?;
594 Ok(Self {
595 embeddings,
596 encoder,
597 post_layernorm,
598 config: config.clone(),
599 })
600 }
601
602 fn forward(&self, pixel_values: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
603 let bs = pixel_values.dim(0)?;
604 let patch_attention_mask = if let Some(attn_mask) = attention_mask {
605 attn_mask.clone()
606 } else {
607 let patch_size = self.config.patch_size;
608 Tensor::ones(
609 (
610 bs,
611 pixel_values.dim(2)? / patch_size,
612 pixel_values.dim(3)? / patch_size,
613 ),
614 DType::U8,
615 pixel_values.device(),
616 )?
617 };
618
619 let hidden_states = self
620 .embeddings
621 .forward(pixel_values, &patch_attention_mask)?;
622
623 let attention_mask = if attention_mask.is_none() {
624 None
625 } else {
626 let mask = patch_attention_mask
627 .reshape((patch_attention_mask.dim(0)?, ()))?
628 .to_dtype(hidden_states.dtype())?;
629 Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
630 };
631 let hidden_states = self
632 .encoder
633 .forward(&hidden_states, attention_mask.as_ref())?;
634 hidden_states.apply(&self.post_layernorm)
635 }
636
637 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
638 let uvb = UnVarBuilder::new();
639
640 uvb.pp("post_layernorm").add(&self.post_layernorm);
641 uvb.pp("embeddings")
642 .extend(self.embeddings.residual_tensors());
643
644 let uvb_enc = uvb.pp("encoder");
645 for (i, layer) in self.encoder.layers.iter().enumerate() {
646 let uvb_l = uvb_enc.pp("layers").pp(i);
647
648 uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
649 uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
650 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
651 uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
652 }
653
654 uvb.to_safetensors()
655 }
656}
657
658struct Mlp {
662 gate_proj: QLinear,
663 up_proj: QLinear,
664 down_proj: QLinear,
665 activation: Activation,
666}
667
668impl Mlp {
669 fn new(
670 hidden_size: usize,
671 intermediate_size: usize,
672 output_size: usize,
673 activation: Activation,
674 vb: ShardedVarBuilder,
675 ) -> Result<Self> {
676 let gate_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("gate_proj"))?;
677 let up_proj = linear_no_bias(hidden_size, intermediate_size, vb.pp("up_proj"))?;
678 let down_proj = linear_no_bias(intermediate_size, output_size, vb.pp("down_proj"))?;
679 Ok(Self {
680 gate_proj: QLinear::from_linear(gate_proj),
681 up_proj: QLinear::from_linear(up_proj),
682 down_proj: QLinear::from_linear(down_proj),
683 activation,
684 })
685 }
686
687 fn forward(&self, x: &Tensor) -> Result<Tensor> {
688 let mut x = x.clone();
689 let original_dtype = x.dtype();
690 if self.gate_proj.is_quant() {
691 x = x.to_dtype(DType::F32)?;
692 }
693 let mut res = self.down_proj.forward(
694 &(self.activation.forward(&self.gate_proj.forward(&x)?)?
695 * self.up_proj.forward(&x)?)?,
696 )?;
697 if self.gate_proj.is_quant() {
698 res = res.to_dtype(original_dtype)?;
699 }
700 Ok(res)
701 }
702
703 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
704 let uvb = UnVarBuilder::new();
705
706 uvb.pp("gate_proj").add(&self.gate_proj);
707 uvb.pp("up_proj").add(&self.up_proj);
708 uvb.pp("down_proj").add(&self.down_proj);
709
710 uvb.to_safetensors()
711 }
712}
713
714struct PerceiverAttention {
715 num_heads: usize,
716 num_kv_heads: usize,
717 num_kv_groups: usize,
718 head_dim: usize,
719 q_proj: QLinear,
720 k_proj: QLinear,
721 v_proj: QLinear,
722 o_proj: QLinear,
723 neg_inf: Tensor,
724}
725
726impl PerceiverAttention {
727 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
728 let hidden_size = config.text_config.hidden_size;
729 let num_heads = config.perceiver_config.resampler_n_heads;
730 let head_dim = config.perceiver_config.resampler_head_dim;
731 let num_key_value_heads = config.perceiver_config.num_key_value_heads;
732 let num_key_value_groups = num_heads / num_key_value_heads;
733
734 let q_proj = linear_no_bias(hidden_size, num_heads * head_dim, vb.pp("q_proj"))?;
735 let k_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("k_proj"))?;
736 let v_proj = linear_no_bias(hidden_size, num_key_value_heads * head_dim, vb.pp("v_proj"))?;
737 let o_proj = linear_no_bias(num_heads * head_dim, hidden_size, vb.pp("o_proj"))?;
738
739 Ok(Self {
740 num_heads,
741 head_dim,
742 q_proj: QLinear::from_linear(q_proj),
743 k_proj: QLinear::from_linear(k_proj),
744 v_proj: QLinear::from_linear(v_proj),
745 o_proj: QLinear::from_linear(o_proj),
746 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
747 num_kv_heads: num_key_value_heads,
748 num_kv_groups: num_key_value_groups,
749 })
750 }
751
752 fn forward(
753 &self,
754 latents: &Tensor,
755 context: &Tensor,
756 attention_mask: &Tensor,
757 ) -> Result<Tensor> {
758 let (b_sz, q_len, _) = latents.dims3()?;
759 let kv_seq_len = q_len + context.dims()[1];
760
761 let mut hidden_states = Tensor::cat(&[context, latents], D::Minus2)?;
762
763 let original_dtype = latents.dtype();
764 let mut latents = latents.clone();
765 if self.q_proj.is_quant() {
766 latents = latents.to_dtype(DType::F32)?;
767 hidden_states = hidden_states.to_dtype(DType::F32)?;
768 }
769 let mut q = self.q_proj.forward(&latents)?;
770 let mut k = self.k_proj.forward(&hidden_states)?;
771 let mut v = self.v_proj.forward(&hidden_states)?;
772 if self.q_proj.is_quant() {
773 q = q.to_dtype(original_dtype)?;
774 k = k.to_dtype(original_dtype)?;
775 v = v.to_dtype(original_dtype)?;
776 }
777
778 let q = q
779 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
780 .transpose(1, 2)?;
781 let k = k
782 .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
783 .transpose(1, 2)?;
784 let v = v
785 .reshape((b_sz, kv_seq_len, self.num_kv_heads, self.head_dim))?
786 .transpose(1, 2)?;
787
788 let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
789 let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
790
791 let attn_weights = (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)?
792 / (self.head_dim as f64).sqrt())?;
793
794 let attn_weights = CausalMasker.apply_mask_one_and_zero(
795 &Some(attention_mask.to_dtype(DType::U8)?),
796 attn_weights,
797 &self.neg_inf,
798 )?;
799 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
800 let mut attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
801
802 if self.q_proj.is_quant() {
803 attn_output = attn_output.to_dtype(DType::F32)?;
804 }
805 let mut res = attn_output
806 .transpose(1, 2)?
807 .reshape((b_sz, q_len, self.num_heads * self.head_dim))?
808 .apply(&self.o_proj)?;
809 if self.q_proj.is_quant() {
810 res = res.to_dtype(original_dtype)?;
811 }
812 Ok(res)
813 }
814
815 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
816 let uvb = UnVarBuilder::new();
817
818 uvb.pp("q_proj").add(&self.q_proj);
819 uvb.pp("k_proj").add(&self.k_proj);
820 uvb.pp("v_proj").add(&self.v_proj);
821 uvb.pp("o_proj").add(&self.o_proj);
822
823 uvb.to_safetensors()
824 }
825}
826
827struct PerceiverLayer {
828 input_latents_norm: RmsNorm,
829 input_context_norm: RmsNorm,
830 self_attn: PerceiverAttention,
831 post_attn_norm: RmsNorm,
832 mlp: Mlp,
833}
834
835impl PerceiverLayer {
836 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
837 let hidden_size = config.text_config.hidden_size;
838 let mlp_act = config.perceiver_config.hidden_act;
839 let rms_eps = config.text_config.rms_norm_eps;
840
841 Ok(Self {
842 input_latents_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_latents_norm"))?,
843 input_context_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("input_context_norm"))?,
844 self_attn: PerceiverAttention::new(config, vb.pp("self_attn"))?,
845 post_attn_norm: RmsNorm::new(hidden_size, rms_eps, vb.pp("post_attention_layernorm"))?,
846 mlp: Mlp::new(
847 hidden_size,
848 hidden_size * 4,
849 hidden_size,
850 mlp_act,
851 vb.pp("mlp"),
852 )?,
853 })
854 }
855
856 fn forward(
857 &self,
858 latents: &Tensor,
859 context: &Tensor,
860 attention_mask: &Tensor,
861 ) -> Result<Tensor> {
862 let residual = latents;
863
864 let latents = self.input_latents_norm.forward(latents)?;
865 let context = self.input_context_norm.forward(context)?;
866
867 let latents = self.self_attn.forward(&latents, &context, attention_mask)?;
868 let latents = (residual + latents)?;
869 let residual = &latents;
870
871 let latents = self.post_attn_norm.forward(&latents)?;
872 let latents = self.mlp.forward(&latents)?;
873 residual + latents
874 }
875}
876
877struct PerceiverResampler {
878 latents: Tensor,
879 layers: Vec<PerceiverLayer>,
880 norm: RmsNorm,
881 n_latents: usize,
882}
883
884impl PerceiverResampler {
885 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
886 let n_latents = config.perceiver_config.resampler_n_latents;
887 let hidden_size = config.text_config.hidden_size;
888 let depth = config.perceiver_config.resampler_depth;
889
890 let latents = vb.get((n_latents, hidden_size), "latents")?;
891 let mut layers = Vec::new();
892 let vb_l = vb.pp("layers");
893 for i in 0..depth {
894 layers.push(PerceiverLayer::new(config, vb_l.pp(i))?);
895 }
896 let norm = RmsNorm::new(hidden_size, config.text_config.rms_norm_eps, vb.pp("norm"))?;
897 Ok(Self {
898 latents,
899 layers,
900 norm,
901 n_latents,
902 })
903 }
904
905 fn forward(&self, context: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
906 let mut s = vec![context.dim(0)?];
907 s.extend(self.latents.dims());
908 let latents = self.latents.unsqueeze(0)?.expand(s)?;
909
910 let latent_attention_mask = Tensor::ones(
911 (attention_mask.dim(0)?, latents.dim(1)?),
912 attention_mask.dtype(),
913 attention_mask.device(),
914 )?;
915 let attention_mask = Tensor::cat(&[attention_mask, &latent_attention_mask], D::Minus1)?;
916 let attention_mask =
917 CausalMasker.expand_mask(&attention_mask, latents.dtype(), Some(self.n_latents))?;
918
919 let mut compressed_context = latents;
920 for perceiver_layer in &self.layers {
921 compressed_context =
922 perceiver_layer.forward(&compressed_context, context, &attention_mask)?;
923 }
924 self.norm.forward(&compressed_context)
925 }
926
927 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
928 let uvb = UnVarBuilder::new();
929
930 uvb.pp("norm").add(&self.norm);
931 uvb.add_tensor("latents", self.latents.clone());
932
933 for (i, layer) in self.layers.iter().enumerate() {
934 let uvb_l = uvb.pp("layers").pp(i);
935
936 uvb_l
937 .pp("input_latents_norm")
938 .add(&layer.input_latents_norm);
939 uvb_l
940 .pp("input_context_norm")
941 .add(&layer.input_context_norm);
942 uvb_l
943 .pp("post_attention_layernorm")
944 .add(&layer.post_attn_norm);
945 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
946 uvb_l
947 .pp("self_attn")
948 .extend(layer.self_attn.residual_tensors());
949 }
950
951 uvb.to_safetensors()
952 }
953}
954
955struct Connector {
956 modality_projection: Mlp,
957 perceiver_resampler: PerceiverResampler,
958}
959
960impl Connector {
961 fn new(config: &Config, vb: ShardedVarBuilder) -> Result<Self> {
962 let modality_projection = Mlp::new(
963 config.vision_config.hidden_size,
964 config.text_config.intermediate_size,
965 config.text_config.hidden_size,
966 config.text_config.hidden_act,
967 vb.pp("modality_projection"),
968 )?;
969 let perceiver_resampler = PerceiverResampler::new(config, vb.pp("perceiver_resampler"))?;
970 Ok(Self {
971 modality_projection,
972 perceiver_resampler,
973 })
974 }
975
976 fn forward(&self, image_hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
977 let image_hidden_states = self.modality_projection.forward(image_hidden_states)?;
978 self.perceiver_resampler
979 .forward(&image_hidden_states, attention_mask)
980 }
981
982 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
983 let uvb = UnVarBuilder::new();
984
985 uvb.pp("modality_projection")
986 .extend(self.modality_projection.residual_tensors());
987 uvb.pp("perceiver_resampler")
988 .extend(self.perceiver_resampler.residual_tensors());
989
990 uvb.to_safetensors()
991 }
992}
993
994pub struct Idefics2 {
999 vision_model: VisionTransformer,
1000 connector: Connector,
1001 text_model: Mistral,
1002 dtype: DType,
1003 config: Config,
1004}
1005
1006impl Idefics2 {
1007 pub fn new(
1008 config: &Config,
1009 vb: ShardedVarBuilder,
1010 is_gptx: bool,
1011 normal_loading_metadata: NormalLoadingMetadata,
1012 attention_mechanism: AttentionImplementation,
1013 ) -> Result<Self> {
1014 let vb_m = vb.pp("model");
1015 let text_model = Mistral::new_inner(
1016 &config.text_config.clone().into(),
1017 vb_m.pp("text_model"),
1018 vb.pp("lm_head"),
1019 is_gptx,
1020 normal_loading_metadata,
1021 attention_mechanism,
1022 )?;
1023 let vision_model = VisionTransformer::new(
1024 &config.vision_config,
1025 vb_m.pp("vision_model")
1026 .set_device(text_model.device().clone()),
1027 )?;
1028 let connector = Connector::new(
1029 config,
1030 vb_m.pp("connector").set_device(text_model.device().clone()),
1031 )?;
1032 Ok(Self {
1033 vision_model,
1034 connector,
1035 text_model,
1036 dtype: vb.dtype(),
1037 config: config.clone(),
1038 })
1039 }
1040
1041 fn inputs_merger(
1042 &self,
1043 input_ids: &Tensor,
1044 input_embeds: &Tensor,
1045 image_hidden_states: &Tensor,
1046 ) -> Result<Tensor> {
1047 let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
1058 let bs = input_ids.dim(0)?;
1059 let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
1060 let mut new_inputs_embeds = input_embeds.clone();
1061 let reshaped_image_hidden_states =
1062 image_hidden_states.reshape((bs, (), vision_hidden_size))?;
1063 assert_eq!(input_embeds.dim(0)?, 1);
1064 assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
1065 let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
1066 let mut image_hidden_state_i = 0;
1067 for (i, v) in special_image_token_mask.iter().enumerate() {
1068 if *v != 0 {
1069 new_inputs_embeds = new_inputs_embeds.slice_assign(
1070 &[&.., &i, &..],
1071 &reshaped_image_hidden_states
1072 .i((.., image_hidden_state_i, ..))?
1073 .unsqueeze(1)?,
1074 )?;
1075 image_hidden_state_i += 1;
1076 }
1077 }
1078 Ok(new_inputs_embeds)
1079 }
1080
1081 #[allow(clippy::too_many_arguments)]
1082 fn forward_inner(
1083 &self,
1084 input_ids: &Tensor,
1085 pixel_values: Option<Tensor>,
1086 seqlen_offsets: &[usize],
1087 context_lens: Vec<(usize, usize)>,
1088 pixel_attention_mask: Option<Tensor>,
1089 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1090 flash_params: &FlashParams,
1091 ) -> Result<Tensor> {
1092 let input_embeds = if let Some(pixel_values) = pixel_values {
1093 let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
1095 let mut s = vec![batch_size * num_images];
1096 s.extend(pixel_values.dims()[2..].to_vec());
1097 let pixel_values = pixel_values.reshape(s)?;
1098
1099 let nb_values_per_image = pixel_values.dims()[1..].iter().product::<usize>();
1101 let real_images_inds = pixel_values
1102 .eq(0.0f64)?
1103 .sum(vec![
1104 pixel_values.dims().len() - 1,
1105 pixel_values.dims().len() - 2,
1106 pixel_values.dims().len() - 3,
1107 ])?
1108 .ne(nb_values_per_image as f64)?;
1109 let mut batches = Vec::new();
1110 for (batch, use_it) in pixel_values
1111 .chunk(pixel_values.dim(0)?, 0)?
1112 .iter()
1113 .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1114 {
1115 let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1116 if use_it {
1117 batches.push(batch.clone());
1118 }
1119 }
1120 let pixel_values = Tensor::cat(&batches, 0)?;
1121
1122 let pixel_attention_mask = if let Some(pixel_attention_mask) = pixel_attention_mask {
1124 let pixel_attention_mask = pixel_attention_mask.reshape((
1125 batch_size * num_images,
1126 pixel_attention_mask.dims()[2],
1127 pixel_attention_mask.dims()[3],
1128 ))?;
1129 let mut batches = Vec::new();
1130 for (batch, use_it) in pixel_attention_mask
1131 .chunk(pixel_attention_mask.dim(0)?, 0)?
1132 .iter()
1133 .zip(real_images_inds.chunk(real_images_inds.dim(0)?, 0)?)
1134 {
1135 let use_it = use_it.squeeze(0)?.to_scalar::<u8>()? != 0;
1136 if use_it {
1137 batches.push(batch.clone());
1138 }
1139 }
1140 Tensor::cat(&batches, 0)?
1141 } else {
1142 Tensor::ones(
1143 (
1144 pixel_values.dims()[0],
1145 pixel_values.dims()[2],
1146 pixel_values.dims()[3],
1147 ),
1148 DType::U8,
1149 pixel_values.device(),
1150 )?
1151 };
1152
1153 let patch_size = self.config.vision_config.patch_size;
1154 let patches_subgrid = pixel_attention_mask.unfold(1, patch_size, patch_size)?;
1155 let patches_subgrid = patches_subgrid.unfold(2, patch_size, patch_size)?;
1156
1157 let patch_attention_mask = patches_subgrid
1158 .sum((D::Minus1, D::Minus2))?
1159 .eq((patch_size * patch_size) as f64)?
1160 .to_dtype(DType::U8)?;
1161
1162 let pixel_values = pixel_values.to_dtype(self.dtype)?;
1163
1164 let image_hidden_states = self
1166 .vision_model
1167 .forward(&pixel_values, Some(&patch_attention_mask))?;
1168
1169 let image_hidden_states = self.connector.forward(
1171 &image_hidden_states,
1172 &patch_attention_mask.reshape((pixel_values.dim(0)?, ()))?,
1173 )?;
1174
1175 if self.text_model.cache.normal().0[0].current_seq_len() == 0 {
1176 self.inputs_merger(
1177 input_ids,
1178 &self.text_model.get_input_embeddings(input_ids)?,
1179 &image_hidden_states,
1180 )?
1181 } else {
1182 candle_core::bail!("Pixel values were specified for a non-prompt.")
1183 }
1184 } else {
1185 self.text_model.get_input_embeddings(input_ids)?
1186 };
1187
1188 self.text_model.forward_embeds(
1189 input_ids,
1190 input_embeds,
1191 seqlen_offsets,
1192 context_lens,
1193 metadata,
1194 flash_params,
1195 )
1196 }
1197}
1198
1199impl IsqModel for Idefics2 {
1200 fn get_layers(
1201 &mut self,
1202 ) -> (
1203 Vec<(
1204 &mut std::sync::Arc<dyn mistralrs_quant::QuantMethod>,
1205 Option<usize>,
1206 )>,
1207 &dyn DeviceMapper,
1208 ) {
1209 self.text_model.get_layers()
1210 }
1211
1212 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1213 let uvb = UnVarBuilder::new();
1214
1215 let uvb_m = uvb.pp("model");
1216 uvb_m
1217 .pp("text_model")
1218 .extend(self.text_model.residual_tensors());
1219 uvb_m
1220 .pp("vision_model")
1221 .extend(self.vision_model.residual_tensors());
1222 uvb_m
1223 .pp("connector")
1224 .extend(self.connector.residual_tensors());
1225
1226 uvb.to_safetensors()
1227 }
1228}
1229
1230impl AnyMoeBaseModelMixin for Idefics2 {
1232 fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
1233 self.text_model.get_mlps()
1234 }
1235 fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
1236 self.text_model.get_mlps_mut()
1237 }
1238 fn create_anymoe_layers(
1239 &mut self,
1240 additional_vbs: Vec<ShardedVarBuilder>,
1241 config: AnyMoeConfig,
1242 (prefix, mlp): (String, String),
1243 layers: Vec<usize>,
1244 expert_type: AnyMoeExpertType,
1245 gate_vb: Option<ShardedVarBuilder>,
1246 ) -> Result<()> {
1247 self.text_model.create_anymoe_layers(
1248 additional_vbs,
1249 config,
1250 (prefix, mlp),
1251 layers,
1252 expert_type,
1253 gate_vb,
1254 )
1255 }
1256 fn amoe_supported(&self) -> bool {
1257 true
1258 }
1259}
1260
1261impl VisionModel for Idefics2 {
1262 fn forward(
1263 &self,
1264 input_ids: &Tensor,
1265 pixel_values: Option<Tensor>,
1266 seqlen_offsets: &[usize],
1267 context_lens: Vec<(usize, usize)>,
1268 _: Vec<usize>, model_specific_args: Box<dyn Any>,
1270 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1271 flash_params: &FlashParams,
1272 ) -> candle_core::Result<Tensor> {
1273 let pixel_attention_mask: Option<Tensor> = *model_specific_args
1274 .downcast()
1275 .expect("Cannot downcast into `Option<Tensor>`");
1276 self.forward_inner(
1277 input_ids,
1278 pixel_values,
1279 seqlen_offsets,
1280 context_lens,
1281 pixel_attention_mask,
1282 metadata,
1283 flash_params,
1284 )
1285 }
1286 fn cache(&self) -> &EitherCache {
1287 self.text_model.cache()
1288 }
1289 fn cache_mut(&mut self) -> &mut EitherCache {
1290 self.text_model.cache_mut()
1291 }
1292 fn device(&self) -> &Device {
1293 self.text_model.device()
1294 }
1295 fn max_seq_len(&self) -> usize {
1296 self.text_model.max_seq_len()
1297 }
1298 fn has_conv2d(&self) -> bool {
1299 true
1300 }
1301 fn config(&self) -> &ModelConfigMetadata {
1302 self.text_model.config()
1303 }
1304 fn default_model_specific_args(&self, _input_ids: &Tensor) -> Box<dyn Any> {
1305 let args: Option<Tensor> = None;
1306 Box::new(args)
1307 }
1308}