1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5use candle_core::{IndexOp, Result, Shape, Tensor, D};
7use candle_nn::{Conv2dConfig, Module};
8use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
9
10use crate::{
11 layers::{self, MatMul},
12 serde_default_fn,
13 utils::unvarbuilder::UnVarBuilder,
14};
15
16#[derive(Debug, Clone, Copy, serde::Deserialize)]
17pub enum Activation {
18 QuickGelu,
19}
20
21impl Module for Activation {
22 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
23 match self {
24 Activation::QuickGelu => xs * candle_nn::ops::sigmoid(&(xs * 1.702f64)?),
25 }
26 }
27}
28
29serde_default_fn!(usize, d_hidden_size, 768);
30serde_default_fn!(usize, d_intermediate_size, 3072);
31serde_default_fn!(usize, d_num_hidden_layers, 12);
32serde_default_fn!(usize, d_num_attention_heads, 12);
33serde_default_fn!(usize, d_num_channels, 3);
34serde_default_fn!(usize, d_image_size, 224);
35serde_default_fn!(usize, d_patch_size, 32);
36serde_default_fn!(Activation, d_act, Activation::QuickGelu);
37
38#[derive(Debug, Clone, serde::Deserialize)]
39pub struct ClipConfig {
40 #[serde(default = "d_hidden_size")]
41 pub hidden_size: usize,
42 #[serde(default = "d_intermediate_size")]
43 pub intermediate_size: usize,
44 #[serde(default = "d_num_hidden_layers")]
45 pub num_hidden_layers: usize,
46 #[serde(default = "d_num_attention_heads")]
47 pub num_attention_heads: usize,
48 #[serde(default = "d_num_channels")]
49 pub num_channels: usize,
50 #[serde(default = "d_image_size")]
51 pub image_size: usize,
52 #[serde(default = "d_patch_size")]
53 pub patch_size: usize,
54 #[serde(default = "d_act")]
55 pub hidden_act: Activation,
56}
57
58#[derive(Clone, Debug)]
60struct ClipVisionEmbeddings {
61 patch_embedding: candle_nn::Conv2d,
62 position_ids: Tensor,
63 class_embedding: Tensor,
64 position_embedding: candle_nn::Embedding,
65}
66
67impl ClipVisionEmbeddings {
68 fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
69 let class_embedding = if vs.contains_tensor("class_embedding") {
71 vs.get(c.hidden_size, "class_embedding")?
72 } else {
73 Tensor::randn(0f32, 1f32, c.hidden_size, vs.device())?
74 };
75
76 let num_patches = (c.image_size / c.patch_size).pow(2);
77 let num_positions = num_patches + 1;
78 let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
79
80 let conv2dconfig = Conv2dConfig {
81 stride: c.patch_size,
82 ..Default::default()
83 };
84 let position_embedding =
85 layers::embedding(num_positions, c.hidden_size, vs.pp("position_embedding"))?;
86 let patch_embedding = layers::conv2d_no_bias(
87 c.num_channels,
88 c.hidden_size,
89 c.patch_size,
90 conv2dconfig,
91 vs.pp("patch_embedding"),
92 )?;
93 Ok(Self {
94 patch_embedding,
95 position_ids,
96 class_embedding,
97 position_embedding,
98 })
99 }
100}
101
102impl Module for ClipVisionEmbeddings {
103 fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
104 let batch_size = pixel_values.shape().dims();
105 let patch_embeds = self
106 .patch_embedding
107 .forward(pixel_values)?
108 .flatten_from(2)?
109 .transpose(1, 2)?;
110 let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
111 let class_embeds = self.class_embedding.expand(shape)?;
112 let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
113 let position_embedding = self.position_embedding.forward(&self.position_ids)?;
114 embeddings.broadcast_add(&position_embedding)
115 }
116}
117
118#[derive(Clone, Debug)]
119struct ClipAttention {
120 k_proj: Arc<dyn QuantMethod>,
121 v_proj: Arc<dyn QuantMethod>,
122 q_proj: Arc<dyn QuantMethod>,
123 out_proj: Arc<dyn QuantMethod>,
124 head_dim: usize,
125 scale: f64,
126 num_attention_heads: usize,
127}
128
129impl ClipAttention {
130 fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
131 let hidden_size = c.hidden_size;
132 let num_attention_heads = c.num_attention_heads;
133 let k_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("k_proj"))?;
134 let v_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("v_proj"))?;
135 let q_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("q_proj"))?;
136 let out_proj = mistralrs_quant::linear(hidden_size, hidden_size, &None, vs.pp("out_proj"))?;
137 let head_dim = hidden_size / num_attention_heads;
138 let scale = (head_dim as f64).powf(-0.5);
139
140 Ok(ClipAttention {
141 k_proj,
142 v_proj,
143 q_proj,
144 out_proj,
145 head_dim,
146 scale,
147 num_attention_heads,
148 })
149 }
150
151 fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
152 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
153 .transpose(1, 2)?
154 .contiguous()
155 }
156
157 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
158 let (bsz, seq_len, hidden_size) = xs.dims3()?;
159
160 let query_states = (self.q_proj.forward(xs)? * self.scale)?;
161 let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
162 let query_states = self
163 .shape(&query_states, seq_len, bsz)?
164 .reshape(proj_shape)?;
165 let key_states = self
166 .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
167 .reshape(proj_shape)?;
168 let value_states = self
169 .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
170 .reshape(proj_shape)?;
171 let attn_weights = MatMul.matmul(&query_states, &key_states.transpose(1, 2)?)?;
172
173 let src_len = key_states.dim(1)?;
174
175 let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
176 attn_weights
177 .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
178 .broadcast_add(causal_attention_mask)?
179 .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
180 } else {
181 attn_weights
182 };
183
184 let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
185
186 let attn_output = MatMul.matmul(&attn_weights, &value_states)?;
187 let attn_output = attn_output
188 .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
189 .transpose(1, 2)?
190 .reshape((bsz, seq_len, hidden_size))?;
191 self.out_proj.forward(&attn_output)
192 }
193}
194
195#[derive(Clone, Debug)]
196struct ClipMlp {
197 fc1: Arc<dyn QuantMethod>,
198 fc2: Arc<dyn QuantMethod>,
199 activation: Activation,
200}
201
202impl ClipMlp {
203 fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
204 let fc1 = mistralrs_quant::linear(c.hidden_size, c.intermediate_size, &None, vs.pp("fc1"))?;
205 let fc2 = mistralrs_quant::linear(c.intermediate_size, c.hidden_size, &None, vs.pp("fc2"))?;
206
207 Ok(ClipMlp {
208 fc1,
209 fc2,
210 activation: c.hidden_act,
211 })
212 }
213}
214
215impl ClipMlp {
216 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
217 let xs = self.fc1.forward(xs)?;
218 self.fc2.forward(&self.activation.forward(&xs)?)
219 }
220}
221
222#[derive(Clone, Debug)]
223struct ClipEncoderLayer {
224 self_attn: ClipAttention,
225 layer_norm1: candle_nn::LayerNorm,
226 mlp: ClipMlp,
227 layer_norm2: candle_nn::LayerNorm,
228}
229
230impl ClipEncoderLayer {
231 fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
232 let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
233 let layer_norm1 = layers::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm1"))?;
234 let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
235 let layer_norm2 = layers::layer_norm(c.hidden_size, 1e-5, vs.pp("layer_norm2"))?;
236
237 Ok(ClipEncoderLayer {
238 self_attn,
239 layer_norm1,
240 mlp,
241 layer_norm2,
242 })
243 }
244
245 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
246 let residual = xs;
247 let xs = self.layer_norm1.forward(xs)?;
248 let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
249 let xs = (xs + residual)?;
250
251 let residual = &xs;
252 let xs = self.layer_norm2.forward(&xs)?;
253 let xs = self.mlp.forward(&xs)?;
254 xs + residual
255 }
256}
257
258#[derive(Clone, Debug)]
259pub struct ClipEncoder {
260 layers: Vec<ClipEncoderLayer>,
261}
262
263impl ClipEncoder {
264 pub fn new(vs: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
265 let vs = vs.pp("layers");
266 let mut layers: Vec<ClipEncoderLayer> = Vec::new();
267 for index in 0..c.num_hidden_layers {
268 let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
269 layers.push(layer)
270 }
271 Ok(ClipEncoder { layers })
272 }
273
274 pub fn forward_get_hidden_states(
275 &self,
276 xs: &Tensor,
277 causal_attention_mask: Option<&Tensor>,
278 ) -> Result<Vec<Tensor>> {
279 let mut xs = xs.clone();
280 let mut hidden_states = Vec::new();
281 for layer in self.layers.iter() {
282 xs = layer.forward(&xs, causal_attention_mask)?;
283 hidden_states.push(xs.clone());
284 }
285 Ok(hidden_states)
286 }
287}
288
289#[derive(Clone, Debug)]
291pub struct ClipVisionTransformer {
292 embeddings: ClipVisionEmbeddings,
293 encoder: ClipEncoder,
294 pre_layer_norm: candle_nn::LayerNorm,
295 final_layer_norm: candle_nn::LayerNorm,
296}
297
298impl ClipVisionTransformer {
299 pub fn new(vb: ShardedVarBuilder, c: &ClipConfig) -> Result<Self> {
302 let embeddings = ClipVisionEmbeddings::new(vb.pp("embeddings"), c)?;
303 let pre_layer_norm = layers::layer_norm(c.hidden_size, 1e-5, vb.pp("pre_layrnorm"))?;
304 let encoder = ClipEncoder::new(vb.pp("encoder"), c)?;
305 let final_layer_norm = layers::layer_norm(c.hidden_size, 1e-5, vb.pp("post_layernorm"))?;
306 Ok(Self {
307 embeddings,
308 encoder,
309 final_layer_norm,
310 pre_layer_norm,
311 })
312 }
313
314 pub fn forward_get_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
315 let hidden_states = pixel_values
316 .apply(&self.embeddings)?
317 .apply(&self.pre_layer_norm)?;
318 let mut result = self
319 .encoder
320 .forward_get_hidden_states(&hidden_states, None)?;
321 let encoder_outputs = result.last().unwrap();
322 let pooled_output = encoder_outputs.i((.., 0, ..))?;
323 result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
324 Ok(result)
325 }
326
327 pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
328 let uvb = UnVarBuilder::new();
329
330 uvb.pp("pre_layrnorm").add(&self.pre_layer_norm);
331 uvb.pp("post_layernorm").add(&self.final_layer_norm);
332
333 {
335 let uvb_emb = uvb.pp("embeddings");
336
337 uvb_emb.add_tensor("class_embedding", self.embeddings.class_embedding.clone());
338 uvb_emb
339 .pp("position_embedding")
340 .add(&self.embeddings.position_embedding);
341 uvb_emb
342 .pp("patch_embedding")
343 .add(&self.embeddings.patch_embedding);
344 }
345
346 {
348 let uvb_enc = uvb.pp("encoder");
349
350 for (i, layer) in self.encoder.layers.iter().enumerate() {
351 let uvb_l = uvb_enc.pp("layers").pp(i);
352
353 uvb_l.pp("layer_norm1").add(&layer.layer_norm1);
354 uvb_l.pp("layer_norm2").add(&layer.layer_norm2);
355
356 let uvb_mlp = uvb_l.pp("mlp");
357 uvb_mlp.pp("fc1").add(&layer.mlp.fc1);
358 uvb_mlp.pp("fc2").add(&layer.mlp.fc2);
359
360 let uvb_attn = uvb_l.pp("self_attn");
361 uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
362 uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
363 uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
364 uvb_attn.pp("out_proj").add(&layer.self_attn.out_proj);
365 }
366 }
367
368 uvb.to_safetensors()
369 }
370}