mistralrs_core/embedding/
bert.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{embedding, layer_norm, linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
5use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
6use serde::Deserialize;
7use tokenizers::Tokenizer;
8
9use crate::{
10    engine::BertEmbeddingModel, layers::Activation, utils::log::once_log_info, GLOBAL_HF_CACHE,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
14#[serde(rename_all = "lowercase")]
15enum PositionEmbeddingType {
16    #[default]
17    Absolute,
18}
19
20// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
21#[derive(Debug, Clone, PartialEq, Deserialize)]
22pub struct Config {
23    vocab_size: usize,
24    hidden_size: usize,
25    num_hidden_layers: usize,
26    num_attention_heads: usize,
27    intermediate_size: usize,
28    pub hidden_act: Activation,
29    hidden_dropout_prob: f64,
30    max_position_embeddings: usize,
31    type_vocab_size: usize,
32    initializer_range: f64,
33    layer_norm_eps: f64,
34    pad_token_id: usize,
35    #[serde(default)]
36    position_embedding_type: PositionEmbeddingType,
37    #[serde(default)]
38    use_cache: bool,
39    classifier_dropout: Option<f64>,
40    model_type: Option<String>,
41}
42
43// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
44struct BertEmbeddings {
45    word_embeddings: Embedding,
46    position_embeddings: Option<Embedding>,
47    token_type_embeddings: Embedding,
48    layer_norm: LayerNorm,
49    span: tracing::Span,
50}
51
52impl BertEmbeddings {
53    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
54        let word_embeddings = embedding(
55            config.vocab_size,
56            config.hidden_size,
57            vb.pp("word_embeddings"),
58        )?;
59        let position_embeddings = embedding(
60            config.max_position_embeddings,
61            config.hidden_size,
62            vb.pp("position_embeddings"),
63        )?;
64        let token_type_embeddings = embedding(
65            config.type_vocab_size,
66            config.hidden_size,
67            vb.pp("token_type_embeddings"),
68        )?;
69        let layer_norm = layer_norm(
70            config.hidden_size,
71            config.layer_norm_eps,
72            vb.pp("LayerNorm"),
73        )?;
74        Ok(Self {
75            word_embeddings,
76            position_embeddings: Some(position_embeddings),
77            token_type_embeddings,
78            layer_norm,
79            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
80        })
81    }
82
83    fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
84        let _enter = self.span.enter();
85        let (_bsize, seq_len) = input_ids.dims2()?;
86        let input_embeddings = self.word_embeddings.forward(input_ids)?;
87        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
88        let mut embeddings = (&input_embeddings + token_type_embeddings)?;
89        if let Some(position_embeddings) = &self.position_embeddings {
90            // TODO: Proper absolute positions?
91            let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
92            let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
93            embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
94        }
95        let embeddings = self.layer_norm.forward(&embeddings)?;
96        Ok(embeddings)
97    }
98}
99
100struct BertSelfAttention {
101    query: Linear,
102    key: Linear,
103    value: Linear,
104    num_attention_heads: usize,
105    attention_head_size: usize,
106    span: tracing::Span,
107    span_softmax: tracing::Span,
108}
109
110impl BertSelfAttention {
111    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
112        let attention_head_size = config.hidden_size / config.num_attention_heads;
113        let all_head_size = config.num_attention_heads * attention_head_size;
114        let hidden_size = config.hidden_size;
115        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
116        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
117        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
118        Ok(Self {
119            query,
120            key,
121            value,
122            num_attention_heads: config.num_attention_heads,
123            attention_head_size,
124            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
125            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
126        })
127    }
128
129    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
130        let mut new_x_shape = xs.dims().to_vec();
131        new_x_shape.pop();
132        new_x_shape.push(self.num_attention_heads);
133        new_x_shape.push(self.attention_head_size);
134        let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
135        xs.contiguous()
136    }
137
138    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
139        let _enter = self.span.enter();
140        let query_layer = self.query.forward(hidden_states)?;
141        let key_layer = self.key.forward(hidden_states)?;
142        let value_layer = self.value.forward(hidden_states)?;
143
144        let query_layer = self.transpose_for_scores(&query_layer)?;
145        let key_layer = self.transpose_for_scores(&key_layer)?;
146        let value_layer = self.transpose_for_scores(&value_layer)?;
147
148        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
149        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
150        let attention_scores = attention_scores.broadcast_add(attention_mask)?;
151        let attention_probs = {
152            let _enter_sm = self.span_softmax.enter();
153            candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
154        };
155
156        let context_layer = attention_probs.matmul(&value_layer)?;
157        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
158        let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
159        Ok(context_layer)
160    }
161}
162
163struct BertSelfOutput {
164    dense: Linear,
165    layer_norm: LayerNorm,
166    span: tracing::Span,
167}
168
169impl BertSelfOutput {
170    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
171        let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
172        let layer_norm = layer_norm(
173            config.hidden_size,
174            config.layer_norm_eps,
175            vb.pp("LayerNorm"),
176        )?;
177        Ok(Self {
178            dense,
179            layer_norm,
180            span: tracing::span!(tracing::Level::TRACE, "self-out"),
181        })
182    }
183
184    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
185        let _enter = self.span.enter();
186        let hidden_states = self.dense.forward(hidden_states)?;
187        self.layer_norm.forward(&(hidden_states + input_tensor)?)
188    }
189}
190
191// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
192struct BertAttention {
193    self_attention: BertSelfAttention,
194    self_output: BertSelfOutput,
195    span: tracing::Span,
196}
197
198impl BertAttention {
199    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
200        let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
201        let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
202        Ok(Self {
203            self_attention,
204            self_output,
205            span: tracing::span!(tracing::Level::TRACE, "attn"),
206        })
207    }
208
209    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
210        let _enter = self.span.enter();
211        let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
212        let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
213        Ok(attention_output)
214    }
215}
216
217// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
218struct BertIntermediate {
219    dense: Linear,
220    intermediate_act: Activation,
221}
222
223impl BertIntermediate {
224    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
225        let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
226        Ok(Self {
227            dense,
228            intermediate_act: config.hidden_act,
229        })
230    }
231}
232
233impl Module for BertIntermediate {
234    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
235        let hidden_states = self.dense.forward(hidden_states)?;
236        let ys = self.intermediate_act.forward(&hidden_states)?;
237        Ok(ys)
238    }
239}
240
241// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
242struct BertOutput {
243    dense: Linear,
244    layer_norm: LayerNorm,
245    span: tracing::Span,
246}
247
248impl BertOutput {
249    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
250        let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
251        let layer_norm = layer_norm(
252            config.hidden_size,
253            config.layer_norm_eps,
254            vb.pp("LayerNorm"),
255        )?;
256        Ok(Self {
257            dense,
258            layer_norm,
259            span: tracing::span!(tracing::Level::TRACE, "out"),
260        })
261    }
262
263    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
264        let _enter = self.span.enter();
265        let hidden_states = self.dense.forward(hidden_states)?;
266        self.layer_norm.forward(&(hidden_states + input_tensor)?)
267    }
268}
269
270// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
271struct BertLayer {
272    attention: BertAttention,
273    intermediate: BertIntermediate,
274    output: BertOutput,
275    span: tracing::Span,
276}
277
278impl BertLayer {
279    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
280        let attention = BertAttention::load(vb.pp("attention"), config)?;
281        let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
282        let output = BertOutput::load(vb.pp("output"), config)?;
283        Ok(Self {
284            attention,
285            intermediate,
286            output,
287            span: tracing::span!(tracing::Level::TRACE, "layer"),
288        })
289    }
290
291    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
292        let _enter = self.span.enter();
293        let attention_output = self.attention.forward(hidden_states, attention_mask)?;
294        // TODO: Support cross-attention?
295        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
296        // TODO: Support something similar to `apply_chunking_to_forward`?
297        let intermediate_output = self.intermediate.forward(&attention_output)?;
298        let layer_output = self
299            .output
300            .forward(&intermediate_output, &attention_output)?;
301        Ok(layer_output)
302    }
303}
304
305// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
306struct BertEncoder {
307    layers: Vec<BertLayer>,
308    span: tracing::Span,
309}
310
311impl BertEncoder {
312    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
313        let layers = (0..config.num_hidden_layers)
314            .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
315            .collect::<Result<Vec<_>>>()?;
316        let span = tracing::span!(tracing::Level::TRACE, "encoder");
317        Ok(BertEncoder { layers, span })
318    }
319
320    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
321        let _enter = self.span.enter();
322        let mut hidden_states = hidden_states.clone();
323        // Use a loop rather than a fold as it's easier to modify when adding debug/...
324        for layer in self.layers.iter() {
325            hidden_states = layer.forward(&hidden_states, attention_mask)?
326        }
327        Ok(hidden_states)
328    }
329}
330
331// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
332pub struct BertModel {
333    embeddings: BertEmbeddings,
334    encoder: BertEncoder,
335    span: tracing::Span,
336}
337
338impl BertModel {
339    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
340        let (embeddings, encoder) = match (
341            BertEmbeddings::load(vb.pp("embeddings"), config),
342            BertEncoder::load(vb.pp("encoder"), config),
343        ) {
344            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
345            (Err(err), _) | (_, Err(err)) => {
346                if let Some(model_type) = &config.model_type {
347                    if let (Ok(embeddings), Ok(encoder)) = (
348                        BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
349                        BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
350                    ) {
351                        (embeddings, encoder)
352                    } else {
353                        return Err(err);
354                    }
355                } else {
356                    return Err(err);
357                }
358            }
359        };
360        Ok(Self {
361            embeddings,
362            encoder,
363            span: tracing::span!(tracing::Level::TRACE, "model"),
364        })
365    }
366
367    fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
368        let attention_mask = match attention_mask.rank() {
369            3 => attention_mask.unsqueeze(1)?,
370            2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
371            _ => candle_core::bail!("Wrong shape for input_ids or attention_mask"),
372        };
373        let attention_mask = attention_mask.to_dtype(dtype)?;
374        // torch.finfo(dtype).min
375        (attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
376            &Tensor::try_from(f32::MIN)?
377                .to_dtype(dtype)?
378                .to_device(attention_mask.device())?,
379        )
380    }
381
382    pub fn forward(
383        &self,
384        input_ids: &Tensor,
385        token_type_ids: &Tensor,
386        attention_mask: Option<&Tensor>,
387    ) -> Result<Tensor> {
388        let _enter = self.span.enter();
389        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
390        let attention_mask = match attention_mask {
391            Some(attention_mask) => attention_mask.clone(),
392            None => input_ids.ones_like()?,
393        };
394        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
395        let attention_mask =
396            Self::get_extended_attention_mask(&attention_mask, embedding_output.dtype())?;
397        let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
398        Ok(sequence_output)
399    }
400}
401
402pub struct BertPipeline {
403    pub model: BertModel,
404    pub tokenizer: Tokenizer,
405}
406
407impl BertPipeline {
408    pub fn new(model: BertEmbeddingModel, device: &Device) -> anyhow::Result<Self> {
409        let model_id = match model {
410            BertEmbeddingModel::SnowflakeArcticEmbedL => {
411                "Snowflake/snowflake-arctic-embed-l-v2.0".to_string()
412            }
413            BertEmbeddingModel::Custom(model_id) => model_id,
414        };
415        once_log_info(format!("Loading embedding model ({model_id})."));
416
417        let repo = Repo::with_revision(model_id, RepoType::Model, "main".to_string());
418        let (config_filename, tokenizer_filename, weights_filename) = {
419            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
420            let api = ApiBuilder::from_cache(cache)
421                .with_progress(true)
422                .with_token(None)
423                .build()?;
424            let api = api.repo(repo);
425            let config = api.get("config.json")?;
426            let tokenizer = api.get("tokenizer.json")?;
427            let weights = api.get("model.safetensors")?;
428            (config, tokenizer, weights)
429        };
430        let config = std::fs::read_to_string(config_filename)?;
431        let config: Config = serde_json::from_str(&config)?;
432        let tokenizer =
433            Tokenizer::from_file(tokenizer_filename).map_err(candle_core::Error::msg)?;
434
435        let vb = unsafe {
436            VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, device)?
437        };
438        let model = BertModel::load(vb, &config)?;
439        Ok(Self { model, tokenizer })
440    }
441}