1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Result, Tensor};
4use candle_nn::{embedding, layer_norm, linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
5use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
6use serde::Deserialize;
7use tokenizers::Tokenizer;
8
9use crate::{
10 engine::BertEmbeddingModel, layers::Activation, utils::log::once_log_info, GLOBAL_HF_CACHE,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
14#[serde(rename_all = "lowercase")]
15enum PositionEmbeddingType {
16 #[default]
17 Absolute,
18}
19
20#[derive(Debug, Clone, PartialEq, Deserialize)]
22pub struct Config {
23 vocab_size: usize,
24 hidden_size: usize,
25 num_hidden_layers: usize,
26 num_attention_heads: usize,
27 intermediate_size: usize,
28 pub hidden_act: Activation,
29 hidden_dropout_prob: f64,
30 max_position_embeddings: usize,
31 type_vocab_size: usize,
32 initializer_range: f64,
33 layer_norm_eps: f64,
34 pad_token_id: usize,
35 #[serde(default)]
36 position_embedding_type: PositionEmbeddingType,
37 #[serde(default)]
38 use_cache: bool,
39 classifier_dropout: Option<f64>,
40 model_type: Option<String>,
41}
42
43struct BertEmbeddings {
45 word_embeddings: Embedding,
46 position_embeddings: Option<Embedding>,
47 token_type_embeddings: Embedding,
48 layer_norm: LayerNorm,
49 span: tracing::Span,
50}
51
52impl BertEmbeddings {
53 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
54 let word_embeddings = embedding(
55 config.vocab_size,
56 config.hidden_size,
57 vb.pp("word_embeddings"),
58 )?;
59 let position_embeddings = embedding(
60 config.max_position_embeddings,
61 config.hidden_size,
62 vb.pp("position_embeddings"),
63 )?;
64 let token_type_embeddings = embedding(
65 config.type_vocab_size,
66 config.hidden_size,
67 vb.pp("token_type_embeddings"),
68 )?;
69 let layer_norm = layer_norm(
70 config.hidden_size,
71 config.layer_norm_eps,
72 vb.pp("LayerNorm"),
73 )?;
74 Ok(Self {
75 word_embeddings,
76 position_embeddings: Some(position_embeddings),
77 token_type_embeddings,
78 layer_norm,
79 span: tracing::span!(tracing::Level::TRACE, "embeddings"),
80 })
81 }
82
83 fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
84 let _enter = self.span.enter();
85 let (_bsize, seq_len) = input_ids.dims2()?;
86 let input_embeddings = self.word_embeddings.forward(input_ids)?;
87 let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
88 let mut embeddings = (&input_embeddings + token_type_embeddings)?;
89 if let Some(position_embeddings) = &self.position_embeddings {
90 let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
92 let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
93 embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
94 }
95 let embeddings = self.layer_norm.forward(&embeddings)?;
96 Ok(embeddings)
97 }
98}
99
100struct BertSelfAttention {
101 query: Linear,
102 key: Linear,
103 value: Linear,
104 num_attention_heads: usize,
105 attention_head_size: usize,
106 span: tracing::Span,
107 span_softmax: tracing::Span,
108}
109
110impl BertSelfAttention {
111 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
112 let attention_head_size = config.hidden_size / config.num_attention_heads;
113 let all_head_size = config.num_attention_heads * attention_head_size;
114 let hidden_size = config.hidden_size;
115 let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
116 let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
117 let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
118 Ok(Self {
119 query,
120 key,
121 value,
122 num_attention_heads: config.num_attention_heads,
123 attention_head_size,
124 span: tracing::span!(tracing::Level::TRACE, "self-attn"),
125 span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
126 })
127 }
128
129 fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
130 let mut new_x_shape = xs.dims().to_vec();
131 new_x_shape.pop();
132 new_x_shape.push(self.num_attention_heads);
133 new_x_shape.push(self.attention_head_size);
134 let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
135 xs.contiguous()
136 }
137
138 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
139 let _enter = self.span.enter();
140 let query_layer = self.query.forward(hidden_states)?;
141 let key_layer = self.key.forward(hidden_states)?;
142 let value_layer = self.value.forward(hidden_states)?;
143
144 let query_layer = self.transpose_for_scores(&query_layer)?;
145 let key_layer = self.transpose_for_scores(&key_layer)?;
146 let value_layer = self.transpose_for_scores(&value_layer)?;
147
148 let attention_scores = query_layer.matmul(&key_layer.t()?)?;
149 let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
150 let attention_scores = attention_scores.broadcast_add(attention_mask)?;
151 let attention_probs = {
152 let _enter_sm = self.span_softmax.enter();
153 candle_nn::ops::softmax(&attention_scores, candle_core::D::Minus1)?
154 };
155
156 let context_layer = attention_probs.matmul(&value_layer)?;
157 let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
158 let context_layer = context_layer.flatten_from(candle_core::D::Minus2)?;
159 Ok(context_layer)
160 }
161}
162
163struct BertSelfOutput {
164 dense: Linear,
165 layer_norm: LayerNorm,
166 span: tracing::Span,
167}
168
169impl BertSelfOutput {
170 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
171 let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
172 let layer_norm = layer_norm(
173 config.hidden_size,
174 config.layer_norm_eps,
175 vb.pp("LayerNorm"),
176 )?;
177 Ok(Self {
178 dense,
179 layer_norm,
180 span: tracing::span!(tracing::Level::TRACE, "self-out"),
181 })
182 }
183
184 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
185 let _enter = self.span.enter();
186 let hidden_states = self.dense.forward(hidden_states)?;
187 self.layer_norm.forward(&(hidden_states + input_tensor)?)
188 }
189}
190
191struct BertAttention {
193 self_attention: BertSelfAttention,
194 self_output: BertSelfOutput,
195 span: tracing::Span,
196}
197
198impl BertAttention {
199 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
200 let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
201 let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
202 Ok(Self {
203 self_attention,
204 self_output,
205 span: tracing::span!(tracing::Level::TRACE, "attn"),
206 })
207 }
208
209 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
210 let _enter = self.span.enter();
211 let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
212 let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
213 Ok(attention_output)
214 }
215}
216
217struct BertIntermediate {
219 dense: Linear,
220 intermediate_act: Activation,
221}
222
223impl BertIntermediate {
224 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
225 let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
226 Ok(Self {
227 dense,
228 intermediate_act: config.hidden_act,
229 })
230 }
231}
232
233impl Module for BertIntermediate {
234 fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
235 let hidden_states = self.dense.forward(hidden_states)?;
236 let ys = self.intermediate_act.forward(&hidden_states)?;
237 Ok(ys)
238 }
239}
240
241struct BertOutput {
243 dense: Linear,
244 layer_norm: LayerNorm,
245 span: tracing::Span,
246}
247
248impl BertOutput {
249 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
250 let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
251 let layer_norm = layer_norm(
252 config.hidden_size,
253 config.layer_norm_eps,
254 vb.pp("LayerNorm"),
255 )?;
256 Ok(Self {
257 dense,
258 layer_norm,
259 span: tracing::span!(tracing::Level::TRACE, "out"),
260 })
261 }
262
263 fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
264 let _enter = self.span.enter();
265 let hidden_states = self.dense.forward(hidden_states)?;
266 self.layer_norm.forward(&(hidden_states + input_tensor)?)
267 }
268}
269
270struct BertLayer {
272 attention: BertAttention,
273 intermediate: BertIntermediate,
274 output: BertOutput,
275 span: tracing::Span,
276}
277
278impl BertLayer {
279 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
280 let attention = BertAttention::load(vb.pp("attention"), config)?;
281 let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
282 let output = BertOutput::load(vb.pp("output"), config)?;
283 Ok(Self {
284 attention,
285 intermediate,
286 output,
287 span: tracing::span!(tracing::Level::TRACE, "layer"),
288 })
289 }
290
291 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
292 let _enter = self.span.enter();
293 let attention_output = self.attention.forward(hidden_states, attention_mask)?;
294 let intermediate_output = self.intermediate.forward(&attention_output)?;
298 let layer_output = self
299 .output
300 .forward(&intermediate_output, &attention_output)?;
301 Ok(layer_output)
302 }
303}
304
305struct BertEncoder {
307 layers: Vec<BertLayer>,
308 span: tracing::Span,
309}
310
311impl BertEncoder {
312 fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
313 let layers = (0..config.num_hidden_layers)
314 .map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
315 .collect::<Result<Vec<_>>>()?;
316 let span = tracing::span!(tracing::Level::TRACE, "encoder");
317 Ok(BertEncoder { layers, span })
318 }
319
320 fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
321 let _enter = self.span.enter();
322 let mut hidden_states = hidden_states.clone();
323 for layer in self.layers.iter() {
325 hidden_states = layer.forward(&hidden_states, attention_mask)?
326 }
327 Ok(hidden_states)
328 }
329}
330
331pub struct BertModel {
333 embeddings: BertEmbeddings,
334 encoder: BertEncoder,
335 span: tracing::Span,
336}
337
338impl BertModel {
339 pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
340 let (embeddings, encoder) = match (
341 BertEmbeddings::load(vb.pp("embeddings"), config),
342 BertEncoder::load(vb.pp("encoder"), config),
343 ) {
344 (Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
345 (Err(err), _) | (_, Err(err)) => {
346 if let Some(model_type) = &config.model_type {
347 if let (Ok(embeddings), Ok(encoder)) = (
348 BertEmbeddings::load(vb.pp(format!("{model_type}.embeddings")), config),
349 BertEncoder::load(vb.pp(format!("{model_type}.encoder")), config),
350 ) {
351 (embeddings, encoder)
352 } else {
353 return Err(err);
354 }
355 } else {
356 return Err(err);
357 }
358 }
359 };
360 Ok(Self {
361 embeddings,
362 encoder,
363 span: tracing::span!(tracing::Level::TRACE, "model"),
364 })
365 }
366
367 fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<Tensor> {
368 let attention_mask = match attention_mask.rank() {
369 3 => attention_mask.unsqueeze(1)?,
370 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?,
371 _ => candle_core::bail!("Wrong shape for input_ids or attention_mask"),
372 };
373 let attention_mask = attention_mask.to_dtype(dtype)?;
374 (attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
376 &Tensor::try_from(f32::MIN)?
377 .to_dtype(dtype)?
378 .to_device(attention_mask.device())?,
379 )
380 }
381
382 pub fn forward(
383 &self,
384 input_ids: &Tensor,
385 token_type_ids: &Tensor,
386 attention_mask: Option<&Tensor>,
387 ) -> Result<Tensor> {
388 let _enter = self.span.enter();
389 let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
390 let attention_mask = match attention_mask {
391 Some(attention_mask) => attention_mask.clone(),
392 None => input_ids.ones_like()?,
393 };
394 let attention_mask =
396 Self::get_extended_attention_mask(&attention_mask, embedding_output.dtype())?;
397 let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
398 Ok(sequence_output)
399 }
400}
401
402pub struct BertPipeline {
403 pub model: BertModel,
404 pub tokenizer: Tokenizer,
405}
406
407impl BertPipeline {
408 pub fn new(model: BertEmbeddingModel, device: &Device) -> anyhow::Result<Self> {
409 let model_id = match model {
410 BertEmbeddingModel::SnowflakeArcticEmbedL => {
411 "Snowflake/snowflake-arctic-embed-l-v2.0".to_string()
412 }
413 BertEmbeddingModel::Custom(model_id) => model_id,
414 };
415 once_log_info(format!("Loading embedding model ({model_id})."));
416
417 let repo = Repo::with_revision(model_id, RepoType::Model, "main".to_string());
418 let (config_filename, tokenizer_filename, weights_filename) = {
419 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
420 let api = ApiBuilder::from_cache(cache)
421 .with_progress(true)
422 .with_token(None)
423 .build()?;
424 let api = api.repo(repo);
425 let config = api.get("config.json")?;
426 let tokenizer = api.get("tokenizer.json")?;
427 let weights = api.get("model.safetensors")?;
428 (config, tokenizer, weights)
429 };
430 let config = std::fs::read_to_string(config_filename)?;
431 let config: Config = serde_json::from_str(&config)?;
432 let tokenizer =
433 Tokenizer::from_file(tokenizer_filename).map_err(candle_core::Error::msg)?;
434
435 let vb = unsafe {
436 VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, device)?
437 };
438 let model = BertModel::load(vb, &config)?;
439 Ok(Self { model, tokenizer })
440 }
441}