1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module};
3use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
4use std::{ops::Mul, sync::Arc};
5
6use crate::{
7 attention::SdpaParams,
8 layers::{self, conv2d, embedding, layer_norm, Activation, CausalMasker, Sdpa},
9 utils::unvarbuilder::UnVarBuilder,
10};
11
12use super::config::{Idefics3Config, Idefics3VisionConfig};
13
14pub(crate) struct Idefics3SimpleMLP {
15 pub(crate) proj: Linear,
16}
17
18impl Idefics3SimpleMLP {
19 pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
20 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
21 let out_dim = cfg.text_config.hidden_size;
22 Ok(Self {
23 proj: layers::linear_no_bias(in_dim, out_dim, vb.pp("proj"))?,
24 })
25 }
26
27 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
28 x.apply(&self.proj)
29 }
30}
31
32pub struct Idefics3Connector {
33 scale_factor: usize,
34 pub(crate) modality_projection: Idefics3SimpleMLP,
35}
36
37impl Idefics3Connector {
38 pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
39 Ok(Self {
40 scale_factor: cfg.scale_factor,
41 modality_projection: Idefics3SimpleMLP::new(cfg, vb.pp("modality_projection"))?,
42 })
43 }
44
45 pub fn pixel_shuffle(&self, x: &Tensor, scale_factor: usize) -> Result<Tensor> {
46 let (bs, seq, embed_dim) = x.dims3()?;
47 let height = (seq as f32).sqrt() as usize;
48 let width = height;
49 let mut x = x.reshape((bs, height, width, embed_dim))?;
50 x = x.reshape((bs, height, width / scale_factor, embed_dim * scale_factor))?;
51 x = x.permute((0, 2, 1, 3))?;
52 x = x.reshape((
53 bs,
54 width / scale_factor,
55 height / scale_factor,
56 embed_dim * scale_factor.pow(2),
57 ))?;
58 x = x.permute((0, 2, 1, 3))?;
59 x.reshape((
60 bs,
61 (seq as f32 / scale_factor.pow(2) as f32) as usize,
62 embed_dim * scale_factor.pow(2),
63 ))
64 }
65
66 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
67 let image_hidden_states = self.pixel_shuffle(x, self.scale_factor)?;
68 self.modality_projection.forward(&image_hidden_states)
69 }
70}
71
72struct VisionEmbeddings {
73 patch_size: usize,
74 patch_embedding: Conv2d,
75 num_patches_per_side: usize,
76 position_embedding: Embedding,
77}
78
79fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
82 use std::cmp::Ordering;
83
84 let mut result = Vec::with_capacity(xs.len());
85
86 for &x in xs {
87 let idx = match boundaries.binary_search_by(|&val| {
96 val.partial_cmp(&x).unwrap_or(Ordering::Less)
99 }) {
100 Ok(i) => i,
101 Err(i) => i,
102 };
103
104 result.push(idx as u32);
105 }
106
107 Tensor::from_vec(result, (xs.len(),), device)
108}
109
110impl VisionEmbeddings {
111 fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
112 let conv_config = Conv2dConfig {
113 stride: config.patch_size,
114 ..Default::default()
115 };
116 let patch_embedding = conv2d(
117 config.num_channels,
118 config.hidden_size,
119 config.patch_size,
120 conv_config,
121 vb.pp("patch_embedding"),
122 )?;
123 let num_patches_per_side = config.image_size / config.patch_size;
124 let num_patches = num_patches_per_side.pow(2);
125 Ok(Self {
126 patch_size: config.patch_size,
127 patch_embedding,
128 num_patches_per_side,
129 position_embedding: embedding(
130 num_patches,
131 config.hidden_size,
132 vb.pp("position_embedding"),
133 &None,
134 )?,
135 })
136 }
137
138 fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
139 let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
140
141 let patch_embeds = self.patch_embedding.forward(pixel_values)?;
142
143 let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
144
145 let (max_nb_patches_h, max_nb_patches_w) =
146 (max_im_h / self.patch_size, max_im_w / self.patch_size);
147 let boundaries = Tensor::arange_step(
148 1.0 / self.num_patches_per_side as f32,
149 1.0,
150 1.0 / self.num_patches_per_side as f32,
151 pixel_values.device(),
152 )?
153 .to_vec1::<f32>()?;
154 let position_ids = Tensor::full(
155 0u32,
156 (bs, max_nb_patches_h * max_nb_patches_w),
157 pixel_values.device(),
158 )?;
159
160 let mut new_position_ids = Vec::new();
161 for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
162 let p_attn_mask = p_attn_mask.squeeze(0)?;
163 let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
164 let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
165
166 let fractional_coords_h = Tensor::arange_step(
167 0.0,
168 1.0 - 1e-6,
169 1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
170 pixel_values.device(),
171 )?
172 .to_vec1::<f32>()?;
173 let fractional_coords_w = Tensor::arange_step(
174 0.0,
175 1.0 - 1e-6,
176 1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
177 pixel_values.device(),
178 )?
179 .to_vec1::<f32>()?;
180
181 let bucket_coords_h =
182 bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
183 let bucket_coords_w =
184 bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
185
186 let pos_ids = bucket_coords_h
187 .unsqueeze(D::Minus1)?
188 .mul(self.num_patches_per_side as f64)?
189 .broadcast_add(&bucket_coords_w)?
190 .flatten_all()?
191 .to_vec1::<u32>()?;
192
193 let true_indices = p_attn_mask
194 .flatten_all()?
195 .to_vec1::<u8>()?
196 .iter()
197 .enumerate()
198 .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
199 .collect::<Vec<_>>();
200 let position_ids_b = position_ids.i(b_idx)?;
201
202 let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
203 let new_position_ids_b_len = new_position_ids_b.len();
204 for (i, true_idx) in true_indices.into_iter().enumerate() {
205 new_position_ids_b[true_idx] = pos_ids[i];
206 }
207
208 new_position_ids.push(Tensor::from_vec(
209 new_position_ids_b,
210 new_position_ids_b_len,
211 pixel_values.device(),
212 )?);
213 }
214 let position_ids = Tensor::stack(&new_position_ids, 0)?;
215 let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
216 embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
217 }
218
219 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
220 let uvb = UnVarBuilder::new();
221
222 uvb.pp("patch_embedding").add(&self.patch_embedding);
223 uvb.pp("position_embedding").add(&self.position_embedding);
224
225 uvb.to_safetensors()
226 }
227}
228
229struct Attention {
230 embed_dim: usize,
231 num_heads: usize,
232 head_dim: usize,
233 q_proj: Arc<dyn QuantMethod>,
234 k_proj: Arc<dyn QuantMethod>,
235 v_proj: Arc<dyn QuantMethod>,
236 o_proj: Arc<dyn QuantMethod>,
237 sdpa_params: SdpaParams,
238}
239
240impl Attention {
241 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
242 let embed_dim = config.hidden_size;
243 let num_heads = config.num_attention_heads;
244 let head_dim = embed_dim / num_heads;
245 let scale = 1.0 / (head_dim as f32).sqrt();
246
247 let q_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("q_proj"))?;
248 let k_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("k_proj"))?;
249 let v_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("v_proj"))?;
250 let o_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("out_proj"))?;
251
252 Ok(Self {
253 embed_dim,
254 num_heads,
255 head_dim,
256 q_proj,
257 k_proj,
258 v_proj,
259 o_proj,
260 sdpa_params: SdpaParams {
261 n_kv_groups: 1,
262 softcap: None,
263 softmax_scale: scale,
264 sliding_window: None,
265 },
266 })
267 }
268
269 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
270 let (b_sz, q_len, _) = xs.dims3()?;
271
272 let mut q = self.q_proj.forward_autocast(xs)?;
273 let mut k = self.k_proj.forward_autocast(xs)?;
274 let mut v = self.v_proj.forward_autocast(xs)?;
275
276 q = q
277 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
278 .transpose(1, 2)?;
279 k = k
280 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
281 .transpose(1, 2)?;
282 v = v
283 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
284 .transpose(1, 2)?;
285
286 let attn_output =
287 Sdpa.run_attention(&q, &k, &v, attention_mask, None, &self.sdpa_params)?;
288
289 self.o_proj
290 .forward_autocast(&attn_output.transpose(1, 2)?.reshape((
291 b_sz,
292 q_len,
293 self.embed_dim,
294 ))?)
295 }
296
297 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
298 let uvb = UnVarBuilder::new();
299
300 uvb.pp("q_proj").add(&self.q_proj);
301 uvb.pp("k_proj").add(&self.k_proj);
302 uvb.pp("v_proj").add(&self.v_proj);
303 uvb.pp("out_proj").add(&self.o_proj);
304
305 uvb.to_safetensors()
306 }
307}
308
309struct VisionMLP {
310 activation: Activation,
311 fc1: Arc<dyn QuantMethod>,
312 fc2: Arc<dyn QuantMethod>,
313}
314
315impl VisionMLP {
316 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
317 let fc1 = mistralrs_quant::linear(
318 config.hidden_size,
319 config.intermediate_size,
320 &None,
321 vb.pp("fc1"),
322 )?;
323 let fc2 = mistralrs_quant::linear(
324 config.intermediate_size,
325 config.hidden_size,
326 &None,
327 vb.pp("fc2"),
328 )?;
329 Ok(Self {
330 activation: config.hidden_act,
331 fc1,
332 fc2,
333 })
334 }
335
336 fn forward(&self, x: &Tensor) -> Result<Tensor> {
337 let mut x = self.fc1.forward_autocast(x)?;
338 x = self.activation.forward(&x)?;
339 self.fc2.forward_autocast(&x)
340 }
341
342 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
343 let uvb = UnVarBuilder::new();
344
345 uvb.pp("fc1").add(&self.fc1);
346 uvb.pp("fc2").add(&self.fc2);
347
348 uvb.to_safetensors()
349 }
350}
351
352struct EncoderLayer {
353 mlp: VisionMLP,
354 attn: Attention,
355 layer_norm_1: LayerNorm,
356 layer_norm_2: LayerNorm,
357}
358
359impl EncoderLayer {
360 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
361 let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
362 let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
363 let layer_norm_1 = layer_norm(
364 config.hidden_size,
365 config.layer_norm_eps,
366 vb.pp("layer_norm1"),
367 )?;
368 let layer_norm_2 = layer_norm(
369 config.hidden_size,
370 config.layer_norm_eps,
371 vb.pp("layer_norm2"),
372 )?;
373 Ok(Self {
374 mlp,
375 attn,
376 layer_norm_1,
377 layer_norm_2,
378 })
379 }
380
381 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
382 let residual = xs.clone();
383
384 let mut hidden_states = self.layer_norm_1.forward(xs)?;
385 hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
386 hidden_states = (hidden_states + residual)?;
387
388 let residual = hidden_states.clone();
389 hidden_states = self.layer_norm_2.forward(&hidden_states)?;
390 hidden_states = self.mlp.forward(&hidden_states)?;
391 hidden_states + residual
392 }
393}
394
395struct Encoder {
396 layers: Vec<EncoderLayer>,
397}
398
399impl Encoder {
400 fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
401 let mut layers = Vec::new();
402 let vb_l = vb.pp("layers");
403 for i in 0..config.num_hidden_layers {
404 layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
405 }
406 Ok(Self { layers })
407 }
408
409 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
410 let mut hidden_states = xs.clone();
411 for layer in &self.layers {
412 hidden_states = layer.forward(&hidden_states, attention_mask)?;
413 }
414 Ok(hidden_states)
415 }
416}
417
418pub struct Idefics3VisionTransformer {
419 embeddings: VisionEmbeddings,
420 encoder: Encoder,
421 post_layernorm: LayerNorm,
422 patch_size: usize,
423 neg_inf: Tensor,
424}
425
426impl Idefics3VisionTransformer {
427 pub fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
428 let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
429 let post_layernorm = layer_norm(
430 config.hidden_size,
431 config.layer_norm_eps,
432 vb.pp("post_layernorm"),
433 )?;
434 let encoder = Encoder::new(config, vb.pp("encoder"))?;
435 Ok(Self {
436 embeddings,
437 encoder,
438 post_layernorm,
439 patch_size: config.patch_size,
440 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
441 })
442 }
443
444 pub fn forward(
445 &self,
446 pixel_values: &Tensor,
447 attention_mask: Option<&Tensor>,
448 ) -> Result<Tensor> {
449 let bs = pixel_values.dim(0)?;
450 let patch_attention_mask = if let Some(attn_mask) = attention_mask {
451 attn_mask.clone()
452 } else {
453 Tensor::ones(
454 (
455 bs,
456 pixel_values.dim(2)? / self.patch_size,
457 pixel_values.dim(3)? / self.patch_size,
458 ),
459 DType::U8,
460 pixel_values.device(),
461 )?
462 };
463
464 let hidden_states = self
465 .embeddings
466 .forward(pixel_values, &patch_attention_mask)?;
467
468 let attention_mask = if attention_mask.is_none() {
469 None
470 } else {
471 let mask = patch_attention_mask
472 .reshape((patch_attention_mask.dim(0)?, ()))?
473 .to_dtype(hidden_states.dtype())?;
474 Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
475 };
476
477 let attention_mask = match &attention_mask {
478 Some(mask) => {
479 let (b, _h, q, k) = mask.dims4()?;
480 let dims = [b, self.encoder.layers[0].attn.num_heads, q, k];
481 let mask = mask.broadcast_as(&dims)?;
482 let mask = mask.to_dtype(DType::U8)?.where_cond(
483 &self.neg_inf.to_device(mask.device())?.broadcast_as(&dims)?,
484 &Tensor::zeros(&dims, self.neg_inf.dtype(), self.neg_inf.device())?,
485 )?;
486
487 Some(mask)
488 }
489 None => None,
490 };
491 let hidden_states = self
492 .encoder
493 .forward(&hidden_states, attention_mask.as_ref())?;
494 hidden_states.apply(&self.post_layernorm)
495 }
496
497 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
498 let uvb = UnVarBuilder::new();
499
500 uvb.pp("post_layernorm").add(&self.post_layernorm);
501 uvb.pp("embeddings")
502 .extend(self.embeddings.residual_tensors());
503
504 let uvb_enc = uvb.pp("encoder");
505 for (i, layer) in self.encoder.layers.iter().enumerate() {
506 let uvb_l = uvb_enc.pp("layers").pp(i);
507
508 uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
509 uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
510 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
511 uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
512 }
513
514 uvb.to_safetensors()
515 }
516}