mistralrs_core/embedding/
bert.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{embedding, layer_norm, linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
5use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
6use serde::Deserialize;
7use tokenizers::Tokenizer;
8
9use crate::{engine::BertEmbeddingModel, layers::Activation, GLOBAL_HF_CACHE};
10use mistralrs_quant::log::once_log_info;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
13#[serde(rename_all = "lowercase")]
14enum PositionEmbeddingType {
15    #[default]
16    Absolute,
17}
18
19// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
20#[derive(Debug, Clone, PartialEq, Deserialize)]
21pub struct Config {
22    vocab_size: usize,
23    hidden_size: usize,
24    num_hidden_layers: usize,
25    num_attention_heads: usize,
26    intermediate_size: usize,
27    pub hidden_act: Activation,
28    hidden_dropout_prob: f64,
29    max_position_embeddings: usize,
30    type_vocab_size: usize,
31    initializer_range: f64,
32    layer_norm_eps: f64,
33    pad_token_id: usize,
34    #[serde(default)]
35    position_embedding_type: PositionEmbeddingType,
36    #[serde(default)]
37    use_cache: bool,
38    classifier_dropout: Option<f64>,
39    model_type: Option<String>,
40}
41
42// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
43struct BertEmbeddings {
44    word_embeddings: Embedding,
45    position_embeddings: Option<Embedding>,
46    token_type_embeddings: Embedding,
47    layer_norm: LayerNorm,
48    span: tracing::Span,
49}
50
51impl BertEmbeddings {
52    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
53        let word_embeddings = embedding(
54            config.vocab_size,
55            config.hidden_size,
56            vb.pp("word_embeddings"),
57        )?;
58        let position_embeddings = embedding(
59            config.max_position_embeddings,
60            config.hidden_size,
61            vb.pp("position_embeddings"),
62        )?;
63        let token_type_embeddings = embedding(
64            config.type_vocab_size,
65            config.hidden_size,
66            vb.pp("token_type_embeddings"),
67        )?;
68        let layer_norm = layer_norm(
69            config.hidden_size,
70            config.layer_norm_eps,
71            vb.pp("LayerNorm"),
72        )?;
73        Ok(Self {
74            word_embeddings,
75            position_embeddings: Some(position_embeddings),
76            token_type_embeddings,
77            layer_norm,
78            span: tracing::span!(tracing::Level::TRACE, "embeddings"),
79        })
80    }
81
82    fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
83        let _enter = self.span.enter();
84        let (_bsize, seq_len) = input_ids.dims2()?;
85        let input_embeddings = self.word_embeddings.forward(input_ids)?;
86        let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
87        let mut embeddings = (&input_embeddings + token_type_embeddings)?;
88        if let Some(position_embeddings) = &self.position_embeddings {
89            // TODO: Proper absolute positions?
90            let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
91            let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
92            embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
93        }
94        let embeddings = self.layer_norm.forward(&embeddings)?;
95        Ok(embeddings)
96    }
97}
98
99struct BertSelfAttention {
100    query: Linear,
101    key: Linear,
102    value: Linear,
103    num_attention_heads: usize,
104    attention_head_size: usize,
105    span: tracing::Span,
106    span_softmax: tracing::Span,
107}
108
109impl BertSelfAttention {
110    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
111        let attention_head_size = config.hidden_size / config.num_attention_heads;
112        let all_head_size = config.num_attention_heads * attention_head_size;
113        let hidden_size = config.hidden_size;
114        let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
115        let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
116        let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
117        Ok(Self {
118            query,
119            key,
120            value,
121            num_attention_heads: config.num_attention_heads,
122            attention_head_size,
123            span: tracing::span!(tracing::Level::TRACE, "self-attn"),
124            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
125        })
126    }
127
128    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
129        let mut new_x_shape = xs.dims().to_vec();
130        new_x_shape.pop();
131        new_x_shape.push(self.num_attention_heads);
132        new_x_shape.push(self.attention_head_size);
133        let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
134        xs.contiguous()
135    }
136
137    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
138        let _enter = self.span.enter();
139        let query_layer = self.query.forward(hidden_states)?;
140        let key_layer = self.key.forward(hidden_states)?;
141        let value_layer = self.value.forward(hidden_states)?;
142
143        let query_layer = self.transpose_for_scores(&query_layer)?;
144        let key_layer = self.transpose_for_scores(&key_layer)?;
145        let value_layer = self.transpose_for_scores(&value_layer)?;
146
147        let attention_scores = query_layer.matmul(&key_layer.t()?)?;
148        let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
149        let attention_scores = attention_scores.broadcast_add(attention_mask)?;
150        let attention_probs = {
151            let _enter_sm = self.span_softmax.enter();
152            candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
153        };
154
155        let context_layer = attention_probs.matmul(&value_layer)?;
156        let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
157        let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
158        Ok(context_layer)
159    }
160}
161
162struct BertSelfOutput {
163    dense: Linear,
164    layer_norm: LayerNorm,
165    span: tracing::Span,
166}
167
168impl BertSelfOutput {
169    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
170        let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
171        let layer_norm = layer_norm(
172            config.hidden_size,
173            config.layer_norm_eps,
174            vb.pp("LayerNorm"),
175        )?;
176        Ok(Self {
177            dense,
178            layer_norm,
179            span: tracing::span!(tracing::Level::TRACE, "self-out"),
180        })
181    }
182
183    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
184        let _enter = self.span.enter();
185        let hidden_states = self.dense.forward(hidden_states)?;
186        self.layer_norm.forward(&(hidden_states + input_tensor)?)
187    }
188}
189
190// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
191struct BertAttention {
192    self_attention: BertSelfAttention,
193    self_output: BertSelfOutput,
194    span: tracing::Span,
195}
196
197impl BertAttention {
198    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
199        let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
200        let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
201        Ok(Self {
202            self_attention,
203            self_output,
204            span: tracing::span!(tracing::Level::TRACE, "attn"),
205        })
206    }
207
208    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
209        let _enter = self.span.enter();
210        let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
211        let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
212        Ok(attention_output)
213    }
214}
215
216// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
217struct BertIntermediate {
218    dense: Linear,
219    intermediate_act: Activation,
220}
221
222impl BertIntermediate {
223    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
224        let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
225        Ok(Self {
226            dense,
227            intermediate_act: config.hidden_act,
228        })
229    }
230}
231
232impl Module for BertIntermediate {
233    fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
234        let hidden_states = self.dense.forward(hidden_states)?;
235        let ys = self.intermediate_act.forward(&hidden_states)?;
236        Ok(ys)
237    }
238}
239
240// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
241struct BertOutput {
242    dense: Linear,
243    layer_norm: LayerNorm,
244    span: tracing::Span,
245}
246
247impl BertOutput {
248    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
249        let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
250        let layer_norm = layer_norm(
251            config.hidden_size,
252            config.layer_norm_eps,
253            vb.pp("LayerNorm"),
254        )?;
255        Ok(Self {
256            dense,
257            layer_norm,
258            span: tracing::span!(tracing::Level::TRACE, "out"),
259        })
260    }
261
262    fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
263        let _enter = self.span.enter();
264        let hidden_states = self.dense.forward(hidden_states)?;
265        self.layer_norm.forward(&(hidden_states + input_tensor)?)
266    }
267}
268
269// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
270struct BertLayer {
271    attention: BertAttention,
272    intermediate: BertIntermediate,
273    output: BertOutput,
274    span: tracing::Span,
275}
276
277impl BertLayer {
278    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
279        let attention = BertAttention::load(vb.pp("attention"), config)?;
280        let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
281        let output = BertOutput::load(vb.pp("output"), config)?;
282        Ok(Self {
283            attention,
284            intermediate,
285            output,
286            span: tracing::span!(tracing::Level::TRACE, "layer"),
287        })
288    }
289
290    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
291        let _enter = self.span.enter();
292        let attention_output = self.attention.forward(hidden_states, attention_mask)?;
293        // TODO: Support cross-attention?
294        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
295        // TODO: Support something similar to `apply_chunking_to_forward`?
296        let intermediate_output = self.intermediate.forward(&attention_output)?;
297        let layer_output = self
298            .output
299            .forward(&intermediate_output, &attention_output)?;
300        Ok(layer_output)
301    }
302}
303
304// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
305struct BertEncoder {
306    layers: Vec<BertLayer>,
307    span: tracing::Span,
308}
309
310impl BertEncoder {
311    fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
312        let layers = (0..config.num_hidden_layers)
313            .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
314            .collect::<Result<Vec<_>>>()?;
315        let span = tracing::span!(tracing::Level::TRACE, "encoder");
316        Ok(BertEncoder { layers, span })
317    }
318
319    fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
320        let _enter = self.span.enter();
321        let mut hidden_states = hidden_states.clone();
322        // Use a loop rather than a fold as it's easier to modify when adding debug/...
323        for layer in self.layers.iter() {
324            hidden_states = layer.forward(&hidden_states, attention_mask)?
325        }
326        Ok(hidden_states)
327    }
328}
329
330// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
331pub struct BertModel {
332    embeddings: BertEmbeddings,
333    encoder: BertEncoder,
334    span: tracing::Span,
335}
336
337impl BertModel {
338    pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
339        let (embeddings, encoder) = match (
340            BertEmbeddings::load(vb.pp("embeddings"), config),
341            BertEncoder::load(vb.pp("encoder"), config),
342        ) {
343            (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
344            (Err(err), _) | (_, Err(err)) => {
345                if let Some(model_type) = &config.model_type {
346                    if let (Ok(embeddings), Ok(encoder)) = (
347                        BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
348                        BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
349                    ) {
350                        (embeddings, encoder)
351                    } else {
352                        return Err(err);
353                    }
354                } else {
355                    return Err(err);
356                }
357            }
358        };
359        Ok(Self {
360            embeddings,
361            encoder,
362            span: tracing::span!(tracing::Level::TRACE, "model"),
363        })
364    }
365
366    fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
367        let attention_mask = match attention_mask.rank() {
368            3 => attention_mask.unsqueeze(1)?,
369            2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
370            _ => candle_core::bail!("Wrong shape for input_ids or attention_mask"),
371        };
372        let attention_mask = attention_mask.to_dtype(dtype)?;
373        // torch.finfo(dtype).min
374        (attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
375            &Tensor::try_from(f32::MIN)?
376                .to_dtype(dtype)?
377                .to_device(attention_mask.device())?,
378        )
379    }
380
381    pub fn forward(
382        &self,
383        input_ids: &Tensor,
384        token_type_ids: &Tensor,
385        attention_mask: Option<&Tensor>,
386    ) -> Result<Tensor> {
387        let _enter = self.span.enter();
388        let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
389        let attention_mask = match attention_mask {
390            Some(attention_mask) => attention_mask.clone(),
391            None => input_ids.ones_like()?,
392        };
393        // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
394        let attention_mask =
395            Self::get_extended_attention_mask(&attention_mask, embedding_output.dtype())?;
396        let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
397        Ok(sequence_output)
398    }
399}
400
401pub struct BertPipeline {
402    pub model: BertModel,
403    pub tokenizer: Tokenizer,
404}
405
406impl BertPipeline {
407    pub fn new(model: BertEmbeddingModel, device: &Device) -> anyhow::Result<Self> {
408        let model_id = match model {
409            BertEmbeddingModel::SnowflakeArcticEmbedL => {
410                "Snowflake/snowflake-arctic-embed-l-v2.0".to_string()
411            }
412            BertEmbeddingModel::Custom(model_id) => model_id,
413        };
414        once_log_info(format!("Loading embedding model ({model_id})."));
415
416        let repo = Repo::with_revision(model_id, RepoType::Model, "main".to_string());
417        let (config_filename, tokenizer_filename, weights_filename) = {
418            let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
419            let api = ApiBuilder::from_cache(cache)
420                .with_progress(true)
421                .with_token(None)
422                .build()?;
423            let api = api.repo(repo);
424            let config = api.get("config.json")?;
425            let tokenizer = api.get("tokenizer.json")?;
426            let weights = api.get("model.safetensors")?;
427            (config, tokenizer, weights)
428        };
429        let config = std::fs::read_to_string(config_filename)?;
430        let config: Config = serde_json::from_str(&config)?;
431        let tokenizer =
432            Tokenizer::from_file(tokenizer_filename).map_err(candle_core::Error::msg)?;
433
434        let vb = unsafe {
435            VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, device)?
436        };
437        let model = BertModel::load(vb, &config)?;
438        Ok(Self { model, tokenizer })
439    }
440}