1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn as nn;
5use candle_nn::Module;
6use mistralrs_quant::ShardedVarBuilder;
7use serde::Deserialize;
8
9use crate::layers::{self, MatMul};
10
11#[derive(Debug, Clone, Copy, Deserialize)]
12pub enum Activation {
13 #[serde(rename = "quick_gelu")]
14 QuickGelu,
15}
16
17impl Module for Activation {
18 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
19 match self {
20 Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
21 }
22 }
23}
24
25#[derive(Debug, Clone, Deserialize)]
26pub struct ClipTextConfig {
27 pub vocab_size: usize,
28 pub projection_dim: usize,
29 pub hidden_act: Activation,
30 pub intermediate_size: usize,
31 pub max_position_embeddings: usize,
32 pub num_hidden_layers: usize,
33 pub num_attention_heads: usize,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct ClipConfig {
38 pub text_config: ClipTextConfig,
39}
40
41#[derive(Clone, Debug)]
44struct ClipTextEmbeddings {
45 token_embedding: candle_nn::Embedding,
46 position_embedding: candle_nn::Embedding,
47 position_ids: Tensor,
48}
49
50impl ClipTextEmbeddings {
51 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
52 let token_embedding =
53 layers::embedding(c.vocab_size, c.projection_dim, vs.pp("token_embedding"))?;
54 let position_embedding: nn::Embedding = layers::embedding(
55 c.max_position_embeddings,
56 c.projection_dim,
57 vs.pp("position_embedding"),
58 )?;
59 let position_ids =
60 Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
61 Ok(ClipTextEmbeddings {
62 token_embedding,
63 position_embedding,
64 position_ids,
65 })
66 }
67}
68
69impl Module for ClipTextEmbeddings {
70 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
71 let seq_length = input_ids.dim(D::Minus1)?;
72 let inputs_embeds = self.token_embedding.forward(input_ids)?;
73 let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
74 let position_embedding = self.position_embedding.forward(&position_ids)?;
75 inputs_embeds.broadcast_add(&position_embedding)
76 }
77}
78
79#[derive(Clone, Debug)]
80struct ClipAttention {
81 k_proj: candle_nn::Linear,
82 v_proj: candle_nn::Linear,
83 q_proj: candle_nn::Linear,
84 out_proj: candle_nn::Linear,
85 head_dim: usize,
86 scale: f64,
87 num_attention_heads: usize,
88}
89
90impl ClipAttention {
91 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
92 let projection_dim = c.projection_dim;
93 let num_attention_heads = c.num_attention_heads;
94 let k_proj = layers::linear(projection_dim, projection_dim, vs.pp("k_proj"))?;
95 let v_proj = layers::linear(projection_dim, projection_dim, vs.pp("v_proj"))?;
96 let q_proj = layers::linear(projection_dim, projection_dim, vs.pp("q_proj"))?;
97 let out_proj = layers::linear(projection_dim, projection_dim, vs.pp("out_proj"))?;
98 let head_dim = projection_dim / num_attention_heads;
99 let scale = (head_dim as f64).powf(-0.5);
100
101 Ok(ClipAttention {
102 k_proj,
103 v_proj,
104 q_proj,
105 out_proj,
106 head_dim,
107 scale,
108 num_attention_heads,
109 })
110 }
111
112 fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
113 xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
114 .transpose(1, 2)?
115 .contiguous()
116 }
117
118 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
119 let in_dtype = xs.dtype();
120 let (bsz, seq_len, projection_dim) = xs.dims3()?;
121
122 let query_states = (self.q_proj.forward(xs)? * self.scale)?;
123 let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
124 let query_states = self
125 .shape(&query_states, seq_len, bsz)?
126 .reshape(proj_shape)?
127 .to_dtype(DType::F32)?;
128 let key_states = self
129 .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
130 .reshape(proj_shape)?
131 .to_dtype(DType::F32)?;
132 let value_states = self
133 .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
134 .reshape(proj_shape)?
135 .to_dtype(DType::F32)?;
136 let attn_weights = MatMul.matmul(&query_states, &key_states.transpose(1, 2)?)?;
137
138 let src_len = key_states.dim(1)?;
139
140 let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
141 attn_weights
142 .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
143 .broadcast_add(causal_attention_mask)?
144 .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
145 } else {
146 attn_weights
147 };
148
149 let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
150
151 let attn_output = MatMul
152 .matmul(&attn_weights, &value_states)?
153 .to_dtype(in_dtype)?;
154 let attn_output = attn_output
155 .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
156 .transpose(1, 2)?
157 .reshape((bsz, seq_len, projection_dim))?;
158 self.out_proj.forward(&attn_output)
159 }
160}
161
162#[derive(Clone, Debug)]
163struct ClipMlp {
164 fc1: candle_nn::Linear,
165 fc2: candle_nn::Linear,
166 activation: Activation,
167}
168
169impl ClipMlp {
170 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
171 let fc1 = layers::linear(c.projection_dim, c.intermediate_size, vs.pp("fc1"))?;
172 let fc2 = layers::linear(c.intermediate_size, c.projection_dim, vs.pp("fc2"))?;
173
174 Ok(ClipMlp {
175 fc1,
176 fc2,
177 activation: c.hidden_act,
178 })
179 }
180}
181
182impl ClipMlp {
183 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
184 let xs = self.fc1.forward(xs)?;
185 self.fc2.forward(&self.activation.forward(&xs)?)
186 }
187}
188
189#[derive(Clone, Debug)]
190struct ClipEncoderLayer {
191 self_attn: ClipAttention,
192 layer_norm1: candle_nn::LayerNorm,
193 mlp: ClipMlp,
194 layer_norm2: candle_nn::LayerNorm,
195}
196
197impl ClipEncoderLayer {
198 fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
199 let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
200 let layer_norm1 = layers::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm1"))?;
201 let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
202 let layer_norm2 = layers::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm2"))?;
203
204 Ok(ClipEncoderLayer {
205 self_attn,
206 layer_norm1,
207 mlp,
208 layer_norm2,
209 })
210 }
211
212 fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
213 let residual = xs;
214 let xs = self.layer_norm1.forward(xs)?;
215 let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
216 let xs = (xs + residual)?;
217
218 let residual = &xs;
219 let xs = self.layer_norm2.forward(&xs)?;
220 let xs = self.mlp.forward(&xs)?;
221 xs + residual
222 }
223}
224
225#[derive(Clone, Debug)]
226pub struct ClipEncoder {
227 layers: Vec<ClipEncoderLayer>,
228}
229
230impl ClipEncoder {
231 pub fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
232 let vs = vs.pp("layers");
233 let mut layers: Vec<ClipEncoderLayer> = Vec::new();
234 for index in 0..c.num_hidden_layers {
235 let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
236 layers.push(layer)
237 }
238 Ok(ClipEncoder { layers })
239 }
240
241 pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
242 let mut xs = xs.clone();
243 for layer in self.layers.iter() {
244 xs = layer.forward(&xs, causal_attention_mask)?;
245 }
246 Ok(xs)
247 }
248}
249
250#[derive(Clone, Debug)]
252pub struct ClipTextTransformer {
253 embeddings: ClipTextEmbeddings,
254 encoder: ClipEncoder,
255 final_layer_norm: candle_nn::LayerNorm,
256}
257
258impl ClipTextTransformer {
259 pub fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
260 let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
261 let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
262 let final_layer_norm =
263 layers::layer_norm(c.projection_dim, 1e-5, vs.pp("final_layer_norm"))?;
264 Ok(ClipTextTransformer {
265 embeddings,
266 encoder,
267 final_layer_norm,
268 })
269 }
270
271 fn build_causal_attention_mask(
273 bsz: usize,
274 seq_len: usize,
275 mask_after: usize,
276 device: &Device,
277 ) -> Result<Tensor> {
278 let mask: Vec<_> = (0..seq_len)
279 .flat_map(|i| {
280 (0..seq_len).map(move |j| {
281 if j > i || j > mask_after {
282 f32::MIN
283 } else {
284 0.
285 }
286 })
287 })
288 .collect();
289 let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
290 mask.broadcast_as((bsz, 1, seq_len, seq_len))
291 }
292
293 pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
294 let (bsz, seq_len) = input_ids.dims2()?;
295 let input_ids = self.embeddings.forward(input_ids)?;
296 let causal_attention_mask =
297 Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
298 let input_ids = self
299 .encoder
300 .forward(&input_ids, Some(&causal_attention_mask))?;
301 self.final_layer_norm.forward(&input_ids)
302 }
303}
304
305impl Module for ClipTextTransformer {
306 fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
307 let output = self.forward_with_mask(input_ids, usize::MAX)?;
308 let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
309
310 let mut indices = Vec::new();
311 for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
312 let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
313 indices.push(index);
314 }
315 Tensor::cat(&indices, 0)
316 }
317}