1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn as nn;
5use candle_nn::Module;
6use mistralrs_quant::ShardedVarBuilder;
7use serde::Deserialize;
8
9use crate::layers::{self, MatMul};
10
11#[derive(Debug, Clone, Copy, Deserialize)]
12pub enum Activation {
13 #[serde(rename = "quick_gelu")]
14 QuickGelu,
15}
16
17impl Module for Activation {
18 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
19 match self {
20 Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
21 }
22 }
23}
24
25#[derive(Debug, Clone, Deserialize)]
26pub struct ClipTextConfig {
27 pub vocab_size: usize,
28 pub projection_dim: usize,
29 pub hidden_act: Activation,
30 pub intermediate_size: usize,
31 pub max_position_embeddings: usize,
32 pub num_hidden_layers: usize,
33 pub num_attention_heads: usize,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct ClipConfig {
38 pub text_config: ClipTextConfig,
39}
40
41#[derive(Clone, Debug)]
44struct ClipTextEmbeddings {
45 token_embedding: candle_nn::Embedding,
46 position_embedding: candle_nn::Embedding,
47 position_ids: Tensor,
48}
49
50impl ClipTextEmbeddings {
51 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
52 let token_embedding = layers::embedding(
53 c.vocab_size,
54 c.projection_dim,
55 vs.pp("token_embedding"),
56 &None,
57 )?;
58 let position_embedding: nn::Embedding = layers::embedding(
59 c.max_position_embeddings,
60 c.projection_dim,
61 vs.pp("position_embedding"),
62 &None,
63 )?;
64 let position_ids =
65 Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
66 Ok(ClipTextEmbeddings {
67 token_embedding,
68 position_embedding,
69 position_ids,
70 })
71 }
72}
73
74impl Module for ClipTextEmbeddings {
75 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
76 let seq_length = input_ids.dim(D::Minus1)?;
77 let inputs_embeds = self.token_embedding.forward(input_ids)?;
78 let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
79 let position_embedding = self.position_embedding.forward(&position_ids)?;
80 inputs_embeds.broadcast_add(&position_embedding)
81 }
82}
83
84#[derive(Clone, Debug)]
85struct ClipAttention {
86 k_proj: candle_nn::Linear,
87 v_proj: candle_nn::Linear,
88 q_proj: candle_nn::Linear,
89 out_proj: candle_nn::Linear,
90 head_dim: usize,
91 scale: f64,
92 num_attention_heads: usize,
93}
94
95impl ClipAttention {
96 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
97 let projection_dim = c.projection_dim;
98 let num_attention_heads = c.num_attention_heads;
99 let k_proj = layers::linear(projection_dim, projection_dim, vs.pp("k_proj"))?;
100 let v_proj = layers::linear(projection_dim, projection_dim, vs.pp("v_proj"))?;
101 let q_proj = layers::linear(projection_dim, projection_dim, vs.pp("q_proj"))?;
102 let out_proj = layers::linear(projection_dim, projection_dim, vs.pp("out_proj"))?;
103 let head_dim = projection_dim / num_attention_heads;
104 let scale = (head_dim as f64).powf(-0.5);
105
106 Ok(ClipAttention {
107 k_proj,
108 v_proj,
109 q_proj,
110 out_proj,
111 head_dim,
112 scale,
113 num_attention_heads,
114 })
115 }
116
117 fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
118 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
119 .transpose(1, 2)?
120 .contiguous()
121 }
122
123 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
124 let in_dtype = xs.dtype();
125 let (bsz, seq_len, projection_dim) = xs.dims3()?;
126
127 let query_states = (self.q_proj.forward(xs)? * self.scale)?;
128 let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
129 let query_states = self
130 .shape(&query_states, seq_len, bsz)?
131 .reshape(proj_shape)?
132 .to_dtype(DType::F32)?;
133 let key_states = self
134 .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
135 .reshape(proj_shape)?
136 .to_dtype(DType::F32)?;
137 let value_states = self
138 .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
139 .reshape(proj_shape)?
140 .to_dtype(DType::F32)?;
141 let attn_weights = MatMul.matmul(&query_states, &key_states.transpose(1, 2)?)?;
142
143 let src_len = key_states.dim(1)?;
144
145 let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
146 attn_weights
147 .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
148 .broadcast_add(causal_attention_mask)?
149 .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
150 } else {
151 attn_weights
152 };
153
154 let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
155
156 let attn_output = MatMul
157 .matmul(&attn_weights, &value_states)?
158 .to_dtype(in_dtype)?;
159 let attn_output = attn_output
160 .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
161 .transpose(1, 2)?
162 .reshape((bsz, seq_len, projection_dim))?;
163 self.out_proj.forward(&attn_output)
164 }
165}
166
167#[derive(Clone, Debug)]
168struct ClipMlp {
169 fc1: candle_nn::Linear,
170 fc2: candle_nn::Linear,
171 activation: Activation,
172}
173
174impl ClipMlp {
175 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
176 let fc1 = layers::linear(c.projection_dim, c.intermediate_size, vs.pp("fc1"))?;
177 let fc2 = layers::linear(c.intermediate_size, c.projection_dim, vs.pp("fc2"))?;
178
179 Ok(ClipMlp {
180 fc1,
181 fc2,
182 activation: c.hidden_act,
183 })
184 }
185}
186
187impl ClipMlp {
188 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
189 let xs = self.fc1.forward(xs)?;
190 self.fc2.forward(&self.activation.forward(&xs)?)
191 }
192}
193
194#[derive(Clone, Debug)]
195struct ClipEncoderLayer {
196 self_attn: ClipAttention,
197 layer_norm1: candle_nn::LayerNorm,
198 mlp: ClipMlp,
199 layer_norm2: candle_nn::LayerNorm,
200}
201
202impl ClipEncoderLayer {
203 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
204 let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
205 let layer_norm1 = layers::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm1"))?;
206 let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
207 let layer_norm2 = layers::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm2"))?;
208
209 Ok(ClipEncoderLayer {
210 self_attn,
211 layer_norm1,
212 mlp,
213 layer_norm2,
214 })
215 }
216
217 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
218 let residual = xs;
219 let xs = self.layer_norm1.forward(xs)?;
220 let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
221 let xs = (xs + residual)?;
222
223 let residual = &xs;
224 let xs = self.layer_norm2.forward(&xs)?;
225 let xs = self.mlp.forward(&xs)?;
226 xs + residual
227 }
228}
229
230#[derive(Clone, Debug)]
231pub struct ClipEncoder {
232 layers: Vec<ClipEncoderLayer>,
233}
234
235impl ClipEncoder {
236 pub fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
237 let vs = vs.pp("layers");
238 let mut layers: Vec<ClipEncoderLayer> = Vec::new();
239 for index in 0..c.num_hidden_layers {
240 let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
241 layers.push(layer)
242 }
243 Ok(ClipEncoder { layers })
244 }
245
246 pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
247 let mut xs = xs.clone();
248 for layer in self.layers.iter() {
249 xs = layer.forward(&xs, causal_attention_mask)?;
250 }
251 Ok(xs)
252 }
253}
254
255#[derive(Clone, Debug)]
257pub struct ClipTextTransformer {
258 embeddings: ClipTextEmbeddings,
259 encoder: ClipEncoder,
260 final_layer_norm: candle_nn::LayerNorm,
261}
262
263impl ClipTextTransformer {
264 pub fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
265 let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
266 let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
267 let final_layer_norm =
268 layers::layer_norm(c.projection_dim, 1e-5, vs.pp("final_layer_norm"))?;
269 Ok(ClipTextTransformer {
270 embeddings,
271 encoder,
272 final_layer_norm,
273 })
274 }
275
276 fn build_causal_attention_mask(
278 bsz: usize,
279 seq_len: usize,
280 mask_after: usize,
281 device: &Device,
282 ) -> Result<Tensor> {
283 let mask: Vec<_> = (0..seq_len)
284 .flat_map(|i| {
285 (0..seq_len).map(move |j| {
286 if j > i || j > mask_after {
287 f32::MIN
288 } else {
289 0.
290 }
291 })
292 })
293 .collect();
294 let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
295 mask.broadcast_as((bsz, 1, seq_len, seq_len))
296 }
297
298 pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
299 let (bsz, seq_len) = input_ids.dims2()?;
300 let input_ids = self.embeddings.forward(input_ids)?;
301 let causal_attention_mask =
302 Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
303 let input_ids = self
304 .encoder
305 .forward(&input_ids, Some(&causal_attention_mask))?;
306 self.final_layer_norm.forward(&input_ids)
307 }
308}
309
310impl Module for ClipTextTransformer {
311 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
312 let output = self.forward_with_mask(input_ids, usize::MAX)?;
313 let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
314
315 let mut indices = Vec::new();
316 for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
317 let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
318 indices.push(index);
319 }
320 Tensor::cat(&indices, 0)
321 }
322}