mistralrs_core/search/
rag.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{borrow::Cow, cmp::Ordering};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Error as E, Tensor};
7use itertools::Itertools;
8use tokenizers::{InputSequence, PaddingParams, PaddingStrategy, Tokenizer};
9
10use crate::embedding::bert::{BertModel, BertPipeline};
11
12use super::SearchResult;
13
14fn normalize_l2(v: &Tensor) -> Result<Tensor> {
15    Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
16}
17
18/// Get the indexes of requests most similar to the query. In decreasing order
19pub fn compute_most_similar(
20    device: &Device,
21    query: &str,
22    results: Vec<&SearchResult>,
23    BertPipeline { model, tokenizer }: &mut BertPipeline,
24) -> Result<Vec<usize>> {
25    let normalize_embeddings = false;
26
27    tokenizer.with_padding(Some(PaddingParams {
28        strategy: PaddingStrategy::BatchLongest,
29        ..Default::default()
30    }));
31
32    let mut mean_similarities = Vec::new();
33    for result in results {
34        let mean_content_similarity = {
35            let content = &result.content;
36            let chunks = content
37                .chars()
38                .chunks(4096)
39                .into_iter()
40                .map(|chunk| chunk.collect::<String>())
41                .collect::<Vec<_>>();
42            let sentences = [vec![query.to_string()], chunks].concat();
43            let similarities =
44                compute_similarities(model, tokenizer, device, sentences, normalize_embeddings)?;
45            similarities.iter().sum::<f32>() / similarities.len() as f32
46        };
47
48        let title_similarity = {
49            let title = &result.title;
50            let sentences = vec![query.to_string(), title.to_string()];
51            let similarities =
52                compute_similarities(model, tokenizer, device, sentences, normalize_embeddings)?;
53            similarities.iter().sum::<f32>() / similarities.len() as f32
54        };
55        mean_similarities.push(title_similarity * 2. + mean_content_similarity);
56    }
57
58    let mut indexed: Vec<(usize, f32)> = mean_similarities.iter().cloned().enumerate().collect();
59    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Less));
60    let ordered_indexes: Vec<usize> = indexed.into_iter().map(|(i, _)| i).collect();
61
62    Ok(ordered_indexes)
63}
64
65fn compute_similarities(
66    model: &BertModel,
67    tokenizer: &Tokenizer,
68    device: &Device,
69    sentences: Vec<String>,
70    normalize_embeddings: bool,
71) -> Result<Vec<f32>> {
72    let n_sentences = sentences.len();
73    let sentences_batched = sentences
74        .iter()
75        .map(|s| InputSequence::Raw(Cow::from(s)))
76        .collect::<Vec<_>>();
77    let tokens = tokenizer
78        .encode_batch(sentences_batched, true)
79        .map_err(E::msg)?;
80    let token_ids = tokens
81        .iter()
82        .map(|tokens| {
83            let tokens = tokens.get_ids().to_vec();
84            Ok(Tensor::new(tokens.as_slice(), device)?)
85        })
86        .collect::<Result<Vec<_>>>()?;
87    let attention_mask = tokens
88        .iter()
89        .map(|tokens| {
90            let tokens = tokens.get_attention_mask().to_vec();
91            Ok(Tensor::new(tokens.as_slice(), device)?)
92        })
93        .collect::<Result<Vec<_>>>()?;
94
95    let token_ids = Tensor::stack(&token_ids, 0)?;
96    let attention_mask = Tensor::stack(&attention_mask, 0)?;
97    let token_type_ids = token_ids.zeros_like()?;
98
99    let embeddings = model
100        .forward(&token_ids, &token_type_ids, Some(&attention_mask))?
101        .to_dtype(DType::F32)?;
102
103    // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
104    let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
105    let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
106    let embeddings = if normalize_embeddings {
107        normalize_l2(&embeddings)?
108    } else {
109        embeddings
110    };
111
112    let query_embedding = embeddings.get(0)?;
113    let mut similarities = vec![];
114    for j in 1..n_sentences {
115        let e_j = embeddings.get(j)?;
116        let sum_ij = (&query_embedding * &e_j)?.sum_all()?.to_scalar::<f32>()?;
117        let sum_i2 = (&query_embedding * &query_embedding)?
118            .sum_all()?
119            .to_scalar::<f32>()?;
120        let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
121        let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
122        similarities.push(cosine_similarity)
123    }
124
125    Ok(similarities)
126}