mistralrs_core/diffusion_models/clip/
text.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn as nn;
5use candle_nn::Module;
6use mistralrs_quant::ShardedVarBuilder;
7use serde::Deserialize;
8
9use crate::layers::{self, MatMul};
10
11#[derive(Debug, Clone, Copy, Deserialize)]
12pub enum Activation {
13    #[serde(rename = "quick_gelu")]
14    QuickGelu,
15}
16
17impl Module for Activation {
18    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
19        match self {
20            Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?,
21        }
22    }
23}
24
25#[derive(Debug, Clone, Deserialize)]
26pub struct ClipTextConfig {
27    pub vocab_size: usize,
28    pub projection_dim: usize,
29    pub hidden_act: Activation,
30    pub intermediate_size: usize,
31    pub max_position_embeddings: usize,
32    pub num_hidden_layers: usize,
33    pub num_attention_heads: usize,
34}
35
36#[derive(Debug, Clone, Deserialize)]
37pub struct ClipConfig {
38    pub text_config: ClipTextConfig,
39}
40
41// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
42// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
43#[derive(Clone, Debug)]
44struct ClipTextEmbeddings {
45    token_embedding: candle_nn::Embedding,
46    position_embedding: candle_nn::Embedding,
47    position_ids: Tensor,
48}
49
50impl ClipTextEmbeddings {
51    fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
52        let token_embedding = layers::embedding(
53            c.vocab_size,
54            c.projection_dim,
55            vs.pp("token_embedding"),
56            &None,
57        )?;
58        let position_embedding: nn::Embedding = layers::embedding(
59            c.max_position_embeddings,
60            c.projection_dim,
61            vs.pp("position_embedding"),
62            &None,
63        )?;
64        let position_ids =
65            Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
66        Ok(ClipTextEmbeddings {
67            token_embedding,
68            position_embedding,
69            position_ids,
70        })
71    }
72}
73
74impl Module for ClipTextEmbeddings {
75    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
76        let seq_length = input_ids.dim(D::Minus1)?;
77        let inputs_embeds = self.token_embedding.forward(input_ids)?;
78        let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
79        let position_embedding = self.position_embedding.forward(&position_ids)?;
80        inputs_embeds.broadcast_add(&position_embedding)
81    }
82}
83
84#[derive(Clone, Debug)]
85struct ClipAttention {
86    k_proj: candle_nn::Linear,
87    v_proj: candle_nn::Linear,
88    q_proj: candle_nn::Linear,
89    out_proj: candle_nn::Linear,
90    head_dim: usize,
91    scale: f64,
92    num_attention_heads: usize,
93}
94
95impl ClipAttention {
96    fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
97        let projection_dim = c.projection_dim;
98        let num_attention_heads = c.num_attention_heads;
99        let k_proj = layers::linear(projection_dim, projection_dim, vs.pp("k_proj"))?;
100        let v_proj = layers::linear(projection_dim, projection_dim, vs.pp("v_proj"))?;
101        let q_proj = layers::linear(projection_dim, projection_dim, vs.pp("q_proj"))?;
102        let out_proj = layers::linear(projection_dim, projection_dim, vs.pp("out_proj"))?;
103        let head_dim = projection_dim / num_attention_heads;
104        let scale = (head_dim as f64).powf(-0.5);
105
106        Ok(ClipAttention {
107            k_proj,
108            v_proj,
109            q_proj,
110            out_proj,
111            head_dim,
112            scale,
113            num_attention_heads,
114        })
115    }
116
117    fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
118        xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))?
119            .transpose(1, 2)?
120            .contiguous()
121    }
122
123    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
124        let in_dtype = xs.dtype();
125        let (bsz, seq_len, projection_dim) = xs.dims3()?;
126
127        let query_states = (self.q_proj.forward(xs)? * self.scale)?;
128        let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim);
129        let query_states = self
130            .shape(&query_states, seq_len, bsz)?
131            .reshape(proj_shape)?
132            .to_dtype(DType::F32)?;
133        let key_states = self
134            .shape(&self.k_proj.forward(xs)?, seq_len, bsz)?
135            .reshape(proj_shape)?
136            .to_dtype(DType::F32)?;
137        let value_states = self
138            .shape(&self.v_proj.forward(xs)?, seq_len, bsz)?
139            .reshape(proj_shape)?
140            .to_dtype(DType::F32)?;
141        let attn_weights = MatMul.matmul(&query_states, &key_states.transpose(1, 2)?)?;
142
143        let src_len = key_states.dim(1)?;
144
145        let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
146            attn_weights
147                .reshape((bsz, self.num_attention_heads, seq_len, src_len))?
148                .broadcast_add(causal_attention_mask)?
149                .reshape((bsz * self.num_attention_heads, seq_len, src_len))?
150        } else {
151            attn_weights
152        };
153
154        let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
155
156        let attn_output = MatMul
157            .matmul(&attn_weights, &value_states)?
158            .to_dtype(in_dtype)?;
159        let attn_output = attn_output
160            .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))?
161            .transpose(1, 2)?
162            .reshape((bsz, seq_len, projection_dim))?;
163        self.out_proj.forward(&attn_output)
164    }
165}
166
167#[derive(Clone, Debug)]
168struct ClipMlp {
169    fc1: candle_nn::Linear,
170    fc2: candle_nn::Linear,
171    activation: Activation,
172}
173
174impl ClipMlp {
175    fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
176        let fc1 = layers::linear(c.projection_dim, c.intermediate_size, vs.pp("fc1"))?;
177        let fc2 = layers::linear(c.intermediate_size, c.projection_dim, vs.pp("fc2"))?;
178
179        Ok(ClipMlp {
180            fc1,
181            fc2,
182            activation: c.hidden_act,
183        })
184    }
185}
186
187impl ClipMlp {
188    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
189        let xs = self.fc1.forward(xs)?;
190        self.fc2.forward(&self.activation.forward(&xs)?)
191    }
192}
193
194#[derive(Clone, Debug)]
195struct ClipEncoderLayer {
196    self_attn: ClipAttention,
197    layer_norm1: candle_nn::LayerNorm,
198    mlp: ClipMlp,
199    layer_norm2: candle_nn::LayerNorm,
200}
201
202impl ClipEncoderLayer {
203    fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
204        let self_attn = ClipAttention::new(vs.pp("self_attn"), c)?;
205        let layer_norm1 = layers::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm1"))?;
206        let mlp = ClipMlp::new(vs.pp("mlp"), c)?;
207        let layer_norm2 = layers::layer_norm(c.projection_dim, 1e-5, vs.pp("layer_norm2"))?;
208
209        Ok(ClipEncoderLayer {
210            self_attn,
211            layer_norm1,
212            mlp,
213            layer_norm2,
214        })
215    }
216
217    fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
218        let residual = xs;
219        let xs = self.layer_norm1.forward(xs)?;
220        let xs = self.self_attn.forward(&xs, causal_attention_mask)?;
221        let xs = (xs + residual)?;
222
223        let residual = &xs;
224        let xs = self.layer_norm2.forward(&xs)?;
225        let xs = self.mlp.forward(&xs)?;
226        xs + residual
227    }
228}
229
230#[derive(Clone, Debug)]
231pub struct ClipEncoder {
232    layers: Vec<ClipEncoderLayer>,
233}
234
235impl ClipEncoder {
236    pub fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
237        let vs = vs.pp("layers");
238        let mut layers: Vec<ClipEncoderLayer> = Vec::new();
239        for index in 0..c.num_hidden_layers {
240            let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
241            layers.push(layer)
242        }
243        Ok(ClipEncoder { layers })
244    }
245
246    pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
247        let mut xs = xs.clone();
248        for layer in self.layers.iter() {
249            xs = layer.forward(&xs, causal_attention_mask)?;
250        }
251        Ok(xs)
252    }
253}
254
255/// A CLIP transformer based model.
256#[derive(Clone, Debug)]
257pub struct ClipTextTransformer {
258    embeddings: ClipTextEmbeddings,
259    encoder: ClipEncoder,
260    final_layer_norm: candle_nn::LayerNorm,
261}
262
263impl ClipTextTransformer {
264    pub fn new(vs: ShardedVarBuilder, c: &ClipTextConfig) -> Result<Self> {
265        let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
266        let encoder = ClipEncoder::new(vs.pp("encoder"), c)?;
267        let final_layer_norm =
268            layers::layer_norm(c.projection_dim, 1e-5, vs.pp("final_layer_norm"))?;
269        Ok(ClipTextTransformer {
270            embeddings,
271            encoder,
272            final_layer_norm,
273        })
274    }
275
276    // TODO: rewrrite to newer version
277    fn build_causal_attention_mask(
278        bsz: usize,
279        seq_len: usize,
280        mask_after: usize,
281        device: &Device,
282    ) -> Result<Tensor> {
283        let mask: Vec<_> = (0..seq_len)
284            .flat_map(|i| {
285                (0..seq_len).map(move |j| {
286                    if j > i || j > mask_after {
287                        f32::MIN
288                    } else {
289                        0.
290                    }
291                })
292            })
293            .collect();
294        let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
295        mask.broadcast_as((bsz, 1, seq_len, seq_len))
296    }
297
298    pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
299        let (bsz, seq_len) = input_ids.dims2()?;
300        let input_ids = self.embeddings.forward(input_ids)?;
301        let causal_attention_mask =
302            Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
303        let input_ids = self
304            .encoder
305            .forward(&input_ids, Some(&causal_attention_mask))?;
306        self.final_layer_norm.forward(&input_ids)
307    }
308}
309
310impl Module for ClipTextTransformer {
311    fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
312        let output = self.forward_with_mask(input_ids, usize::MAX)?;
313        let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
314
315        let mut indices = Vec::new();
316        for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
317            let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
318            indices.push(index);
319        }
320        Tensor::cat(&indices, 0)
321    }
322}