1use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
2use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Linear, Module};
3use mistralrs_quant::ShardedVarBuilder;
4use std::ops::Mul;
5
6use crate::{
7 layers::{self, conv2d, embedding, layer_norm, Activation, CausalMasker, MatMul},
8 utils::unvarbuilder::UnVarBuilder,
9};
10
11use super::config::{Idefics3Config, Idefics3VisionConfig};
12
13pub(crate) struct Idefics3SimpleMLP {
14 pub(crate) proj: Linear,
15}
16
17impl Idefics3SimpleMLP {
18 pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
19 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
20 let out_dim = cfg.text_config.hidden_size;
21 Ok(Self {
22 proj: layers::linear_no_bias(in_dim, out_dim, vb.pp("proj"))?,
23 })
24 }
25
26 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
27 x.apply(&self.proj)
28 }
29}
30
31pub struct Idefics3Connector {
32 scale_factor: usize,
33 pub(crate) modality_projection: Idefics3SimpleMLP,
34}
35
36impl Idefics3Connector {
37 pub fn new(cfg: &Idefics3Config, vb: ShardedVarBuilder) -> Result<Self> {
38 Ok(Self {
39 scale_factor: cfg.scale_factor,
40 modality_projection: Idefics3SimpleMLP::new(cfg, vb.pp("modality_projection"))?,
41 })
42 }
43
44 pub fn pixel_shuffle(&self, x: &Tensor, scale_factor: usize) -> Result<Tensor> {
45 let (bs, seq, embed_dim) = x.dims3()?;
46 let height = (seq as f32).sqrt() as usize;
47 let width = height;
48 let mut x = x.reshape((bs, height, width, embed_dim))?;
49 x = x.reshape((bs, height, width / scale_factor, embed_dim * scale_factor))?;
50 x = x.permute((0, 2, 1, 3))?;
51 x = x.reshape((
52 bs,
53 width / scale_factor,
54 height / scale_factor,
55 embed_dim * scale_factor.pow(2),
56 ))?;
57 x = x.permute((0, 2, 1, 3))?;
58 x.reshape((
59 bs,
60 (seq as f32 / scale_factor.pow(2) as f32) as usize,
61 embed_dim * scale_factor.pow(2),
62 ))
63 }
64
65 pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
66 let image_hidden_states = self.pixel_shuffle(x, self.scale_factor)?;
67 self.modality_projection.forward(&image_hidden_states)
68 }
69}
70
71struct VisionEmbeddings {
72 patch_size: usize,
73 patch_embedding: Conv2d,
74 num_patches_per_side: usize,
75 position_embedding: Embedding,
76}
77
78fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
81 use std::cmp::Ordering;
82
83 let mut result = Vec::with_capacity(xs.len());
84
85 for &x in xs {
86 let idx = match boundaries.binary_search_by(|&val| {
95 val.partial_cmp(&x).unwrap_or(Ordering::Less)
98 }) {
99 Ok(i) => i,
100 Err(i) => i,
101 };
102
103 result.push(idx as u32);
104 }
105
106 Tensor::from_vec(result, (xs.len(),), device)
107}
108
109impl VisionEmbeddings {
110 fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
111 let conv_config = Conv2dConfig {
112 stride: config.patch_size,
113 ..Default::default()
114 };
115 let patch_embedding = conv2d(
116 config.num_channels,
117 config.hidden_size,
118 config.patch_size,
119 conv_config,
120 vb.pp("patch_embedding"),
121 )?;
122 let num_patches_per_side = config.image_size / config.patch_size;
123 let num_patches = num_patches_per_side.pow(2);
124 Ok(Self {
125 patch_size: config.patch_size,
126 patch_embedding,
127 num_patches_per_side,
128 position_embedding: embedding(
129 num_patches,
130 config.hidden_size,
131 vb.pp("position_embedding"),
132 &None,
133 )?,
134 })
135 }
136
137 fn forward(&self, pixel_values: &Tensor, patch_attention_mask: &Tensor) -> Result<Tensor> {
138 let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
139
140 let patch_embeds = self.patch_embedding.forward(pixel_values)?;
141
142 let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
143
144 let (max_nb_patches_h, max_nb_patches_w) =
145 (max_im_h / self.patch_size, max_im_w / self.patch_size);
146 let boundaries = Tensor::arange_step(
147 1.0 / self.num_patches_per_side as f32,
148 1.0,
149 1.0 / self.num_patches_per_side as f32,
150 pixel_values.device(),
151 )?
152 .to_vec1::<f32>()?;
153 let position_ids = Tensor::full(
154 0u32,
155 (bs, max_nb_patches_h * max_nb_patches_w),
156 pixel_values.device(),
157 )?;
158
159 let mut new_position_ids = Vec::new();
160 for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
161 let p_attn_mask = p_attn_mask.squeeze(0)?;
162 let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
163 let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
164
165 let fractional_coords_h = Tensor::arange_step(
166 0.0,
167 1.0 - 1e-6,
168 1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
169 pixel_values.device(),
170 )?
171 .to_vec1::<f32>()?;
172 let fractional_coords_w = Tensor::arange_step(
173 0.0,
174 1.0 - 1e-6,
175 1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
176 pixel_values.device(),
177 )?
178 .to_vec1::<f32>()?;
179
180 let bucket_coords_h =
181 bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
182 let bucket_coords_w =
183 bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
184
185 let pos_ids = bucket_coords_h
186 .unsqueeze(D::Minus1)?
187 .mul(self.num_patches_per_side as f64)?
188 .broadcast_add(&bucket_coords_w)?
189 .flatten_all()?
190 .to_vec1::<u32>()?;
191
192 let true_indices = p_attn_mask
193 .flatten_all()?
194 .to_vec1::<u8>()?
195 .iter()
196 .enumerate()
197 .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
198 .collect::<Vec<_>>();
199 let position_ids_b = position_ids.i(b_idx)?;
200
201 let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
202 let new_position_ids_b_len = new_position_ids_b.len();
203 for (i, true_idx) in true_indices.into_iter().enumerate() {
204 new_position_ids_b[true_idx] = pos_ids[i];
205 }
206
207 new_position_ids.push(Tensor::from_vec(
208 new_position_ids_b,
209 new_position_ids_b_len,
210 pixel_values.device(),
211 )?);
212 }
213 let position_ids = Tensor::stack(&new_position_ids, 0)?;
214 let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
215 embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
216 }
217
218 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
219 let uvb = UnVarBuilder::new();
220
221 uvb.pp("patch_embedding").add(&self.patch_embedding);
222 uvb.pp("position_embedding").add(&self.position_embedding);
223
224 uvb.to_safetensors()
225 }
226}
227
228struct Attention {
229 embed_dim: usize,
230 num_heads: usize,
231 head_dim: usize,
232 scale: f64,
233 q_proj: Linear,
234 k_proj: Linear,
235 v_proj: Linear,
236 o_proj: Linear,
237 neg_inf: Tensor,
238}
239
240impl Attention {
241 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
242 let embed_dim = config.hidden_size;
243 let num_heads = config.num_attention_heads;
244 let head_dim = embed_dim / num_heads;
245 let scale = 1.0 / (head_dim as f64).sqrt();
246
247 let q_proj = layers::linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
248 let k_proj = layers::linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
249 let v_proj = layers::linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
250 let o_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
251
252 Ok(Self {
253 embed_dim,
254 num_heads,
255 head_dim,
256 scale,
257 q_proj,
258 k_proj,
259 v_proj,
260 o_proj,
261 neg_inf: Tensor::new(f32::NEG_INFINITY, vb.device())?.to_dtype(vb.dtype())?,
262 })
263 }
264
265 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
266 let (b_sz, q_len, _) = xs.dims3()?;
267
268 let mut q = self.q_proj.forward(xs)?;
269 let mut k = self.k_proj.forward(xs)?;
270 let mut v = self.v_proj.forward(xs)?;
271
272 q = q
273 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
274 .transpose(1, 2)?;
275 k = k
276 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
277 .transpose(1, 2)?;
278 v = v
279 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
280 .transpose(1, 2)?;
281
282 let attn_weights =
283 (MatMul.matmul(&q.contiguous()?, &k.transpose(2, 3)?.contiguous()?)? * self.scale)?;
284
285 let mut attn_weights = CausalMasker.apply_mask_one_and_zero(
286 &attention_mask.map(|x| x.to_dtype(DType::U8).unwrap()),
287 attn_weights,
288 &self.neg_inf,
289 )?;
290 attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
291 let attn_output = MatMul.matmul(&attn_weights, &v.contiguous()?)?;
292
293 attn_output
294 .transpose(1, 2)?
295 .reshape((b_sz, q_len, self.embed_dim))?
296 .apply(&self.o_proj)
297 }
298
299 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
300 let uvb = UnVarBuilder::new();
301
302 uvb.pp("q_proj").add(&self.q_proj);
303 uvb.pp("k_proj").add(&self.k_proj);
304 uvb.pp("v_proj").add(&self.v_proj);
305 uvb.pp("out_proj").add(&self.o_proj);
306
307 uvb.to_safetensors()
308 }
309}
310
311struct VisionMLP {
312 activation: Activation,
313 fc1: Linear,
314 fc2: Linear,
315}
316
317impl VisionMLP {
318 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
319 let fc1 = layers::linear(config.hidden_size, config.intermediate_size, vb.pp("fc1"))?;
320 let fc2 = layers::linear(config.intermediate_size, config.hidden_size, vb.pp("fc2"))?;
321 Ok(Self {
322 activation: config.hidden_act,
323 fc1,
324 fc2,
325 })
326 }
327
328 fn forward(&self, x: &Tensor) -> Result<Tensor> {
329 let mut x = self.fc1.forward(x)?;
330 x = self.activation.forward(&x)?;
331 self.fc2.forward(&x)
332 }
333
334 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
335 let uvb = UnVarBuilder::new();
336
337 uvb.pp("fc1").add(&self.fc1);
338 uvb.pp("fc2").add(&self.fc2);
339
340 uvb.to_safetensors()
341 }
342}
343
344struct EncoderLayer {
345 mlp: VisionMLP,
346 attn: Attention,
347 layer_norm_1: LayerNorm,
348 layer_norm_2: LayerNorm,
349}
350
351impl EncoderLayer {
352 fn new(config: Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
353 let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
354 let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
355 let layer_norm_1 = layer_norm(
356 config.hidden_size,
357 config.layer_norm_eps,
358 vb.pp("layer_norm1"),
359 )?;
360 let layer_norm_2 = layer_norm(
361 config.hidden_size,
362 config.layer_norm_eps,
363 vb.pp("layer_norm2"),
364 )?;
365 Ok(Self {
366 mlp,
367 attn,
368 layer_norm_1,
369 layer_norm_2,
370 })
371 }
372
373 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
374 let residual = xs.clone();
375
376 let hidden_states = self.layer_norm_1.forward(xs)?;
377 let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
378 let hidden_states = (hidden_states + residual)?;
379
380 let residual = &hidden_states;
381 let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
382 let hidden_states = self.mlp.forward(&hidden_states)?;
383 hidden_states + residual
384 }
385}
386
387struct Encoder {
388 layers: Vec<EncoderLayer>,
389}
390
391impl Encoder {
392 fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
393 let mut layers = Vec::new();
394 let vb_l = vb.pp("layers");
395 for i in 0..config.num_hidden_layers {
396 layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
397 }
398 Ok(Self { layers })
399 }
400
401 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
402 let mut hidden_states = xs.clone();
403 for layer in &self.layers {
404 hidden_states = layer.forward(&hidden_states, attention_mask)?;
405 }
406 Ok(hidden_states)
407 }
408}
409
410pub struct Idefics3VisionTransformer {
411 embeddings: VisionEmbeddings,
412 encoder: Encoder,
413 post_layernorm: LayerNorm,
414 patch_size: usize,
415}
416
417impl Idefics3VisionTransformer {
418 pub fn new(config: &Idefics3VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
419 let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
420 let post_layernorm = layer_norm(
421 config.hidden_size,
422 config.layer_norm_eps,
423 vb.pp("post_layernorm"),
424 )?;
425 let encoder = Encoder::new(config, vb.pp("encoder"))?;
426 Ok(Self {
427 embeddings,
428 encoder,
429 post_layernorm,
430 patch_size: config.patch_size,
431 })
432 }
433
434 pub fn forward(
435 &self,
436 pixel_values: &Tensor,
437 attention_mask: Option<&Tensor>,
438 ) -> Result<Tensor> {
439 let bs = pixel_values.dim(0)?;
440 let patch_attention_mask = if let Some(attn_mask) = attention_mask {
441 attn_mask.clone()
442 } else {
443 Tensor::ones(
444 (
445 bs,
446 pixel_values.dim(2)? / self.patch_size,
447 pixel_values.dim(3)? / self.patch_size,
448 ),
449 DType::U8,
450 pixel_values.device(),
451 )?
452 };
453
454 let hidden_states = self
455 .embeddings
456 .forward(pixel_values, &patch_attention_mask)?;
457
458 let attention_mask = if attention_mask.is_none() {
459 None
460 } else {
461 let mask = patch_attention_mask
462 .reshape((patch_attention_mask.dim(0)?, ()))?
463 .to_dtype(hidden_states.dtype())?;
464 Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
465 };
466 let hidden_states = self
467 .encoder
468 .forward(&hidden_states, attention_mask.as_ref())?;
469 hidden_states.apply(&self.post_layernorm)
470 }
471
472 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
473 let uvb = UnVarBuilder::new();
474
475 uvb.pp("post_layernorm").add(&self.post_layernorm);
476 uvb.pp("embeddings")
477 .extend(self.embeddings.residual_tensors());
478
479 let uvb_enc = uvb.pp("encoder");
480 for (i, layer) in self.encoder.layers.iter().enumerate() {
481 let uvb_l = uvb_enc.pp("layers").pp(i);
482
483 uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
484 uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
485 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
486 uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
487 }
488
489 uvb.to_safetensors()
490 }
491}