1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{ops::Mul, sync::Arc};
7
8use crate::{
9 attention::SdpaParams,
10 layers::{conv2d, embedding, layer_norm, Activation, CausalMasker, Sdpa},
11 serde_default_fn,
12 utils::unvarbuilder::UnVarBuilder,
13};
14
15serde_default_fn!(usize, hidden_size, 768);
16serde_default_fn!(usize, intermediate_size, 3072);
17serde_default_fn!(usize, num_hidden_layers, 12);
18serde_default_fn!(usize, num_attention_heads, 12);
19serde_default_fn!(usize, num_channels, 3);
20serde_default_fn!(usize, image_size, 224);
21serde_default_fn!(usize, patch_size, 16);
22serde_default_fn!(Activation, hidden_act, Activation::GeluPytorchTanh);
23serde_default_fn!(f64, layer_norm_eps, 1e-6);
24
25#[derive(Debug, Clone, serde::Deserialize)]
26pub struct SiglipVisionConfig {
27 #[serde(default = "hidden_size")]
28 pub hidden_size: usize,
29 #[serde(default = "intermediate_size")]
30 pub intermediate_size: usize,
31 #[serde(default = "num_hidden_layers")]
32 pub num_hidden_layers: usize,
33 #[serde(default = "num_attention_heads")]
34 pub num_attention_heads: usize,
35 #[serde(default = "num_channels")]
36 pub num_channels: usize,
37 #[serde(default = "image_size")]
38 pub image_size: usize,
39 #[serde(default = "patch_size")]
40 pub patch_size: usize,
41 #[serde(default = "hidden_act")]
42 pub hidden_act: Activation,
43 #[serde(default = "layer_norm_eps")]
44 pub layer_norm_eps: f64,
45}
46
47impl Default for SiglipVisionConfig {
48 fn default() -> Self {
49 Self {
50 hidden_size: 768,
51 intermediate_size: 3072,
52 num_hidden_layers: 12,
53 num_attention_heads: 12,
54 num_channels: 3,
55 image_size: 224,
56 patch_size: 16,
57 hidden_act: Activation::GeluPytorchTanh,
58 layer_norm_eps: 1e-6,
59 }
60 }
61}
62
63pub(super) struct VisionEmbeddings {
64 patch_size: usize,
65 patch_embedding: Conv2d,
66 num_patches_per_side: usize,
67 pub(super) position_embedding: Embedding,
68}
69
70fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
73 use std::cmp::Ordering;
74
75 let mut result = Vec::with_capacity(xs.len());
76
77 for &x in xs {
78 let idx = match boundaries.binary_search_by(|&val| {
87 val.partial_cmp(&x).unwrap_or(Ordering::Less)
90 }) {
91 Ok(i) => i,
92 Err(i) => i,
93 };
94
95 result.push(idx as u32);
96 }
97
98 Tensor::from_vec(result, (xs.len(),), device)
99}
100
101impl VisionEmbeddings {
102 fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
103 let conv_config = Conv2dConfig {
104 stride: config.patch_size,
105 ..Default::default()
106 };
107 let patch_embedding = conv2d(
108 config.num_channels,
109 config.hidden_size,
110 config.patch_size,
111 conv_config,
112 vb.pp("patch_embedding"),
113 )?;
114 let num_patches_per_side = config.image_size / config.patch_size;
115 let num_patches = num_patches_per_side.pow(2);
116 Ok(Self {
117 patch_size: config.patch_size,
118 patch_embedding,
119 num_patches_per_side,
120 position_embedding: embedding(
121 num_patches,
122 config.hidden_size,
123 vb.pp("position_embedding"),
124 &None,
125 )?,
126 })
127 }
128
129 fn forward(
130 &self,
131 pixel_values: &Tensor,
132 patch_attention_mask: &Tensor,
133 tgt_sizes: Option<&Tensor>,
134 ) -> Result<Tensor> {
135 let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
136
137 let patch_embeds = self.patch_embedding.forward(pixel_values)?;
138
139 let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
140
141 let (max_nb_patches_h, max_nb_patches_w) =
142 (max_im_h / self.patch_size, max_im_w / self.patch_size);
143 let boundaries = Tensor::arange_step(
144 1.0 / self.num_patches_per_side as f32,
145 1.0,
146 1.0 / self.num_patches_per_side as f32,
147 pixel_values.device(),
148 )?
149 .to_vec1::<f32>()?;
150 let position_ids = Tensor::full(
151 0u32,
152 (bs, max_nb_patches_h * max_nb_patches_w),
153 pixel_values.device(),
154 )?;
155
156 let mut new_position_ids = Vec::new();
157 for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
158 let p_attn_mask = p_attn_mask.squeeze(0)?;
159 let (nb_patches_h, nb_patches_w) = if let Some(tgt_sizes) = tgt_sizes {
160 (tgt_sizes.i((b_idx, 0))?, tgt_sizes.i((b_idx, 1))?)
161 } else {
162 let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
163 let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
164 (nb_patches_h, nb_patches_w)
165 };
166
167 let fractional_coords_h = Tensor::arange_step(
168 0.0,
169 1.0 - 1e-6,
170 1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
171 pixel_values.device(),
172 )?
173 .to_vec1::<f32>()?;
174 let fractional_coords_w = Tensor::arange_step(
175 0.0,
176 1.0 - 1e-6,
177 1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
178 pixel_values.device(),
179 )?
180 .to_vec1::<f32>()?;
181
182 let bucket_coords_h =
183 bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
184 let bucket_coords_w =
185 bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
186
187 let pos_ids = bucket_coords_h
188 .unsqueeze(D::Minus1)?
189 .mul(self.num_patches_per_side as f64)?
190 .broadcast_add(&bucket_coords_w)?
191 .flatten_all()?
192 .to_vec1::<u32>()?;
193
194 let true_indices = p_attn_mask
195 .flatten_all()?
196 .to_vec1::<u8>()?
197 .iter()
198 .enumerate()
199 .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
200 .collect::<Vec<_>>();
201 let position_ids_b = position_ids.i(b_idx)?;
202
203 let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
204 let new_position_ids_b_len = new_position_ids_b.len();
205 for (i, true_idx) in true_indices.into_iter().enumerate() {
206 new_position_ids_b[true_idx] = pos_ids[i];
207 }
208
209 new_position_ids.push(Tensor::from_vec(
210 new_position_ids_b,
211 new_position_ids_b_len,
212 pixel_values.device(),
213 )?);
214 }
215 let position_ids = Tensor::stack(&new_position_ids, 0)?;
216 let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
217 embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
218 }
219
220 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
221 let uvb = UnVarBuilder::new();
222
223 uvb.pp("patch_embedding").add(&self.patch_embedding);
224 uvb.pp("position_embedding").add(&self.position_embedding);
225
226 uvb.to_safetensors()
227 }
228}
229
230struct Attention {
231 embed_dim: usize,
232 num_heads: usize,
233 head_dim: usize,
234 scale: f32,
235 q_proj: Arc<dyn QuantMethod>,
236 k_proj: Arc<dyn QuantMethod>,
237 v_proj: Arc<dyn QuantMethod>,
238 o_proj: Arc<dyn QuantMethod>,
239}
240
241impl Attention {
242 fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
243 let embed_dim = config.hidden_size;
244 let num_heads = config.num_attention_heads;
245 let head_dim = embed_dim / num_heads;
246 let scale = 1.0 / (head_dim as f32).sqrt();
247
248 let q_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("q_proj"))?;
249 let k_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("k_proj"))?;
250 let v_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("v_proj"))?;
251 let o_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("out_proj"))?;
252
253 Ok(Self {
254 embed_dim,
255 num_heads,
256 head_dim,
257 scale,
258 q_proj,
259 k_proj,
260 v_proj,
261 o_proj,
262 })
263 }
264
265 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
266 let (b_sz, q_len, _) = xs.dims3()?;
267
268 let mut q = self.q_proj.forward(xs)?;
269 let mut k = self.k_proj.forward(xs)?;
270 let mut v = self.v_proj.forward(xs)?;
271
272 q = q
273 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
274 .transpose(1, 2)?
275 .contiguous()?;
276 k = k
277 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
278 .transpose(1, 2)?
279 .contiguous()?;
280 v = v
281 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
282 .transpose(1, 2)?
283 .contiguous()?;
284
285 let attn_weights = Sdpa.run_attention(
286 &q,
287 &k,
288 &v,
289 attention_mask,
290 None,
291 &SdpaParams {
292 n_kv_groups: 1,
293 use_flash_attn: false,
294 sliding_window: None,
295 softcap: None,
296 softmax_scale: self.scale,
297 },
298 )?;
299
300 self.o_proj.forward(&attn_weights.transpose(1, 2)?.reshape((
301 b_sz,
302 q_len,
303 self.embed_dim,
304 ))?)
305 }
306
307 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
308 let uvb = UnVarBuilder::new();
309
310 uvb.pp("q_proj").add(&self.q_proj);
311 uvb.pp("k_proj").add(&self.k_proj);
312 uvb.pp("v_proj").add(&self.v_proj);
313 uvb.pp("out_proj").add(&self.o_proj);
314
315 uvb.to_safetensors()
316 }
317}
318
319struct VisionMLP {
320 activation: Activation,
321 fc1: Arc<dyn QuantMethod>,
322 fc2: Arc<dyn QuantMethod>,
323}
324
325impl VisionMLP {
326 fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
327 let fc1 = mistralrs_quant::linear(
328 config.hidden_size,
329 config.intermediate_size,
330 &None,
331 vb.pp("fc1"),
332 )?;
333 let fc2 = mistralrs_quant::linear(
334 config.intermediate_size,
335 config.hidden_size,
336 &None,
337 vb.pp("fc2"),
338 )?;
339 Ok(Self {
340 activation: config.hidden_act,
341 fc1,
342 fc2,
343 })
344 }
345
346 fn forward(&self, x: &Tensor) -> Result<Tensor> {
347 let mut x = self.fc1.forward(x)?;
348 x = self.activation.forward(&x)?;
349 self.fc2.forward(&x)
350 }
351
352 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
353 let uvb = UnVarBuilder::new();
354
355 uvb.pp("fc1").add(&self.fc1);
356 uvb.pp("fc2").add(&self.fc2);
357
358 uvb.to_safetensors()
359 }
360}
361
362struct EncoderLayer {
363 mlp: VisionMLP,
364 attn: Attention,
365 layer_norm_1: LayerNorm,
366 layer_norm_2: LayerNorm,
367}
368
369impl EncoderLayer {
370 fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
371 let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
372 let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
373 let layer_norm_1 = layer_norm(
374 config.hidden_size,
375 config.layer_norm_eps,
376 vb.pp("layer_norm1"),
377 )?;
378 let layer_norm_2 = layer_norm(
379 config.hidden_size,
380 config.layer_norm_eps,
381 vb.pp("layer_norm2"),
382 )?;
383 Ok(Self {
384 mlp,
385 attn,
386 layer_norm_1,
387 layer_norm_2,
388 })
389 }
390
391 fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
392 let residual = xs.clone();
393
394 let hidden_states = self.layer_norm_1.forward(xs)?;
395 let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
396 let hidden_states = (hidden_states + residual)?;
397
398 let residual = &hidden_states;
399 let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
400 let hidden_states = self.mlp.forward(&hidden_states)?;
401 hidden_states + residual
402 }
403}
404
405struct Encoder {
406 layers: Vec<EncoderLayer>,
407}
408
409impl Encoder {
410 fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
411 let mut layers = Vec::new();
412 let vb_l = vb.pp("layers");
413 for i in 0..config.num_hidden_layers {
414 layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
415 }
416 Ok(Self { layers })
417 }
418
419 fn forward_get_hidden_states(
420 &self,
421 xs: &Tensor,
422 attention_mask: Option<&Tensor>,
423 hidden_states_index: isize,
424 ) -> Result<Tensor> {
425 let mut hidden_states = xs.clone();
426 for (layer_idx, layer) in self.layers.iter().enumerate() {
427 hidden_states = layer.forward(&hidden_states, attention_mask)?;
428 if (self.layers.len() as isize + hidden_states_index) as usize == layer_idx {
429 return Ok(hidden_states);
430 }
431 }
432 Ok(hidden_states)
433 }
434}
435
436pub struct SiglipVisionTransformer {
437 pub(super) embeddings: VisionEmbeddings,
438 encoder: Encoder,
439 post_layernorm: LayerNorm,
440 config: SiglipVisionConfig,
441}
442
443impl SiglipVisionTransformer {
444 pub fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
445 let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
446 let post_layernorm = layer_norm(
447 config.hidden_size,
448 config.layer_norm_eps,
449 vb.pp("post_layernorm"),
450 )?;
451 let encoder = Encoder::new(config, vb.pp("encoder"))?;
452 Ok(Self {
453 embeddings,
454 encoder,
455 post_layernorm,
456 config: config.clone(),
457 })
458 }
459
460 pub fn forward(
461 &self,
462 pixel_values: &Tensor,
463 attention_mask: Option<&Tensor>,
464 tgt_sizes: Option<&Tensor>,
465 ) -> Result<Tensor> {
466 self.forward_get_hidden_states(pixel_values, attention_mask, tgt_sizes, -1)
467 }
468
469 pub fn forward_get_hidden_states(
470 &self,
471 pixel_values: &Tensor,
472 attention_mask: Option<&Tensor>,
473 tgt_sizes: Option<&Tensor>,
474 hidden_states_index: isize,
475 ) -> Result<Tensor> {
476 let bs = pixel_values.dim(0)?;
477 let patch_attention_mask = if let Some(attn_mask) = attention_mask {
478 attn_mask.clone()
479 } else {
480 let patch_size = self.config.patch_size;
481 Tensor::ones(
482 (
483 bs,
484 pixel_values.dim(2)? / patch_size,
485 pixel_values.dim(3)? / patch_size,
486 ),
487 DType::U8,
488 pixel_values.device(),
489 )?
490 };
491
492 let hidden_states =
493 self.embeddings
494 .forward(pixel_values, &patch_attention_mask, tgt_sizes)?;
495
496 let attention_mask = if attention_mask.is_none() {
497 None
498 } else {
499 let mask = patch_attention_mask
500 .reshape((patch_attention_mask.dim(0)?, ()))?
501 .to_dtype(hidden_states.dtype())?;
502 Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
503 };
504 let hidden_states = self.encoder.forward_get_hidden_states(
505 &hidden_states,
506 attention_mask.as_ref(),
507 hidden_states_index + 1,
508 )?;
509 hidden_states.apply(&self.post_layernorm)
510 }
511
512 pub fn dtype(&self) -> DType {
513 self.embeddings.patch_embedding.weight().dtype()
514 }
515
516 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
517 let uvb = UnVarBuilder::new();
518
519 uvb.pp("post_layernorm").add(&self.post_layernorm);
520 uvb.pp("embeddings")
521 .extend(self.embeddings.residual_tensors());
522
523 let uvb_enc = uvb.pp("encoder");
524 for (i, layer) in self.encoder.layers.iter().enumerate() {
525 let uvb_l = uvb_enc.pp("layers").pp(i);
526
527 uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
528 uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
529 uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
530 uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
531 }
532
533 uvb.to_safetensors()
534 }
535}