mistralrs_core/search/
rag.rs1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{borrow::Cow, cmp::Ordering};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Error as E, Tensor};
7use itertools::Itertools;
8use tokenizers::{InputSequence, PaddingParams, PaddingStrategy, Tokenizer};
9
10use crate::embedding::bert::{BertModel, BertPipeline};
11
12use super::SearchResult;
13
14fn normalize_l2(v: &Tensor) -> Result<Tensor> {
15 Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
16}
17
18pub fn compute_most_similar(
20 device: &Device,
21 query: &str,
22 results: Vec<&SearchResult>,
23 BertPipeline { model, tokenizer }: &mut BertPipeline,
24) -> Result<Vec<usize>> {
25 let normalize_embeddings = false;
26
27 tokenizer.with_padding(Some(PaddingParams {
28 strategy: PaddingStrategy::BatchLongest,
29 ..Default::default()
30 }));
31
32 let mut mean_similarities = Vec::new();
33 for result in results {
34 let mean_content_similarity = {
35 let content = &result.content;
36 let chunks = content
37 .chars()
38 .chunks(4096)
39 .into_iter()
40 .map(|chunk| chunk.collect::<String>())
41 .collect::<Vec<_>>();
42 let sentences = [vec![query.to_string()], chunks].concat();
43 let similarities =
44 compute_similarities(model, tokenizer, device, sentences, normalize_embeddings)?;
45 similarities.iter().sum::<f32>() / similarities.len() as f32
46 };
47
48 let title_similarity = {
49 let title = &result.title;
50 let sentences = vec![query.to_string(), title.to_string()];
51 let similarities =
52 compute_similarities(model, tokenizer, device, sentences, normalize_embeddings)?;
53 similarities.iter().sum::<f32>() / similarities.len() as f32
54 };
55 mean_similarities.push(title_similarity * 2. + mean_content_similarity);
56 }
57
58 let mut indexed: Vec<(usize, f32)> = mean_similarities.iter().cloned().enumerate().collect();
59 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Less));
60 let ordered_indexes: Vec<usize> = indexed.into_iter().map(|(i, _)| i).collect();
61
62 Ok(ordered_indexes)
63}
64
65fn compute_similarities(
66 model: &BertModel,
67 tokenizer: &Tokenizer,
68 device: &Device,
69 sentences: Vec<String>,
70 normalize_embeddings: bool,
71) -> Result<Vec<f32>> {
72 let n_sentences = sentences.len();
73 let sentences_batched = sentences
74 .iter()
75 .map(|s| InputSequence::Raw(Cow::from(s)))
76 .collect::<Vec<_>>();
77 let tokens = tokenizer
78 .encode_batch(sentences_batched, true)
79 .map_err(E::msg)?;
80 let token_ids = tokens
81 .iter()
82 .map(|tokens| {
83 let tokens = tokens.get_ids().to_vec();
84 Ok(Tensor::new(tokens.as_slice(), device)?)
85 })
86 .collect::<Result<Vec<_>>>()?;
87 let attention_mask = tokens
88 .iter()
89 .map(|tokens| {
90 let tokens = tokens.get_attention_mask().to_vec();
91 Ok(Tensor::new(tokens.as_slice(), device)?)
92 })
93 .collect::<Result<Vec<_>>>()?;
94
95 let token_ids = Tensor::stack(&token_ids, 0)?;
96 let attention_mask = Tensor::stack(&attention_mask, 0)?;
97 let token_type_ids = token_ids.zeros_like()?;
98
99 let embeddings = model
100 .forward(&token_ids, &token_type_ids, Some(&attention_mask))?
101 .to_dtype(DType::F32)?;
102
103 let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
105 let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
106 let embeddings = if normalize_embeddings {
107 normalize_l2(&embeddings)?
108 } else {
109 embeddings
110 };
111
112 let query_embedding = embeddings.get(0)?;
113 let mut similarities = vec![];
114 for j in 1..n_sentences {
115 let e_j = embeddings.get(j)?;
116 let sum_ij = (&query_embedding * &e_j)?.sum_all()?.to_scalar::<f32>()?;
117 let sum_i2 = (&query_embedding * &query_embedding)?
118 .sum_all()?
119 .to_scalar::<f32>()?;
120 let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
121 let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
122 similarities.push(cosine_similarity)
123 }
124
125 Ok(similarities)
126}