mistralrs_core/vision_models/
clip.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5// Sourced from https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/clip/vision_model.rs
6use candle_core::{IndexOp, Result, Shape, Tensor, D};
7use candle_nn::{Conv2dConfig, Module};
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9
10use crate::{
11    layers::{self, MatMul},
12    serde_default_fn,
13    utils::unvarbuilder::UnVarBuilder,
14};
15
16#[derive(Debug, Clone, Copy, serde::Deserialize)]
17pub enum Activation {
18    QuickGelu,
19}
20
21impl Module for Activation {
22    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
23        match self {
24            Activation::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
25        }
26    }
27}
28
29serde_default_fn!(usize, d_hidden_size, 768);
30serde_default_fn!(usize, d_intermediate_size, 3072);
31serde_default_fn!(usize, d_num_hidden_layers, 12);
32serde_default_fn!(usize, d_num_attention_heads, 12);
33serde_default_fn!(usize, d_num_channels, 3);
34serde_default_fn!(usize, d_image_size, 224);
35serde_default_fn!(usize, d_patch_size, 32);
36serde_default_fn!(Activation, d_act, Activation::QuickGelu);
37
38#[derive(Debug, Clone, serde::Deserialize)]
39pub struct ClipConfig {
40    #[serde(default = "d_hidden_size")]
41    pub hidden_size: usize,
42    #[serde(default = "d_intermediate_size")]
43    pub intermediate_size: usize,
44    #[serde(default = "d_num_hidden_layers")]
45    pub num_hidden_layers: usize,
46    #[serde(default = "d_num_attention_heads")]
47    pub num_attention_heads: usize,
48    #[serde(default = "d_num_channels")]
49    pub num_channels: usize,
50    #[serde(default = "d_image_size")]
51    pub image_size: usize,
52    #[serde(default = "d_patch_size")]
53    pub patch_size: usize,
54    #[serde(default = "d_act")]
55    pub hidden_act: Activation,
56}
57
58// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
59#[derive(Clone, Debug)]
60struct ClipVisionEmbeddings {
61    patch_embedding: candle_nn::Conv2d,
62    position_ids: Tensor,
63    class_embedding: Tensor,
64    position_embedding: candle_nn::Embedding,
65}
66
67impl ClipVisionEmbeddings {
68    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
69        // originally nn.Parameter
70        let class_embedding = if vs.contains_tensor("class_embedding") {
71            vs.get(c.hidden_size, "class_embedding")?
72        } else {
73            Tensor::randn(0f32, 1f32, c.hidden_size, vs.device())?
74        };
75
76        let num_patches = (c.image_size / c.patch_size).pow(2);
77        let num_positions = num_patches + 1;
78        let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
79
80        let conv2dconfig = Conv2dConfig {
81            stride: c.patch_size,
82            ..Default::default()
83        };
84        let position_embedding = layers::embedding(
85            num_positions,
86            c.hidden_size,
87            vs.pp("position_embedding"),
88            &None,
89        )?;
90        let patch_embedding = layers::conv2d_no_bias(
91            c.num_channels,
92            c.hidden_size,
93            c.patch_size,
94            conv2dconfig,
95            vs.pp("patch_embedding"),
96        )?;
97        Ok(Self {
98            patch_embedding,
99            position_ids,
100            class_embedding,
101            position_embedding,
102        })
103    }
104}
105
106impl Module for ClipVisionEmbeddings {
107    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
108        let batch_size = pixel_values.shape().dims();
109        let patch_embeds = self
110            .patch_embedding
111            .forward(pixel_values)?
112            .flatten_from(2)?
113            .transpose(1, 2)?;
114        let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
115        let class_embeds = self.class_embedding.expand(shape)?;
116        let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
117        let position_embedding = self.position_embedding.forward(&self.position_ids)?;
118        embeddings.broadcast_add(&position_embedding)
119    }
120}
121
122#[derive(Clone, Debug)]
123struct ClipAttention {
124    k_proj: Arc<dyn QuantMethod>,
125    v_proj: Arc<dyn QuantMethod>,
126    q_proj: Arc<dyn QuantMethod>,
127    out_proj: Arc<dyn QuantMethod>,
128    head_dim: usize,
129    scale: f64,
130    num_attention_heads: usize,
131}
132
133impl ClipAttention {
134    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
135        let hidden_size = c.hidden_size;
136        let num_attention_heads = c.num_attention_heads;
137        let k_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("k_proj"))?;
138        let v_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("v_proj"))?;
139        let q_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("q_proj"))?;
140        let out_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("out_proj"))?;
141        let head_dim = hidden_size / num_attention_heads;
142        let scale = (head_dim as f64).powf(-0.5);
143
144        Ok(ClipAttention {
145            k_proj,
146            v_proj,
147            q_proj,
148            out_proj,
149            head_dim,
150            scale,
151            num_attention_heads,
152        })
153    }
154
155    fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
156        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
157            .transpose(1, 2)?
158            .contiguous()
159    }
160
161    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
162        let (bsz, seq_len, hidden_size) = xs.dims3()?;
163
164        let query_states = (self.q_proj.forward(xs)? * self.scale)?;
165        let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
166        let query_states = self
167            .shape(&query_states, seq_len, bsz)?
168            .reshape(proj_shape)?;
169        let key_states = self
170            .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
171            .reshape(proj_shape)?;
172        let value_states = self
173            .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
174            .reshape(proj_shape)?;
175        let attn_weights = MatMul.matmul(&query_states, &key_states.transpose(1, 2)?)?;
176
177        let src_len = key_states.dim(1)?;
178
179        let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
180            attn_weights
181                .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
182                .broadcast_add(causal_attention_mask)?
183                .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
184        } else {
185            attn_weights
186        };
187
188        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
189
190        let attn_output = MatMul.matmul(&attn_weights, &value_states)?;
191        let attn_output = attn_output
192            .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
193            .transpose(1, 2)?
194            .reshape((bsz, seq_len, hidden_size))?;
195        self.out_proj.forward(&attn_output)
196    }
197}
198
199#[derive(Clone, Debug)]
200struct ClipMlp {
201    fc1: Arc<dyn QuantMethod>,
202    fc2: Arc<dyn QuantMethod>,
203    activation: Activation,
204}
205
206impl ClipMlp {
207    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
208        let fc1 = mistralrs_quant::linear(c.hidden_size, c.intermediate_size, &None, vs.pp("fc1"))?;
209        let fc2 = mistralrs_quant::linear(c.intermediate_size, c.hidden_size, &None, vs.pp("fc2"))?;
210
211        Ok(ClipMlp {
212            fc1,
213            fc2,
214            activation: c.hidden_act,
215        })
216    }
217}
218
219impl ClipMlp {
220    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
221        let xs = self.fc1.forward(xs)?;
222        self.fc2.forward(&self.activation.forward(&xs)?)
223    }
224}
225
226#[derive(Clone, Debug)]
227struct ClipEncoderLayer {
228    self_attn: ClipAttention,
229    layer_norm1: candle_nn::LayerNorm,
230    mlp: ClipMlp,
231    layer_norm2: candle_nn::LayerNorm,
232}
233
234impl ClipEncoderLayer {
235    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
236        let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
237        let layer_norm1 = layers::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm1"))?;
238        let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
239        let layer_norm2 = layers::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm2"))?;
240
241        Ok(ClipEncoderLayer {
242            self_attn,
243            layer_norm1,
244            mlp,
245            layer_norm2,
246        })
247    }
248
249    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
250        let residual = xs;
251        let xs = self.layer_norm1.forward(xs)?;
252        let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
253        let xs = (xs + residual)?;
254
255        let residual = &xs;
256        let xs = self.layer_norm2.forward(&xs)?;
257        let xs = self.mlp.forward(&xs)?;
258        xs + residual
259    }
260}
261
262#[derive(Clone, Debug)]
263pub struct ClipEncoder {
264    layers: Vec<ClipEncoderLayer>,
265}
266
267impl ClipEncoder {
268    pub fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
269        let vs = vs.pp("layers");
270        let mut layers: Vec<ClipEncoderLayer> = Vec::new();
271        for index in 0..c.num_hidden_layers {
272            let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
273            layers.push(layer)
274        }
275        Ok(ClipEncoder { layers })
276    }
277
278    pub fn forward_get_hidden_states(
279        &self,
280        xs: &Tensor,
281        causal_attention_mask: Option<&Tensor>,
282    ) -> Result<Vec<Tensor>> {
283        let mut xs = xs.clone();
284        let mut hidden_states = Vec::new();
285        for layer in self.layers.iter() {
286            xs = layer.forward(&xs, causal_attention_mask)?;
287            hidden_states.push(xs.clone());
288        }
289        Ok(hidden_states)
290    }
291}
292
293// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
294#[derive(Clone, Debug)]
295pub struct ClipVisionTransformer {
296    embeddings: ClipVisionEmbeddings,
297    encoder: ClipEncoder,
298    pre_layer_norm: candle_nn::LayerNorm,
299    final_layer_norm: candle_nn::LayerNorm,
300}
301
302impl ClipVisionTransformer {
303    /// Create a CLIP vision transformer model. Expects the vb to point to the root (not model)
304    /// where (for example) `.pp("embeddings")` is valid.
305    pub fn new(vb: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
306        let embeddings = ClipVisionEmbeddings::new(vb.pp("embeddings"), c)?;
307        let pre_layer_norm = layers::layer_norm(c.hidden_size, 1e-5, vb.pp("pre_layrnorm"))?;
308        let encoder = ClipEncoder::new(vb.pp("encoder"), c)?;
309        let final_layer_norm = layers::layer_norm(c.hidden_size, 1e-5, vb.pp("post_layernorm"))?;
310        Ok(Self {
311            embeddings,
312            encoder,
313            final_layer_norm,
314            pre_layer_norm,
315        })
316    }
317
318    pub fn forward_get_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
319        let hidden_states = pixel_values
320            .apply(&self.embeddings)?
321            .apply(&self.pre_layer_norm)?;
322        let mut result = self
323            .encoder
324            .forward_get_hidden_states(&hidden_states, None)?;
325        let encoder_outputs = result.last().unwrap();
326        let pooled_output = encoder_outputs.i((.., 0, ..))?;
327        result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
328        Ok(result)
329    }
330
331    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
332        let uvb = UnVarBuilder::new();
333
334        uvb.pp("pre_layrnorm").add(&self.pre_layer_norm);
335        uvb.pp("post_layernorm").add(&self.final_layer_norm);
336
337        // vision embeddings
338        {
339            let uvb_emb = uvb.pp("embeddings");
340
341            uvb_emb.add_tensor("class_embedding", self.embeddings.class_embedding.clone());
342            uvb_emb
343                .pp("position_embedding")
344                .add(&self.embeddings.position_embedding);
345            uvb_emb
346                .pp("patch_embedding")
347                .add(&self.embeddings.patch_embedding);
348        }
349
350        // encoder
351        {
352            let uvb_enc = uvb.pp("encoder");
353
354            for (i, layer) in self.encoder.layers.iter().enumerate() {
355                let uvb_l = uvb_enc.pp("layers").pp(i);
356
357                uvb_l.pp("layer_norm1").add(&layer.layer_norm1);
358                uvb_l.pp("layer_norm2").add(&layer.layer_norm2);
359
360                let uvb_mlp = uvb_l.pp("mlp");
361                uvb_mlp.pp("fc1").add(&layer.mlp.fc1);
362                uvb_mlp.pp("fc2").add(&layer.mlp.fc2);
363
364                let uvb_attn = uvb_l.pp("self_attn");
365                uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
366                uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
367                uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
368                uvb_attn.pp("out_proj").add(&layer.self_attn.out_proj);
369            }
370        }
371
372        uvb.to_safetensors()
373    }
374}