mistralrs_core/vision_models/
clip.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5// Sourced from https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/clip/vision_model.rs
6use candle_core::{IndexOp, Result, Shape, Tensor, D};
7use candle_nn::{Conv2dConfig, Module};
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9
10use crate::{
11    layers::{self, MatMul},
12    serde_default_fn,
13    utils::unvarbuilder::UnVarBuilder,
14};
15
16#[derive(Debug, Clone, Copy, serde::Deserialize)]
17pub enum Activation {
18    QuickGelu,
19}
20
21impl Module for Activation {
22    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
23        match self {
24            Activation::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
25        }
26    }
27}
28
29serde_default_fn!(usize, d_hidden_size, 768);
30serde_default_fn!(usize, d_intermediate_size, 3072);
31serde_default_fn!(usize, d_num_hidden_layers, 12);
32serde_default_fn!(usize, d_num_attention_heads, 12);
33serde_default_fn!(usize, d_num_channels, 3);
34serde_default_fn!(usize, d_image_size, 224);
35serde_default_fn!(usize, d_patch_size, 32);
36serde_default_fn!(Activation, d_act, Activation::QuickGelu);
37
38#[derive(Debug, Clone, serde::Deserialize)]
39pub struct ClipConfig {
40    #[serde(default = "d_hidden_size")]
41    pub hidden_size: usize,
42    #[serde(default = "d_intermediate_size")]
43    pub intermediate_size: usize,
44    #[serde(default = "d_num_hidden_layers")]
45    pub num_hidden_layers: usize,
46    #[serde(default = "d_num_attention_heads")]
47    pub num_attention_heads: usize,
48    #[serde(default = "d_num_channels")]
49    pub num_channels: usize,
50    #[serde(default = "d_image_size")]
51    pub image_size: usize,
52    #[serde(default = "d_patch_size")]
53    pub patch_size: usize,
54    #[serde(default = "d_act")]
55    pub hidden_act: Activation,
56}
57
58// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
59#[derive(Clone, Debug)]
60struct ClipVisionEmbeddings {
61    patch_embedding: candle_nn::Conv2d,
62    position_ids: Tensor,
63    class_embedding: Tensor,
64    position_embedding: candle_nn::Embedding,
65}
66
67impl ClipVisionEmbeddings {
68    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
69        // originally nn.Parameter
70        let class_embedding = if vs.contains_tensor("class_embedding") {
71            vs.get(c.hidden_size, "class_embedding")?
72        } else {
73            Tensor::randn(0f32, 1f32, c.hidden_size, vs.device())?
74        };
75
76        let num_patches = (c.image_size / c.patch_size).pow(2);
77        let num_positions = num_patches + 1;
78        let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
79
80        let conv2dconfig = Conv2dConfig {
81            stride: c.patch_size,
82            ..Default::default()
83        };
84        let position_embedding =
85            layers::embedding(num_positions, c.hidden_size, vs.pp("position_embedding"))?;
86        let patch_embedding = layers::conv2d_no_bias(
87            c.num_channels,
88            c.hidden_size,
89            c.patch_size,
90            conv2dconfig,
91            vs.pp("patch_embedding"),
92        )?;
93        Ok(Self {
94            patch_embedding,
95            position_ids,
96            class_embedding,
97            position_embedding,
98        })
99    }
100}
101
102impl Module for ClipVisionEmbeddings {
103    fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
104        let batch_size = pixel_values.shape().dims();
105        let patch_embeds = self
106            .patch_embedding
107            .forward(pixel_values)?
108            .flatten_from(2)?
109            .transpose(1, 2)?;
110        let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
111        let class_embeds = self.class_embedding.expand(shape)?;
112        let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
113        let position_embedding = self.position_embedding.forward(&self.position_ids)?;
114        embeddings.broadcast_add(&position_embedding)
115    }
116}
117
118#[derive(Clone, Debug)]
119struct ClipAttention {
120    k_proj: Arc<dyn QuantMethod>,
121    v_proj: Arc<dyn QuantMethod>,
122    q_proj: Arc<dyn QuantMethod>,
123    out_proj: Arc<dyn QuantMethod>,
124    head_dim: usize,
125    scale: f64,
126    num_attention_heads: usize,
127}
128
129impl ClipAttention {
130    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
131        let hidden_size = c.hidden_size;
132        let num_attention_heads = c.num_attention_heads;
133        let k_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("k_proj"))?;
134        let v_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("v_proj"))?;
135        let q_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("q_proj"))?;
136        let out_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("out_proj"))?;
137        let head_dim = hidden_size / num_attention_heads;
138        let scale = (head_dim as f64).powf(-0.5);
139
140        Ok(ClipAttention {
141            k_proj,
142            v_proj,
143            q_proj,
144            out_proj,
145            head_dim,
146            scale,
147            num_attention_heads,
148        })
149    }
150
151    fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
152        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
153            .transpose(1, 2)?
154            .contiguous()
155    }
156
157    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
158        let (bsz, seq_len, hidden_size) = xs.dims3()?;
159
160        let query_states = (self.q_proj.forward(xs)? * self.scale)?;
161        let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
162        let query_states = self
163            .shape(&query_states, seq_len, bsz)?
164            .reshape(proj_shape)?;
165        let key_states = self
166            .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
167            .reshape(proj_shape)?;
168        let value_states = self
169            .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
170            .reshape(proj_shape)?;
171        let attn_weights = MatMul.matmul(&query_states, &key_states.transpose(1, 2)?)?;
172
173        let src_len = key_states.dim(1)?;
174
175        let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
176            attn_weights
177                .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
178                .broadcast_add(causal_attention_mask)?
179                .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
180        } else {
181            attn_weights
182        };
183
184        let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
185
186        let attn_output = MatMul.matmul(&attn_weights, &value_states)?;
187        let attn_output = attn_output
188            .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
189            .transpose(1, 2)?
190            .reshape((bsz, seq_len, hidden_size))?;
191        self.out_proj.forward(&attn_output)
192    }
193}
194
195#[derive(Clone, Debug)]
196struct ClipMlp {
197    fc1: Arc<dyn QuantMethod>,
198    fc2: Arc<dyn QuantMethod>,
199    activation: Activation,
200}
201
202impl ClipMlp {
203    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
204        let fc1 = mistralrs_quant::linear(c.hidden_size, c.intermediate_size, &None, vs.pp("fc1"))?;
205        let fc2 = mistralrs_quant::linear(c.intermediate_size, c.hidden_size, &None, vs.pp("fc2"))?;
206
207        Ok(ClipMlp {
208            fc1,
209            fc2,
210            activation: c.hidden_act,
211        })
212    }
213}
214
215impl ClipMlp {
216    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
217        let xs = self.fc1.forward(xs)?;
218        self.fc2.forward(&self.activation.forward(&xs)?)
219    }
220}
221
222#[derive(Clone, Debug)]
223struct ClipEncoderLayer {
224    self_attn: ClipAttention,
225    layer_norm1: candle_nn::LayerNorm,
226    mlp: ClipMlp,
227    layer_norm2: candle_nn::LayerNorm,
228}
229
230impl ClipEncoderLayer {
231    fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
232        let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
233        let layer_norm1 = layers::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm1"))?;
234        let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
235        let layer_norm2 = layers::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm2"))?;
236
237        Ok(ClipEncoderLayer {
238            self_attn,
239            layer_norm1,
240            mlp,
241            layer_norm2,
242        })
243    }
244
245    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
246        let residual = xs;
247        let xs = self.layer_norm1.forward(xs)?;
248        let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
249        let xs = (xs + residual)?;
250
251        let residual = &xs;
252        let xs = self.layer_norm2.forward(&xs)?;
253        let xs = self.mlp.forward(&xs)?;
254        xs + residual
255    }
256}
257
258#[derive(Clone, Debug)]
259pub struct ClipEncoder {
260    layers: Vec<ClipEncoderLayer>,
261}
262
263impl ClipEncoder {
264    pub fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
265        let vs = vs.pp("layers");
266        let mut layers: Vec<ClipEncoderLayer> = Vec::new();
267        for index in 0..c.num_hidden_layers {
268            let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
269            layers.push(layer)
270        }
271        Ok(ClipEncoder { layers })
272    }
273
274    pub fn forward_get_hidden_states(
275        &self,
276        xs: &Tensor,
277        causal_attention_mask: Option<&Tensor>,
278    ) -> Result<Vec<Tensor>> {
279        let mut xs = xs.clone();
280        let mut hidden_states = Vec::new();
281        for layer in self.layers.iter() {
282            xs = layer.forward(&xs, causal_attention_mask)?;
283            hidden_states.push(xs.clone());
284        }
285        Ok(hidden_states)
286    }
287}
288
289// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
290#[derive(Clone, Debug)]
291pub struct ClipVisionTransformer {
292    embeddings: ClipVisionEmbeddings,
293    encoder: ClipEncoder,
294    pre_layer_norm: candle_nn::LayerNorm,
295    final_layer_norm: candle_nn::LayerNorm,
296}
297
298impl ClipVisionTransformer {
299    /// Create a CLIP vision transformer model. Expects the vb to point to the root (not model)
300    /// where (for example) `.pp("embeddings")` is valid.
301    pub fn new(vb: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
302        let embeddings = ClipVisionEmbeddings::new(vb.pp("embeddings"), c)?;
303        let pre_layer_norm = layers::layer_norm(c.hidden_size, 1e-5, vb.pp("pre_layrnorm"))?;
304        let encoder = ClipEncoder::new(vb.pp("encoder"), c)?;
305        let final_layer_norm = layers::layer_norm(c.hidden_size, 1e-5, vb.pp("post_layernorm"))?;
306        Ok(Self {
307            embeddings,
308            encoder,
309            final_layer_norm,
310            pre_layer_norm,
311        })
312    }
313
314    pub fn forward_get_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
315        let hidden_states = pixel_values
316            .apply(&self.embeddings)?
317            .apply(&self.pre_layer_norm)?;
318        let mut result = self
319            .encoder
320            .forward_get_hidden_states(&hidden_states, None)?;
321        let encoder_outputs = result.last().unwrap();
322        let pooled_output = encoder_outputs.i((.., 0, ..))?;
323        result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
324        Ok(result)
325    }
326
327    pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
328        let uvb = UnVarBuilder::new();
329
330        uvb.pp("pre_layrnorm").add(&self.pre_layer_norm);
331        uvb.pp("post_layernorm").add(&self.final_layer_norm);
332
333        // vision embeddings
334        {
335            let uvb_emb = uvb.pp("embeddings");
336
337            uvb_emb.add_tensor("class_embedding", self.embeddings.class_embedding.clone());
338            uvb_emb
339                .pp("position_embedding")
340                .add(&self.embeddings.position_embedding);
341            uvb_emb
342                .pp("patch_embedding")
343                .add(&self.embeddings.patch_embedding);
344        }
345
346        // encoder
347        {
348            let uvb_enc = uvb.pp("encoder");
349
350            for (i, layer) in self.encoder.layers.iter().enumerate() {
351                let uvb_l = uvb_enc.pp("layers").pp(i);
352
353                uvb_l.pp("layer_norm1").add(&layer.layer_norm1);
354                uvb_l.pp("layer_norm2").add(&layer.layer_norm2);
355
356                let uvb_mlp = uvb_l.pp("mlp");
357                uvb_mlp.pp("fc1").add(&layer.mlp.fc1);
358                uvb_mlp.pp("fc2").add(&layer.mlp.fc2);
359
360                let uvb_attn = uvb_l.pp("self_attn");
361                uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
362                uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
363                uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
364                uvb_attn.pp("out_proj").add(&layer.self_attn.out_proj);
365            }
366        }
367
368        uvb.to_safetensors()
369    }
370}