mistralrs_core/vision_models/
siglip.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{ops::Mul, sync::Arc};
7
8use crate::{
9    attention::SdpaParams,
10    layers::{conv2d, embedding, layer_norm, Activation, CausalMasker, Sdpa},
11    serde_default_fn,
12    utils::unvarbuilder::UnVarBuilder,
13};
14
15serde_default_fn!(usize, hidden_size, 768);
16serde_default_fn!(usize, intermediate_size, 3072);
17serde_default_fn!(usize, num_hidden_layers, 12);
18serde_default_fn!(usize, num_attention_heads, 12);
19serde_default_fn!(usize, num_channels, 3);
20serde_default_fn!(usize, image_size, 224);
21serde_default_fn!(usize, patch_size, 16);
22serde_default_fn!(Activation, hidden_act, Activation::GeluPytorchTanh);
23serde_default_fn!(f64, layer_norm_eps, 1e-6);
24
25#[derive(Debug, Clone, serde::Deserialize)]
26pub struct SiglipVisionConfig {
27    #[serde(default = "hidden_size")]
28    pub hidden_size: usize,
29    #[serde(default = "intermediate_size")]
30    pub intermediate_size: usize,
31    #[serde(default = "num_hidden_layers")]
32    pub num_hidden_layers: usize,
33    #[serde(default = "num_attention_heads")]
34    pub num_attention_heads: usize,
35    #[serde(default = "num_channels")]
36    pub num_channels: usize,
37    #[serde(default = "image_size")]
38    pub image_size: usize,
39    #[serde(default = "patch_size")]
40    pub patch_size: usize,
41    #[serde(default = "hidden_act")]
42    pub hidden_act: Activation,
43    #[serde(default = "layer_norm_eps")]
44    pub layer_norm_eps: f64,
45}
46
47impl Default for SiglipVisionConfig {
48    fn default() -> Self {
49        Self {
50            hidden_size: 768,
51            intermediate_size: 3072,
52            num_hidden_layers: 12,
53            num_attention_heads: 12,
54            num_channels: 3,
55            image_size: 224,
56            patch_size: 16,
57            hidden_act: Activation::GeluPytorchTanh,
58            layer_norm_eps: 1e-6,
59        }
60    }
61}
62
63pub(super) struct VisionEmbeddings {
64    patch_size: usize,
65    patch_embedding: Conv2d,
66    num_patches_per_side: usize,
67    pub(super) position_embedding: Embedding,
68}
69
70/// torch.bucketize with right=True
71/// Returns a 1d tensor of shape (xs.len(),) on the CPU
72fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
73    use std::cmp::Ordering;
74
75    let mut result = Vec::with_capacity(xs.len());
76
77    for &x in xs {
78        // binary_search_by returns:
79        //   Ok(i)   if boundaries[i] == x
80        //   Err(i)  if x would be inserted at i
81        //
82        // The returned i is the "insertion point" for x to keep
83        // boundaries sorted. That i is the smallest position
84        // where boundaries[i] >= x (i.e. bisect_left).
85
86        let idx = match boundaries.binary_search_by(|&val| {
87            // Use partial_cmp here; assume no NaNs.
88            // For robust handling of NaNs, you might need a custom comparison.
89            val.partial_cmp(&x).unwrap_or(Ordering::Less)
90        }) {
91            Ok(i) => i,
92            Err(i) => i,
93        };
94
95        result.push(idx as u32);
96    }
97
98    Tensor::from_vec(result, (xs.len(),), device)
99}
100
101impl VisionEmbeddings {
102    fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
103        let conv_config = Conv2dConfig {
104            stride: config.patch_size,
105            ..Default::default()
106        };
107        let patch_embedding = conv2d(
108            config.num_channels,
109            config.hidden_size,
110            config.patch_size,
111            conv_config,
112            vb.pp("patch_embedding"),
113        )?;
114        let num_patches_per_side = config.image_size / config.patch_size;
115        let num_patches = num_patches_per_side.pow(2);
116        Ok(Self {
117            patch_size: config.patch_size,
118            patch_embedding,
119            num_patches_per_side,
120            position_embedding: embedding(
121                num_patches,
122                config.hidden_size,
123                vb.pp("position_embedding"),
124                &None,
125            )?,
126        })
127    }
128
129    fn forward(
130        &self,
131        pixel_values: &Tensor,
132        patch_attention_mask: &Tensor,
133        tgt_sizes: Option<&Tensor>,
134    ) -> Result<Tensor> {
135        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
136
137        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
138
139        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
140
141        let (max_nb_patches_h, max_nb_patches_w) =
142            (max_im_h / self.patch_size, max_im_w / self.patch_size);
143        let boundaries = Tensor::arange_step(
144            1.0 / self.num_patches_per_side as f32,
145            1.0,
146            1.0 / self.num_patches_per_side as f32,
147            pixel_values.device(),
148        )?
149        .to_vec1::<f32>()?;
150        let position_ids = Tensor::full(
151            0u32,
152            (bs, max_nb_patches_h * max_nb_patches_w),
153            pixel_values.device(),
154        )?;
155
156        let mut new_position_ids = Vec::new();
157        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
158            let p_attn_mask = p_attn_mask.squeeze(0)?;
159            let (nb_patches_h, nb_patches_w) = if let Some(tgt_sizes) = tgt_sizes {
160                (tgt_sizes.i((b_idx, 0))?, tgt_sizes.i((b_idx, 1))?)
161            } else {
162                let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
163                let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
164                (nb_patches_h, nb_patches_w)
165            };
166
167            let fractional_coords_h = Tensor::arange_step(
168                0.0,
169                1.0 - 1e-6,
170                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
171                pixel_values.device(),
172            )?
173            .to_vec1::<f32>()?;
174            let fractional_coords_w = Tensor::arange_step(
175                0.0,
176                1.0 - 1e-6,
177                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
178                pixel_values.device(),
179            )?
180            .to_vec1::<f32>()?;
181
182            let bucket_coords_h =
183                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
184            let bucket_coords_w =
185                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
186
187            let pos_ids = bucket_coords_h
188                .unsqueeze(D::Minus1)?
189                .mul(self.num_patches_per_side as f64)?
190                .broadcast_add(&bucket_coords_w)?
191                .flatten_all()?
192                .to_vec1::<u32>()?;
193
194            let true_indices = p_attn_mask
195                .flatten_all()?
196                .to_vec1::<u8>()?
197                .iter()
198                .enumerate()
199                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
200                .collect::<Vec<_>>();
201            let position_ids_b = position_ids.i(b_idx)?;
202
203            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
204            let new_position_ids_b_len = new_position_ids_b.len();
205            for (i, true_idx) in true_indices.into_iter().enumerate() {
206                new_position_ids_b[true_idx] = pos_ids[i];
207            }
208
209            new_position_ids.push(Tensor::from_vec(
210                new_position_ids_b,
211                new_position_ids_b_len,
212                pixel_values.device(),
213            )?);
214        }
215        let position_ids = Tensor::stack(&new_position_ids, 0)?;
216        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
217        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
218    }
219
220    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
221        let uvb = UnVarBuilder::new();
222
223        uvb.pp("patch_embedding").add(&self.patch_embedding);
224        uvb.pp("position_embedding").add(&self.position_embedding);
225
226        uvb.to_safetensors()
227    }
228}
229
230struct Attention {
231    embed_dim: usize,
232    num_heads: usize,
233    head_dim: usize,
234    scale: f32,
235    q_proj: Arc<dyn QuantMethod>,
236    k_proj: Arc<dyn QuantMethod>,
237    v_proj: Arc<dyn QuantMethod>,
238    o_proj: Arc<dyn QuantMethod>,
239}
240
241impl Attention {
242    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
243        let embed_dim = config.hidden_size;
244        let num_heads = config.num_attention_heads;
245        let head_dim = embed_dim / num_heads;
246        let scale = 1.0 / (head_dim as f32).sqrt();
247
248        let q_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("q_proj"))?;
249        let k_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("k_proj"))?;
250        let v_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("v_proj"))?;
251        let o_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("out_proj"))?;
252
253        Ok(Self {
254            embed_dim,
255            num_heads,
256            head_dim,
257            scale,
258            q_proj,
259            k_proj,
260            v_proj,
261            o_proj,
262        })
263    }
264
265    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
266        let (b_sz, q_len, _) = xs.dims3()?;
267
268        let mut q = self.q_proj.forward(xs)?;
269        let mut k = self.k_proj.forward(xs)?;
270        let mut v = self.v_proj.forward(xs)?;
271
272        q = q
273            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
274            .transpose(1, 2)?
275            .contiguous()?;
276        k = k
277            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
278            .transpose(1, 2)?
279            .contiguous()?;
280        v = v
281            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
282            .transpose(1, 2)?
283            .contiguous()?;
284
285        let attn_weights = Sdpa.run_attention(
286            &q,
287            &k,
288            &v,
289            attention_mask,
290            None,
291            &SdpaParams {
292                n_kv_groups: 1,
293                use_flash_attn: false,
294                sliding_window: None,
295                softcap: None,
296                softmax_scale: self.scale,
297            },
298        )?;
299
300        self.o_proj.forward(&attn_weights.transpose(1, 2)?.reshape((
301            b_sz,
302            q_len,
303            self.embed_dim,
304        ))?)
305    }
306
307    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
308        let uvb = UnVarBuilder::new();
309
310        uvb.pp("q_proj").add(&self.q_proj);
311        uvb.pp("k_proj").add(&self.k_proj);
312        uvb.pp("v_proj").add(&self.v_proj);
313        uvb.pp("out_proj").add(&self.o_proj);
314
315        uvb.to_safetensors()
316    }
317}
318
319struct VisionMLP {
320    activation: Activation,
321    fc1: Arc<dyn QuantMethod>,
322    fc2: Arc<dyn QuantMethod>,
323}
324
325impl VisionMLP {
326    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
327        let fc1 = mistralrs_quant::linear(
328            config.hidden_size,
329            config.intermediate_size,
330            &None,
331            vb.pp("fc1"),
332        )?;
333        let fc2 = mistralrs_quant::linear(
334            config.intermediate_size,
335            config.hidden_size,
336            &None,
337            vb.pp("fc2"),
338        )?;
339        Ok(Self {
340            activation: config.hidden_act,
341            fc1,
342            fc2,
343        })
344    }
345
346    fn forward(&self, x: &Tensor) -> Result<Tensor> {
347        let mut x = self.fc1.forward(x)?;
348        x = self.activation.forward(&x)?;
349        self.fc2.forward(&x)
350    }
351
352    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
353        let uvb = UnVarBuilder::new();
354
355        uvb.pp("fc1").add(&self.fc1);
356        uvb.pp("fc2").add(&self.fc2);
357
358        uvb.to_safetensors()
359    }
360}
361
362struct EncoderLayer {
363    mlp: VisionMLP,
364    attn: Attention,
365    layer_norm_1: LayerNorm,
366    layer_norm_2: LayerNorm,
367}
368
369impl EncoderLayer {
370    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
371        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
372        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
373        let layer_norm_1 = layer_norm(
374            config.hidden_size,
375            config.layer_norm_eps,
376            vb.pp("layer_norm1"),
377        )?;
378        let layer_norm_2 = layer_norm(
379            config.hidden_size,
380            config.layer_norm_eps,
381            vb.pp("layer_norm2"),
382        )?;
383        Ok(Self {
384            mlp,
385            attn,
386            layer_norm_1,
387            layer_norm_2,
388        })
389    }
390
391    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
392        let residual = xs.clone();
393
394        let hidden_states = self.layer_norm_1.forward(xs)?;
395        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
396        let hidden_states = (hidden_states + residual)?;
397
398        let residual = &hidden_states;
399        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
400        let hidden_states = self.mlp.forward(&hidden_states)?;
401        hidden_states + residual
402    }
403}
404
405struct Encoder {
406    layers: Vec<EncoderLayer>,
407}
408
409impl Encoder {
410    fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
411        let mut layers = Vec::new();
412        let vb_l = vb.pp("layers");
413        for i in 0..config.num_hidden_layers {
414            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
415        }
416        Ok(Self { layers })
417    }
418
419    fn forward_get_hidden_states(
420        &self,
421        xs: &Tensor,
422        attention_mask: Option<&Tensor>,
423        hidden_states_index: isize,
424    ) -> Result<Tensor> {
425        let mut hidden_states = xs.clone();
426        for (layer_idx, layer) in self.layers.iter().enumerate() {
427            hidden_states = layer.forward(&hidden_states, attention_mask)?;
428            if (self.layers.len() as isize + hidden_states_index) as usize == layer_idx {
429                return Ok(hidden_states);
430            }
431        }
432        Ok(hidden_states)
433    }
434}
435
436pub struct SiglipVisionTransformer {
437    pub(super) embeddings: VisionEmbeddings,
438    encoder: Encoder,
439    post_layernorm: LayerNorm,
440    config: SiglipVisionConfig,
441}
442
443impl SiglipVisionTransformer {
444    pub fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
445        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
446        let post_layernorm = layer_norm(
447            config.hidden_size,
448            config.layer_norm_eps,
449            vb.pp("post_layernorm"),
450        )?;
451        let encoder = Encoder::new(config, vb.pp("encoder"))?;
452        Ok(Self {
453            embeddings,
454            encoder,
455            post_layernorm,
456            config: config.clone(),
457        })
458    }
459
460    pub fn forward(
461        &self,
462        pixel_values: &Tensor,
463        attention_mask: Option<&Tensor>,
464        tgt_sizes: Option<&Tensor>,
465    ) -> Result<Tensor> {
466        self.forward_get_hidden_states(pixel_values, attention_mask, tgt_sizes, -1)
467    }
468
469    pub fn forward_get_hidden_states(
470        &self,
471        pixel_values: &Tensor,
472        attention_mask: Option<&Tensor>,
473        tgt_sizes: Option<&Tensor>,
474        hidden_states_index: isize,
475    ) -> Result<Tensor> {
476        let bs = pixel_values.dim(0)?;
477        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
478            attn_mask.clone()
479        } else {
480            let patch_size = self.config.patch_size;
481            Tensor::ones(
482                (
483                    bs,
484                    pixel_values.dim(2)? / patch_size,
485                    pixel_values.dim(3)? / patch_size,
486                ),
487                DType::U8,
488                pixel_values.device(),
489            )?
490        };
491
492        let hidden_states =
493            self.embeddings
494                .forward(pixel_values, &patch_attention_mask, tgt_sizes)?;
495
496        let attention_mask = if attention_mask.is_none() {
497            None
498        } else {
499            let mask = patch_attention_mask
500                .reshape((patch_attention_mask.dim(0)?, ()))?
501                .to_dtype(hidden_states.dtype())?;
502            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
503        };
504        let hidden_states = self.encoder.forward_get_hidden_states(
505            &hidden_states,
506            attention_mask.as_ref(),
507            hidden_states_index + 1,
508        )?;
509        hidden_states.apply(&self.post_layernorm)
510    }
511
512    pub fn dtype(&self) -> DType {
513        self.embeddings.patch_embedding.weight().dtype()
514    }
515
516    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
517        let uvb = UnVarBuilder::new();
518
519        uvb.pp("post_layernorm").add(&self.post_layernorm);
520        uvb.pp("embeddings")
521            .extend(self.embeddings.residual_tensors());
522
523        let uvb_enc = uvb.pp("encoder");
524        for (i, layer) in self.encoder.layers.iter().enumerate() {
525            let uvb_l = uvb_enc.pp("layers").pp(i);
526
527            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
528            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
529            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
530            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
531        }
532
533        uvb.to_safetensors()
534    }
535}