mistralrs_core/vision_models/
siglip.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{ops::Mul, sync::Arc};
7
8use crate::{
9    attention::SdpaParams,
10    layers::{conv2d, embedding, layer_norm, Activation, CausalMasker, Sdpa},
11    serde_default_fn,
12    utils::unvarbuilder::UnVarBuilder,
13};
14
15serde_default_fn!(usize, hidden_size, 768);
16serde_default_fn!(usize, intermediate_size, 3072);
17serde_default_fn!(usize, num_hidden_layers, 12);
18serde_default_fn!(usize, num_attention_heads, 12);
19serde_default_fn!(usize, num_channels, 3);
20serde_default_fn!(usize, image_size, 224);
21serde_default_fn!(usize, patch_size, 16);
22serde_default_fn!(Activation, hidden_act, Activation::GeluPytorchTanh);
23serde_default_fn!(f64, layer_norm_eps, 1e-6);
24
25#[derive(Debug, Clone, serde::Deserialize)]
26pub struct SiglipVisionConfig {
27    #[serde(default = "hidden_size")]
28    pub hidden_size: usize,
29    #[serde(default = "intermediate_size")]
30    pub intermediate_size: usize,
31    #[serde(default = "num_hidden_layers")]
32    pub num_hidden_layers: usize,
33    #[serde(default = "num_attention_heads")]
34    pub num_attention_heads: usize,
35    #[serde(default = "num_channels")]
36    pub num_channels: usize,
37    #[serde(default = "image_size")]
38    pub image_size: usize,
39    #[serde(default = "patch_size")]
40    pub patch_size: usize,
41    #[serde(default = "hidden_act")]
42    pub hidden_act: Activation,
43    #[serde(default = "layer_norm_eps")]
44    pub layer_norm_eps: f64,
45}
46
47impl Default for SiglipVisionConfig {
48    fn default() -> Self {
49        Self {
50            hidden_size: 768,
51            intermediate_size: 3072,
52            num_hidden_layers: 12,
53            num_attention_heads: 12,
54            num_channels: 3,
55            image_size: 224,
56            patch_size: 16,
57            hidden_act: Activation::GeluPytorchTanh,
58            layer_norm_eps: 1e-6,
59        }
60    }
61}
62
63pub(super) struct VisionEmbeddings {
64    patch_size: usize,
65    patch_embedding: Conv2d,
66    num_patches_per_side: usize,
67    pub(super) position_embedding: Embedding,
68}
69
70/// torch.bucketize with right=True
71/// Returns a 1d tensor of shape (xs.len(),) on the CPU
72fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
73    use std::cmp::Ordering;
74
75    let mut result = Vec::with_capacity(xs.len());
76
77    for &x in xs {
78        // binary_search_by returns:
79        //   Ok(i)   if boundaries[i] == x
80        //   Err(i)  if x would be inserted at i
81        //
82        // The returned i is the "insertion point" for x to keep
83        // boundaries sorted. That i is the smallest position
84        // where boundaries[i] >= x (i.e. bisect_left).
85
86        let idx = match boundaries.binary_search_by(|&val| {
87            // Use partial_cmp here; assume no NaNs.
88            // For robust handling of NaNs, you might need a custom comparison.
89            val.partial_cmp(&x).unwrap_or(Ordering::Less)
90        }) {
91            Ok(i) => i,
92            Err(i) => i,
93        };
94
95        result.push(idx as u32);
96    }
97
98    Tensor::from_vec(result, (xs.len(),), device)
99}
100
101impl VisionEmbeddings {
102    fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
103        let conv_config = Conv2dConfig {
104            stride: config.patch_size,
105            ..Default::default()
106        };
107        let patch_embedding = conv2d(
108            config.num_channels,
109            config.hidden_size,
110            config.patch_size,
111            conv_config,
112            vb.pp("patch_embedding"),
113        )?;
114        let num_patches_per_side = config.image_size / config.patch_size;
115        let num_patches = num_patches_per_side.pow(2);
116        Ok(Self {
117            patch_size: config.patch_size,
118            patch_embedding,
119            num_patches_per_side,
120            position_embedding: embedding(
121                num_patches,
122                config.hidden_size,
123                vb.pp("position_embedding"),
124            )?,
125        })
126    }
127
128    fn forward(
129        &self,
130        pixel_values: &Tensor,
131        patch_attention_mask: &Tensor,
132        tgt_sizes: Option<&Tensor>,
133    ) -> Result<Tensor> {
134        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
135
136        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
137
138        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
139
140        let (max_nb_patches_h, max_nb_patches_w) =
141            (max_im_h / self.patch_size, max_im_w / self.patch_size);
142        let boundaries = Tensor::arange_step(
143            1.0 / self.num_patches_per_side as f32,
144            1.0,
145            1.0 / self.num_patches_per_side as f32,
146            pixel_values.device(),
147        )?
148        .to_vec1::<f32>()?;
149        let position_ids = Tensor::full(
150            0u32,
151            (bs, max_nb_patches_h * max_nb_patches_w),
152            pixel_values.device(),
153        )?;
154
155        let mut new_position_ids = Vec::new();
156        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
157            let p_attn_mask = p_attn_mask.squeeze(0)?;
158            let (nb_patches_h, nb_patches_w) = if let Some(tgt_sizes) = tgt_sizes {
159                (tgt_sizes.i((b_idx, 0))?, tgt_sizes.i((b_idx, 1))?)
160            } else {
161                let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
162                let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
163                (nb_patches_h, nb_patches_w)
164            };
165
166            let fractional_coords_h = Tensor::arange_step(
167                0.0,
168                1.0 - 1e-6,
169                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
170                pixel_values.device(),
171            )?
172            .to_vec1::<f32>()?;
173            let fractional_coords_w = Tensor::arange_step(
174                0.0,
175                1.0 - 1e-6,
176                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
177                pixel_values.device(),
178            )?
179            .to_vec1::<f32>()?;
180
181            let bucket_coords_h =
182                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
183            let bucket_coords_w =
184                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
185
186            let pos_ids = bucket_coords_h
187                .unsqueeze(D::Minus1)?
188                .mul(self.num_patches_per_side as f64)?
189                .broadcast_add(&bucket_coords_w)?
190                .flatten_all()?
191                .to_vec1::<u32>()?;
192
193            let true_indices = p_attn_mask
194                .flatten_all()?
195                .to_vec1::<u8>()?
196                .iter()
197                .enumerate()
198                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
199                .collect::<Vec<_>>();
200            let position_ids_b = position_ids.i(b_idx)?;
201
202            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
203            let new_position_ids_b_len = new_position_ids_b.len();
204            for (i, true_idx) in true_indices.into_iter().enumerate() {
205                new_position_ids_b[true_idx] = pos_ids[i];
206            }
207
208            new_position_ids.push(Tensor::from_vec(
209                new_position_ids_b,
210                new_position_ids_b_len,
211                pixel_values.device(),
212            )?);
213        }
214        let position_ids = Tensor::stack(&new_position_ids, 0)?;
215        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
216        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
217    }
218
219    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
220        let uvb = UnVarBuilder::new();
221
222        uvb.pp("patch_embedding").add(&self.patch_embedding);
223        uvb.pp("position_embedding").add(&self.position_embedding);
224
225        uvb.to_safetensors()
226    }
227}
228
229struct Attention {
230    embed_dim: usize,
231    num_heads: usize,
232    head_dim: usize,
233    scale: f32,
234    q_proj: Arc<dyn QuantMethod>,
235    k_proj: Arc<dyn QuantMethod>,
236    v_proj: Arc<dyn QuantMethod>,
237    o_proj: Arc<dyn QuantMethod>,
238}
239
240impl Attention {
241    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
242        let embed_dim = config.hidden_size;
243        let num_heads = config.num_attention_heads;
244        let head_dim = embed_dim / num_heads;
245        let scale = 1.0 / (head_dim as f32).sqrt();
246
247        let q_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("q_proj"))?;
248        let k_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("k_proj"))?;
249        let v_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("v_proj"))?;
250        let o_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("out_proj"))?;
251
252        Ok(Self {
253            embed_dim,
254            num_heads,
255            head_dim,
256            scale,
257            q_proj,
258            k_proj,
259            v_proj,
260            o_proj,
261        })
262    }
263
264    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
265        let (b_sz, q_len, _) = xs.dims3()?;
266
267        let mut q = self.q_proj.forward(xs)?;
268        let mut k = self.k_proj.forward(xs)?;
269        let mut v = self.v_proj.forward(xs)?;
270
271        q = q
272            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
273            .transpose(1, 2)?
274            .contiguous()?;
275        k = k
276            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
277            .transpose(1, 2)?
278            .contiguous()?;
279        v = v
280            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
281            .transpose(1, 2)?
282            .contiguous()?;
283
284        let attn_weights = Sdpa.run_attention(
285            &q,
286            &k,
287            &v,
288            attention_mask,
289            None,
290            &SdpaParams {
291                n_kv_groups: 1,
292                use_flash_attn: false,
293                sliding_window: None,
294                softcap: None,
295                softmax_scale: self.scale,
296            },
297        )?;
298
299        self.o_proj.forward(&attn_weights.transpose(1, 2)?.reshape((
300            b_sz,
301            q_len,
302            self.embed_dim,
303        ))?)
304    }
305
306    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
307        let uvb = UnVarBuilder::new();
308
309        uvb.pp("q_proj").add(&self.q_proj);
310        uvb.pp("k_proj").add(&self.k_proj);
311        uvb.pp("v_proj").add(&self.v_proj);
312        uvb.pp("out_proj").add(&self.o_proj);
313
314        uvb.to_safetensors()
315    }
316}
317
318struct VisionMLP {
319    activation: Activation,
320    fc1: Arc<dyn QuantMethod>,
321    fc2: Arc<dyn QuantMethod>,
322}
323
324impl VisionMLP {
325    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
326        let fc1 = mistralrs_quant::linear(
327            config.hidden_size,
328            config.intermediate_size,
329            &None,
330            vb.pp("fc1"),
331        )?;
332        let fc2 = mistralrs_quant::linear(
333            config.intermediate_size,
334            config.hidden_size,
335            &None,
336            vb.pp("fc2"),
337        )?;
338        Ok(Self {
339            activation: config.hidden_act,
340            fc1,
341            fc2,
342        })
343    }
344
345    fn forward(&self, x: &Tensor) -> Result<Tensor> {
346        let mut x = self.fc1.forward(x)?;
347        x = self.activation.forward(&x)?;
348        self.fc2.forward(&x)
349    }
350
351    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
352        let uvb = UnVarBuilder::new();
353
354        uvb.pp("fc1").add(&self.fc1);
355        uvb.pp("fc2").add(&self.fc2);
356
357        uvb.to_safetensors()
358    }
359}
360
361struct EncoderLayer {
362    mlp: VisionMLP,
363    attn: Attention,
364    layer_norm_1: LayerNorm,
365    layer_norm_2: LayerNorm,
366}
367
368impl EncoderLayer {
369    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
370        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
371        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
372        let layer_norm_1 = layer_norm(
373            config.hidden_size,
374            config.layer_norm_eps,
375            vb.pp("layer_norm1"),
376        )?;
377        let layer_norm_2 = layer_norm(
378            config.hidden_size,
379            config.layer_norm_eps,
380            vb.pp("layer_norm2"),
381        )?;
382        Ok(Self {
383            mlp,
384            attn,
385            layer_norm_1,
386            layer_norm_2,
387        })
388    }
389
390    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
391        let residual = xs.clone();
392
393        let hidden_states = self.layer_norm_1.forward(xs)?;
394        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
395        let hidden_states = (hidden_states + residual)?;
396
397        let residual = &hidden_states;
398        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
399        let hidden_states = self.mlp.forward(&hidden_states)?;
400        hidden_states + residual
401    }
402}
403
404struct Encoder {
405    layers: Vec<EncoderLayer>,
406}
407
408impl Encoder {
409    fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
410        let mut layers = Vec::new();
411        let vb_l = vb.pp("layers");
412        for i in 0..config.num_hidden_layers {
413            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
414        }
415        Ok(Self { layers })
416    }
417
418    fn forward_get_hidden_states(
419        &self,
420        xs: &Tensor,
421        attention_mask: Option<&Tensor>,
422        hidden_states_index: isize,
423    ) -> Result<Tensor> {
424        let mut hidden_states = xs.clone();
425        for (layer_idx, layer) in self.layers.iter().enumerate() {
426            hidden_states = layer.forward(&hidden_states, attention_mask)?;
427            if (self.layers.len() as isize + hidden_states_index) as usize == layer_idx {
428                return Ok(hidden_states);
429            }
430        }
431        Ok(hidden_states)
432    }
433}
434
435pub struct SiglipVisionTransformer {
436    pub(super) embeddings: VisionEmbeddings,
437    encoder: Encoder,
438    post_layernorm: LayerNorm,
439    config: SiglipVisionConfig,
440}
441
442impl SiglipVisionTransformer {
443    pub fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
444        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
445        let post_layernorm = layer_norm(
446            config.hidden_size,
447            config.layer_norm_eps,
448            vb.pp("post_layernorm"),
449        )?;
450        let encoder = Encoder::new(config, vb.pp("encoder"))?;
451        Ok(Self {
452            embeddings,
453            encoder,
454            post_layernorm,
455            config: config.clone(),
456        })
457    }
458
459    pub fn forward(
460        &self,
461        pixel_values: &Tensor,
462        attention_mask: Option<&Tensor>,
463        tgt_sizes: Option<&Tensor>,
464    ) -> Result<Tensor> {
465        self.forward_get_hidden_states(pixel_values, attention_mask, tgt_sizes, -1)
466    }
467
468    pub fn forward_get_hidden_states(
469        &self,
470        pixel_values: &Tensor,
471        attention_mask: Option<&Tensor>,
472        tgt_sizes: Option<&Tensor>,
473        hidden_states_index: isize,
474    ) -> Result<Tensor> {
475        let bs = pixel_values.dim(0)?;
476        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
477            attn_mask.clone()
478        } else {
479            let patch_size = self.config.patch_size;
480            Tensor::ones(
481                (
482                    bs,
483                    pixel_values.dim(2)? / patch_size,
484                    pixel_values.dim(3)? / patch_size,
485                ),
486                DType::U8,
487                pixel_values.device(),
488            )?
489        };
490
491        let hidden_states =
492            self.embeddings
493                .forward(pixel_values, &patch_attention_mask, tgt_sizes)?;
494
495        let attention_mask = if attention_mask.is_none() {
496            None
497        } else {
498            let mask = patch_attention_mask
499                .reshape((patch_attention_mask.dim(0)?, ()))?
500                .to_dtype(hidden_states.dtype())?;
501            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
502        };
503        let hidden_states = self.encoder.forward_get_hidden_states(
504            &hidden_states,
505            attention_mask.as_ref(),
506            hidden_states_index + 1,
507        )?;
508        hidden_states.apply(&self.post_layernorm)
509    }
510
511    pub fn dtype(&self) -> DType {
512        self.embeddings.patch_embedding.weight().dtype()
513    }
514
515    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
516        let uvb = UnVarBuilder::new();
517
518        uvb.pp("post_layernorm").add(&self.post_layernorm);
519        uvb.pp("embeddings")
520            .extend(self.embeddings.residual_tensors());
521
522        let uvb_enc = uvb.pp("encoder");
523        for (i, layer) in self.encoder.layers.iter().enumerate() {
524            let uvb_l = uvb_enc.pp("layers").pp(i);
525
526            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
527            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
528            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
529            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
530        }
531
532        uvb.to_safetensors()
533    }
534}