1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{embedding, layer_norm, linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
5use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
6use serde::Deserialize;
7use tokenizers::Tokenizer;
8
9use crate::{engine::BertEmbeddingModel, layers::Activation, GLOBAL_HF_CACHE};
10use mistralrs_quant::log::once_log_info;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
13#[serde(rename_all = "lowercase")]
14enum PositionEmbeddingType {
15 #[default]
16 Absolute,
17}
18
19#[derive(Debug, Clone, PartialEq, Deserialize)]
21pub struct Config {
22 vocab_size: usize,
23 hidden_size: usize,
24 num_hidden_layers: usize,
25 num_attention_heads: usize,
26 intermediate_size: usize,
27 pub hidden_act: Activation,
28 hidden_dropout_prob: f64,
29 max_position_embeddings: usize,
30 type_vocab_size: usize,
31 initializer_range: f64,
32 layer_norm_eps: f64,
33 pad_token_id: usize,
34 #[serde(default)]
35 position_embedding_type: PositionEmbeddingType,
36 #[serde(default)]
37 use_cache: bool,
38 classifier_dropout: Option<f64>,
39 model_type: Option<String>,
40}
41
42struct BertEmbeddings {
44 word_embeddings: Embedding,
45 position_embeddings: Option<Embedding>,
46 token_type_embeddings: Embedding,
47 layer_norm: LayerNorm,
48 span: tracing::Span,
49}
50
51impl BertEmbeddings {
52 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
53 let word_embeddings = embedding(
54 config.vocab_size,
55 config.hidden_size,
56 vb.pp("word_embeddings"),
57 )?;
58 let position_embeddings = embedding(
59 config.max_position_embeddings,
60 config.hidden_size,
61 vb.pp("position_embeddings"),
62 )?;
63 let token_type_embeddings = embedding(
64 config.type_vocab_size,
65 config.hidden_size,
66 vb.pp("token_type_embeddings"),
67 )?;
68 let layer_norm = layer_norm(
69 config.hidden_size,
70 config.layer_norm_eps,
71 vb.pp("LayerNorm"),
72 )?;
73 Ok(Self {
74 word_embeddings,
75 position_embeddings: Some(position_embeddings),
76 token_type_embeddings,
77 layer_norm,
78 span: tracing::span!(tracing::Level::TRACE, "embeddings"),
79 })
80 }
81
82 fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
83 let _enter = self.span.enter();
84 let (_bsize, seq_len) = input_ids.dims2()?;
85 let input_embeddings = self.word_embeddings.forward(input_ids)?;
86 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
87 let mut embeddings = (&input_embeddings + token_type_embeddings)?;
88 if let Some(position_embeddings) = &self.position_embeddings {
89 let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
91 let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
92 embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
93 }
94 let embeddings = self.layer_norm.forward(&embeddings)?;
95 Ok(embeddings)
96 }
97}
98
99struct BertSelfAttention {
100 query: Linear,
101 key: Linear,
102 value: Linear,
103 num_attention_heads: usize,
104 attention_head_size: usize,
105 span: tracing::Span,
106 span_softmax: tracing::Span,
107}
108
109impl BertSelfAttention {
110 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
111 let attention_head_size = config.hidden_size / config.num_attention_heads;
112 let all_head_size = config.num_attention_heads * attention_head_size;
113 let hidden_size = config.hidden_size;
114 let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
115 let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
116 let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
117 Ok(Self {
118 query,
119 key,
120 value,
121 num_attention_heads: config.num_attention_heads,
122 attention_head_size,
123 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
124 span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
125 })
126 }
127
128 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
129 let mut new_x_shape = xs.dims().to_vec();
130 new_x_shape.pop();
131 new_x_shape.push(self.num_attention_heads);
132 new_x_shape.push(self.attention_head_size);
133 let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
134 xs.contiguous()
135 }
136
137 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
138 let _enter = self.span.enter();
139 let query_layer = self.query.forward(hidden_states)?;
140 let key_layer = self.key.forward(hidden_states)?;
141 let value_layer = self.value.forward(hidden_states)?;
142
143 let query_layer = self.transpose_for_scores(&query_layer)?;
144 let key_layer = self.transpose_for_scores(&key_layer)?;
145 let value_layer = self.transpose_for_scores(&value_layer)?;
146
147 let attention_scores = query_layer.matmul(&key_layer.t()?)?;
148 let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
149 let attention_scores = attention_scores.broadcast_add(attention_mask)?;
150 let attention_probs = {
151 let _enter_sm = self.span_softmax.enter();
152 candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
153 };
154
155 let context_layer = attention_probs.matmul(&value_layer)?;
156 let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
157 let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
158 Ok(context_layer)
159 }
160}
161
162struct BertSelfOutput {
163 dense: Linear,
164 layer_norm: LayerNorm,
165 span: tracing::Span,
166}
167
168impl BertSelfOutput {
169 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
170 let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
171 let layer_norm = layer_norm(
172 config.hidden_size,
173 config.layer_norm_eps,
174 vb.pp("LayerNorm"),
175 )?;
176 Ok(Self {
177 dense,
178 layer_norm,
179 span: tracing::span!(tracing::Level::TRACE, "self-out"),
180 })
181 }
182
183 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
184 let _enter = self.span.enter();
185 let hidden_states = self.dense.forward(hidden_states)?;
186 self.layer_norm.forward(&(hidden_states + input_tensor)?)
187 }
188}
189
190struct BertAttention {
192 self_attention: BertSelfAttention,
193 self_output: BertSelfOutput,
194 span: tracing::Span,
195}
196
197impl BertAttention {
198 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
199 let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
200 let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
201 Ok(Self {
202 self_attention,
203 self_output,
204 span: tracing::span!(tracing::Level::TRACE, "attn"),
205 })
206 }
207
208 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
209 let _enter = self.span.enter();
210 let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
211 let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
212 Ok(attention_output)
213 }
214}
215
216struct BertIntermediate {
218 dense: Linear,
219 intermediate_act: Activation,
220}
221
222impl BertIntermediate {
223 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
224 let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
225 Ok(Self {
226 dense,
227 intermediate_act: config.hidden_act,
228 })
229 }
230}
231
232impl Module for BertIntermediate {
233 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
234 let hidden_states = self.dense.forward(hidden_states)?;
235 let ys = self.intermediate_act.forward(&hidden_states)?;
236 Ok(ys)
237 }
238}
239
240struct BertOutput {
242 dense: Linear,
243 layer_norm: LayerNorm,
244 span: tracing::Span,
245}
246
247impl BertOutput {
248 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
249 let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
250 let layer_norm = layer_norm(
251 config.hidden_size,
252 config.layer_norm_eps,
253 vb.pp("LayerNorm"),
254 )?;
255 Ok(Self {
256 dense,
257 layer_norm,
258 span: tracing::span!(tracing::Level::TRACE, "out"),
259 })
260 }
261
262 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
263 let _enter = self.span.enter();
264 let hidden_states = self.dense.forward(hidden_states)?;
265 self.layer_norm.forward(&(hidden_states + input_tensor)?)
266 }
267}
268
269struct BertLayer {
271 attention: BertAttention,
272 intermediate: BertIntermediate,
273 output: BertOutput,
274 span: tracing::Span,
275}
276
277impl BertLayer {
278 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
279 let attention = BertAttention::load(vb.pp("attention"), config)?;
280 let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
281 let output = BertOutput::load(vb.pp("output"), config)?;
282 Ok(Self {
283 attention,
284 intermediate,
285 output,
286 span: tracing::span!(tracing::Level::TRACE, "layer"),
287 })
288 }
289
290 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
291 let _enter = self.span.enter();
292 let attention_output = self.attention.forward(hidden_states, attention_mask)?;
293 let intermediate_output = self.intermediate.forward(&attention_output)?;
297 let layer_output = self
298 .output
299 .forward(&intermediate_output, &attention_output)?;
300 Ok(layer_output)
301 }
302}
303
304struct BertEncoder {
306 layers: Vec<BertLayer>,
307 span: tracing::Span,
308}
309
310impl BertEncoder {
311 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
312 let layers = (0..config.num_hidden_layers)
313 .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
314 .collect::<Result<Vec<_>>>()?;
315 let span = tracing::span!(tracing::Level::TRACE, "encoder");
316 Ok(BertEncoder { layers, span })
317 }
318
319 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
320 let _enter = self.span.enter();
321 let mut hidden_states = hidden_states.clone();
322 for layer in self.layers.iter() {
324 hidden_states = layer.forward(&hidden_states, attention_mask)?
325 }
326 Ok(hidden_states)
327 }
328}
329
330pub struct BertModel {
332 embeddings: BertEmbeddings,
333 encoder: BertEncoder,
334 span: tracing::Span,
335}
336
337impl BertModel {
338 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
339 let (embeddings, encoder) = match (
340 BertEmbeddings::load(vb.pp("embeddings"), config),
341 BertEncoder::load(vb.pp("encoder"), config),
342 ) {
343 (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
344 (Err(err), _) | (_, Err(err)) => {
345 if let Some(model_type) = &config.model_type {
346 if let (Ok(embeddings), Ok(encoder)) = (
347 BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
348 BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
349 ) {
350 (embeddings, encoder)
351 } else {
352 return Err(err);
353 }
354 } else {
355 return Err(err);
356 }
357 }
358 };
359 Ok(Self {
360 embeddings,
361 encoder,
362 span: tracing::span!(tracing::Level::TRACE, "model"),
363 })
364 }
365
366 fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
367 let attention_mask = match attention_mask.rank() {
368 3 => attention_mask.unsqueeze(1)?,
369 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
370 _ => candle_core::bail!("Wrong shape for input_ids or attention_mask"),
371 };
372 let attention_mask = attention_mask.to_dtype(dtype)?;
373 (attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
375 &Tensor::try_from(f32::MIN)?
376 .to_dtype(dtype)?
377 .to_device(attention_mask.device())?,
378 )
379 }
380
381 pub fn forward(
382 &self,
383 input_ids: &Tensor,
384 token_type_ids: &Tensor,
385 attention_mask: Option<&Tensor>,
386 ) -> Result<Tensor> {
387 let _enter = self.span.enter();
388 let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
389 let attention_mask = match attention_mask {
390 Some(attention_mask) => attention_mask.clone(),
391 None => input_ids.ones_like()?,
392 };
393 let attention_mask =
395 Self::get_extended_attention_mask(&attention_mask, embedding_output.dtype())?;
396 let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
397 Ok(sequence_output)
398 }
399}
400
401pub struct BertPipeline {
402 pub model: BertModel,
403 pub tokenizer: Tokenizer,
404}
405
406impl BertPipeline {
407 pub fn new(model: BertEmbeddingModel, device: &Device) -> anyhow::Result<Self> {
408 let model_id = match model {
409 BertEmbeddingModel::SnowflakeArcticEmbedL => {
410 "Snowflake/snowflake-arctic-embed-l-v2.0".to_string()
411 }
412 BertEmbeddingModel::Custom(model_id) => model_id,
413 };
414 once_log_info(format!("Loading embedding model ({model_id})."));
415
416 let repo = Repo::with_revision(model_id, RepoType::Model, "main".to_string());
417 let (config_filename, tokenizer_filename, weights_filename) = {
418 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
419 let api = ApiBuilder::from_cache(cache)
420 .with_progress(true)
421 .with_token(None)
422 .build()?;
423 let api = api.repo(repo);
424 let config = api.get("config.json")?;
425 let tokenizer = api.get("tokenizer.json")?;
426 let weights = api.get("model.safetensors")?;
427 (config, tokenizer, weights)
428 };
429 let config = std::fs::read_to_string(config_filename)?;
430 let config: Config = serde_json::from_str(&config)?;
431 let tokenizer =
432 Tokenizer::from_file(tokenizer_filename).map_err(candle_core::Error::msg)?;
433
434 let vb = unsafe {
435 VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, device)?
436 };
437 let model = BertModel::load(vb, &config)?;
438 Ok(Self { model, tokenizer })
439 }
440}