mistralrs_core/vision_models/
siglip.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Conv2d, Conv2dConfig, Embedding, LayerNorm, Module};
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{ops::Mul, sync::Arc};
7
8use crate::{
9    attention::SdpaParams,
10    layers::{conv2d, embedding, layer_norm, Activation, CausalMasker, Sdpa},
11    serde_default_fn,
12    utils::unvarbuilder::UnVarBuilder,
13};
14
15serde_default_fn!(usize, hidden_size, 768);
16serde_default_fn!(usize, intermediate_size, 3072);
17serde_default_fn!(usize, num_hidden_layers, 12);
18serde_default_fn!(usize, num_attention_heads, 12);
19serde_default_fn!(usize, num_channels, 3);
20serde_default_fn!(usize, image_size, 224);
21serde_default_fn!(usize, patch_size, 16);
22serde_default_fn!(Activation, hidden_act, Activation::GeluPytorchTanh);
23serde_default_fn!(f64, layer_norm_eps, 1e-6);
24
25#[derive(Debug, Clone, serde::Deserialize)]
26pub struct SiglipVisionConfig {
27    #[serde(default = "hidden_size")]
28    pub hidden_size: usize,
29    #[serde(default = "intermediate_size")]
30    pub intermediate_size: usize,
31    #[serde(default = "num_hidden_layers")]
32    pub num_hidden_layers: usize,
33    #[serde(default = "num_attention_heads")]
34    pub num_attention_heads: usize,
35    #[serde(default = "num_channels")]
36    pub num_channels: usize,
37    #[serde(default = "image_size")]
38    pub image_size: usize,
39    #[serde(default = "patch_size")]
40    pub patch_size: usize,
41    #[serde(default = "hidden_act")]
42    pub hidden_act: Activation,
43    #[serde(default = "layer_norm_eps")]
44    pub layer_norm_eps: f64,
45}
46
47impl Default for SiglipVisionConfig {
48    fn default() -> Self {
49        Self {
50            hidden_size: 768,
51            intermediate_size: 3072,
52            num_hidden_layers: 12,
53            num_attention_heads: 12,
54            num_channels: 3,
55            image_size: 224,
56            patch_size: 16,
57            hidden_act: Activation::GeluPytorchTanh,
58            layer_norm_eps: 1e-6,
59        }
60    }
61}
62
63pub(super) struct VisionEmbeddings {
64    patch_size: usize,
65    patch_embedding: Conv2d,
66    num_patches_per_side: usize,
67    pub(super) position_embedding: Embedding,
68}
69
70/// torch.bucketize with right=True
71/// Returns a 1d tensor of shape (xs.len(),) on the CPU
72fn bucketize_right(xs: &[f32], boundaries: &[f32], device: &Device) -> Result<Tensor> {
73    use std::cmp::Ordering;
74
75    let mut result = Vec::with_capacity(xs.len());
76
77    for &x in xs {
78        // binary_search_by returns:
79        //   Ok(i)   if boundaries[i] == x
80        //   Err(i)  if x would be inserted at i
81        //
82        // The returned i is the "insertion point" for x to keep
83        // boundaries sorted. That i is the smallest position
84        // where boundaries[i] >= x (i.e. bisect_left).
85
86        let idx = match boundaries.binary_search_by(|&val| {
87            // Use partial_cmp here; assume no NaNs.
88            // For robust handling of NaNs, you might need a custom comparison.
89            val.partial_cmp(&x).unwrap_or(Ordering::Less)
90        }) {
91            Ok(i) => i,
92            Err(i) => i,
93        };
94
95        result.push(idx as u32);
96    }
97
98    Tensor::from_vec(result, (xs.len(),), device)
99}
100
101impl VisionEmbeddings {
102    fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
103        let conv_config = Conv2dConfig {
104            stride: config.patch_size,
105            ..Default::default()
106        };
107        let patch_embedding = conv2d(
108            config.num_channels,
109            config.hidden_size,
110            config.patch_size,
111            conv_config,
112            vb.pp("patch_embedding"),
113        )?;
114        let num_patches_per_side = config.image_size / config.patch_size;
115        let num_patches = num_patches_per_side.pow(2);
116        Ok(Self {
117            patch_size: config.patch_size,
118            patch_embedding,
119            num_patches_per_side,
120            position_embedding: embedding(
121                num_patches,
122                config.hidden_size,
123                vb.pp("position_embedding"),
124                &None,
125            )?,
126        })
127    }
128
129    fn forward(
130        &self,
131        pixel_values: &Tensor,
132        patch_attention_mask: &Tensor,
133        tgt_sizes: Option<&Tensor>,
134    ) -> Result<Tensor> {
135        let (bs, _, max_im_h, max_im_w) = pixel_values.dims4()?;
136
137        let patch_embeds = self.patch_embedding.forward(pixel_values)?;
138
139        let embeddings = patch_embeds.flatten(2, D::Minus1)?.transpose(1, 2)?;
140
141        let (max_nb_patches_h, max_nb_patches_w) =
142            (max_im_h / self.patch_size, max_im_w / self.patch_size);
143        let boundaries = Tensor::arange_step(
144            1.0 / self.num_patches_per_side as f32,
145            1.0,
146            1.0 / self.num_patches_per_side as f32,
147            pixel_values.device(),
148        )?
149        .to_vec1::<f32>()?;
150        let position_ids = Tensor::full(
151            0u32,
152            (bs, max_nb_patches_h * max_nb_patches_w),
153            pixel_values.device(),
154        )?;
155
156        let mut new_position_ids = Vec::new();
157        for (b_idx, p_attn_mask) in patch_attention_mask.chunk(bs, 0)?.iter().enumerate() {
158            let p_attn_mask = p_attn_mask.squeeze(0)?;
159            let (nb_patches_h, nb_patches_w) = if let Some(tgt_sizes) = tgt_sizes {
160                (tgt_sizes.i((b_idx, 0))?, tgt_sizes.i((b_idx, 1))?)
161            } else {
162                let nb_patches_h = p_attn_mask.i((.., 0))?.sum_all()?;
163                let nb_patches_w = p_attn_mask.i((0,))?.sum_all()?;
164                (nb_patches_h, nb_patches_w)
165            };
166
167            let fractional_coords_h = Tensor::arange_step(
168                0.0,
169                1.0 - 1e-6,
170                1.0 / nb_patches_h.to_dtype(DType::F32)?.to_scalar::<f32>()?,
171                pixel_values.device(),
172            )?
173            .to_vec1::<f32>()?;
174            let fractional_coords_w = Tensor::arange_step(
175                0.0,
176                1.0 - 1e-6,
177                1.0 / nb_patches_w.to_dtype(DType::F32)?.to_scalar::<f32>()?,
178                pixel_values.device(),
179            )?
180            .to_vec1::<f32>()?;
181
182            let bucket_coords_h =
183                bucketize_right(&fractional_coords_h, &boundaries, pixel_values.device())?;
184            let bucket_coords_w =
185                bucketize_right(&fractional_coords_w, &boundaries, pixel_values.device())?;
186
187            let pos_ids = bucket_coords_h
188                .unsqueeze(D::Minus1)?
189                .mul(self.num_patches_per_side as f64)?
190                .broadcast_add(&bucket_coords_w)?
191                .flatten_all()?
192                .to_vec1::<u32>()?;
193
194            let true_indices = p_attn_mask
195                .flatten_all()?
196                .to_vec1::<u8>()?
197                .iter()
198                .enumerate()
199                .filter_map(|(i, x)| if *x != 0 { Some(i) } else { None })
200                .collect::<Vec<_>>();
201            let position_ids_b = position_ids.i(b_idx)?;
202
203            let mut new_position_ids_b = position_ids_b.to_vec1::<u32>()?;
204            let new_position_ids_b_len = new_position_ids_b.len();
205            for (i, true_idx) in true_indices.into_iter().enumerate() {
206                new_position_ids_b[true_idx] = pos_ids[i];
207            }
208
209            new_position_ids.push(Tensor::from_vec(
210                new_position_ids_b,
211                new_position_ids_b_len,
212                pixel_values.device(),
213            )?);
214        }
215        let position_ids = Tensor::stack(&new_position_ids, 0)?;
216        let position_ids = position_ids.to_device(self.position_embedding.embeddings().device())?;
217        embeddings.broadcast_add(&self.position_embedding.forward(&position_ids)?)
218    }
219
220    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
221        let uvb = UnVarBuilder::new();
222
223        uvb.pp("patch_embedding").add(&self.patch_embedding);
224        uvb.pp("position_embedding").add(&self.position_embedding);
225
226        uvb.to_safetensors()
227    }
228}
229
230struct Attention {
231    embed_dim: usize,
232    num_heads: usize,
233    head_dim: usize,
234    scale: f32,
235    q_proj: Arc<dyn QuantMethod>,
236    k_proj: Arc<dyn QuantMethod>,
237    v_proj: Arc<dyn QuantMethod>,
238    o_proj: Arc<dyn QuantMethod>,
239}
240
241impl Attention {
242    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
243        let embed_dim = config.hidden_size;
244        let num_heads = config.num_attention_heads;
245        let head_dim = embed_dim / num_heads;
246        let scale = 1.0 / (head_dim as f32).sqrt();
247
248        let q_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("q_proj"))?;
249        let k_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("k_proj"))?;
250        let v_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("v_proj"))?;
251        let o_proj = mistralrs_quant::linear(embed_dim, embed_dim, &None, vb.pp("out_proj"))?;
252
253        Ok(Self {
254            embed_dim,
255            num_heads,
256            head_dim,
257            scale,
258            q_proj,
259            k_proj,
260            v_proj,
261            o_proj,
262        })
263    }
264
265    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
266        let (b_sz, q_len, _) = xs.dims3()?;
267
268        let mut q = self.q_proj.forward(xs)?;
269        let mut k = self.k_proj.forward(xs)?;
270        let mut v = self.v_proj.forward(xs)?;
271
272        q = q
273            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
274            .transpose(1, 2)?;
275        k = k
276            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
277            .transpose(1, 2)?;
278        v = v
279            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
280            .transpose(1, 2)?;
281
282        let attn_weights = Sdpa.run_attention(
283            &q,
284            &k,
285            &v,
286            attention_mask,
287            None,
288            &SdpaParams {
289                n_kv_groups: 1,
290                sliding_window: None,
291                softcap: None,
292                softmax_scale: self.scale,
293            },
294        )?;
295
296        self.o_proj.forward(&attn_weights.transpose(1, 2)?.reshape((
297            b_sz,
298            q_len,
299            self.embed_dim,
300        ))?)
301    }
302
303    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
304        let uvb = UnVarBuilder::new();
305
306        uvb.pp("q_proj").add(&self.q_proj);
307        uvb.pp("k_proj").add(&self.k_proj);
308        uvb.pp("v_proj").add(&self.v_proj);
309        uvb.pp("out_proj").add(&self.o_proj);
310
311        uvb.to_safetensors()
312    }
313}
314
315struct VisionMLP {
316    activation: Activation,
317    fc1: Arc<dyn QuantMethod>,
318    fc2: Arc<dyn QuantMethod>,
319}
320
321impl VisionMLP {
322    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
323        let fc1 = mistralrs_quant::linear(
324            config.hidden_size,
325            config.intermediate_size,
326            &None,
327            vb.pp("fc1"),
328        )?;
329        let fc2 = mistralrs_quant::linear(
330            config.intermediate_size,
331            config.hidden_size,
332            &None,
333            vb.pp("fc2"),
334        )?;
335        Ok(Self {
336            activation: config.hidden_act,
337            fc1,
338            fc2,
339        })
340    }
341
342    fn forward(&self, x: &Tensor) -> Result<Tensor> {
343        let mut x = self.fc1.forward(x)?;
344        x = self.activation.forward(&x)?;
345        self.fc2.forward(&x)
346    }
347
348    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
349        let uvb = UnVarBuilder::new();
350
351        uvb.pp("fc1").add(&self.fc1);
352        uvb.pp("fc2").add(&self.fc2);
353
354        uvb.to_safetensors()
355    }
356}
357
358struct EncoderLayer {
359    mlp: VisionMLP,
360    attn: Attention,
361    layer_norm_1: LayerNorm,
362    layer_norm_2: LayerNorm,
363}
364
365impl EncoderLayer {
366    fn new(config: SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
367        let mlp = VisionMLP::new(config.clone(), vb.pp("mlp"))?;
368        let attn = Attention::new(config.clone(), vb.pp("self_attn"))?;
369        let layer_norm_1 = layer_norm(
370            config.hidden_size,
371            config.layer_norm_eps,
372            vb.pp("layer_norm1"),
373        )?;
374        let layer_norm_2 = layer_norm(
375            config.hidden_size,
376            config.layer_norm_eps,
377            vb.pp("layer_norm2"),
378        )?;
379        Ok(Self {
380            mlp,
381            attn,
382            layer_norm_1,
383            layer_norm_2,
384        })
385    }
386
387    fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
388        let residual = xs.clone();
389
390        let hidden_states = self.layer_norm_1.forward(xs)?;
391        let hidden_states = self.attn.forward(&hidden_states, attention_mask)?;
392        let hidden_states = (hidden_states + residual)?;
393
394        let residual = &hidden_states;
395        let hidden_states = self.layer_norm_2.forward(&hidden_states)?;
396        let hidden_states = self.mlp.forward(&hidden_states)?;
397        hidden_states + residual
398    }
399}
400
401struct Encoder {
402    layers: Vec<EncoderLayer>,
403}
404
405impl Encoder {
406    fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
407        let mut layers = Vec::new();
408        let vb_l = vb.pp("layers");
409        for i in 0..config.num_hidden_layers {
410            layers.push(EncoderLayer::new(config.clone(), vb_l.pp(i))?);
411        }
412        Ok(Self { layers })
413    }
414
415    fn forward_get_hidden_states(
416        &self,
417        xs: &Tensor,
418        attention_mask: Option<&Tensor>,
419        hidden_states_index: isize,
420    ) -> Result<Tensor> {
421        let mut hidden_states = xs.clone();
422        for (layer_idx, layer) in self.layers.iter().enumerate() {
423            hidden_states = layer.forward(&hidden_states, attention_mask)?;
424            if (self.layers.len() as isize + hidden_states_index) as usize == layer_idx {
425                return Ok(hidden_states);
426            }
427        }
428        Ok(hidden_states)
429    }
430}
431
432pub struct SiglipVisionTransformer {
433    pub(super) embeddings: VisionEmbeddings,
434    encoder: Encoder,
435    post_layernorm: LayerNorm,
436    config: SiglipVisionConfig,
437}
438
439impl SiglipVisionTransformer {
440    pub fn new(config: &SiglipVisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
441        let embeddings = VisionEmbeddings::new(config, vb.pp("embeddings"))?;
442        let post_layernorm = layer_norm(
443            config.hidden_size,
444            config.layer_norm_eps,
445            vb.pp("post_layernorm"),
446        )?;
447        let encoder = Encoder::new(config, vb.pp("encoder"))?;
448        Ok(Self {
449            embeddings,
450            encoder,
451            post_layernorm,
452            config: config.clone(),
453        })
454    }
455
456    pub fn forward(
457        &self,
458        pixel_values: &Tensor,
459        attention_mask: Option<&Tensor>,
460        tgt_sizes: Option<&Tensor>,
461    ) -> Result<Tensor> {
462        self.forward_get_hidden_states(pixel_values, attention_mask, tgt_sizes, -1)
463    }
464
465    pub fn forward_get_hidden_states(
466        &self,
467        pixel_values: &Tensor,
468        attention_mask: Option<&Tensor>,
469        tgt_sizes: Option<&Tensor>,
470        hidden_states_index: isize,
471    ) -> Result<Tensor> {
472        let bs = pixel_values.dim(0)?;
473        let patch_attention_mask = if let Some(attn_mask) = attention_mask {
474            attn_mask.clone()
475        } else {
476            let patch_size = self.config.patch_size;
477            Tensor::ones(
478                (
479                    bs,
480                    pixel_values.dim(2)? / patch_size,
481                    pixel_values.dim(3)? / patch_size,
482                ),
483                DType::U8,
484                pixel_values.device(),
485            )?
486        };
487
488        let hidden_states =
489            self.embeddings
490                .forward(pixel_values, &patch_attention_mask, tgt_sizes)?;
491
492        let attention_mask = if attention_mask.is_none() {
493            None
494        } else {
495            let mask = patch_attention_mask
496                .reshape((patch_attention_mask.dim(0)?, ()))?
497                .to_dtype(hidden_states.dtype())?;
498            Some(CausalMasker.expand_mask(&mask, hidden_states.dtype(), None)?)
499        };
500        let hidden_states = self.encoder.forward_get_hidden_states(
501            &hidden_states,
502            attention_mask.as_ref(),
503            hidden_states_index + 1,
504        )?;
505        hidden_states.apply(&self.post_layernorm)
506    }
507
508    pub fn dtype(&self) -> DType {
509        self.embeddings.patch_embedding.weight().dtype()
510    }
511
512    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
513        let uvb = UnVarBuilder::new();
514
515        uvb.pp("post_layernorm").add(&self.post_layernorm);
516        uvb.pp("embeddings")
517            .extend(self.embeddings.residual_tensors());
518
519        let uvb_enc = uvb.pp("encoder");
520        for (i, layer) in self.encoder.layers.iter().enumerate() {
521            let uvb_l = uvb_enc.pp("layers").pp(i);
522
523            uvb_l.pp("layer_norm1").add(&layer.layer_norm_1);
524            uvb_l.pp("layer_norm2").add(&layer.layer_norm_2);
525            uvb_l.pp("mlp").extend(layer.mlp.residual_tensors());
526            uvb_l.pp("self_attn").extend(layer.attn.residual_tensors());
527        }
528
529        uvb.to_safetensors()
530    }
531}