mistralrs_server_core/
embeddings.rs

1//! OpenAI-compatible embeddings endpoint.
2
3use anyhow::{anyhow, Context, Error as AnyhowError, Result};
4use axum::{
5    extract::{Json, State},
6    http,
7    response::IntoResponse,
8};
9use base64::{prelude::BASE64_STANDARD, Engine};
10use futures::future::join_all;
11use mistralrs_core::{
12    Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams,
13};
14use tokio::sync::mpsc::Receiver;
15
16use crate::{
17    handler_core::{
18        base_process_non_streaming_response, create_response_channel, send_request_with_model,
19        ErrorToResponse, JsonError,
20    },
21    openai::{
22        EmbeddingData, EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest,
23        EmbeddingResponse, EmbeddingUsage, EmbeddingVector,
24    },
25    types::{ExtractedMistralRsState, SharedMistralRsState},
26    util::{sanitize_error_message, validate_model_name},
27};
28
29/// Represents different types of embeddings responses.
30pub enum EmbeddingResponder {
31    Json(EmbeddingResponse),
32    InternalError(AnyhowError),
33    ValidationError(AnyhowError),
34}
35
36struct EmbeddingWithUsage {
37    embedding: Vec<f32>,
38    prompt_tokens: usize,
39    total_tokens: usize,
40}
41
42impl IntoResponse for EmbeddingResponder {
43    fn into_response(self) -> axum::response::Response {
44        match self {
45            EmbeddingResponder::Json(s) => Json(s).into_response(),
46            EmbeddingResponder::InternalError(e) => {
47                JsonError::new(sanitize_error_message(e.root_cause()))
48                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
49            }
50            EmbeddingResponder::ValidationError(e) => {
51                JsonError::new(sanitize_error_message(e.root_cause()))
52                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
53            }
54        }
55    }
56}
57
58#[utoipa::path(
59    post,
60    tag = "Mistral.rs",
61    path = "/v1/embeddings",
62    request_body = EmbeddingRequest,
63    responses((status = 200, description = "Embeddings", body = EmbeddingResponse))
64)]
65pub async fn embeddings(
66    State(state): ExtractedMistralRsState,
67    Json(oairequest): Json<EmbeddingRequest>,
68) -> EmbeddingResponder {
69    let repr =
70        serde_json::to_string(&oairequest).expect("Serialization of embedding request failed.");
71    MistralRs::maybe_log_request(state.clone(), repr);
72
73    if let Err(e) = validate_model_name(&oairequest.model, state.clone()) {
74        return validation_error(e);
75    }
76
77    if let Some(dimensions) = oairequest.dimensions {
78        return validation_error(anyhow!(
79            "Custom embedding dimensions ({dimensions}) are not supported."
80        ));
81    }
82
83    let inputs = match normalize_inputs(oairequest.input) {
84        Ok(inputs) => inputs,
85        Err(e) => return validation_error(e),
86    };
87
88    if inputs.is_empty() {
89        return validation_error(anyhow!("input must contain at least one entry."));
90    }
91
92    let model_override = if oairequest.model == "default" {
93        None
94    } else {
95        Some(oairequest.model.clone())
96    };
97
98    let encoding = oairequest.encoding_format.unwrap_or_default();
99    let return_base64 = matches!(encoding, EmbeddingEncodingFormat::Base64);
100
101    let mut data = Vec::with_capacity(inputs.len());
102    let mut total_prompt_tokens: usize = 0;
103    let mut total_tokens: usize = 0;
104
105    match inputs {
106        Inputs::Prompt(prompts) => {
107            let futures = prompts.into_iter().map(|prompt| {
108                let state = state.clone();
109                let model_override = model_override.clone();
110                async move {
111                    fetch_embedding(
112                        state,
113                        prompt,
114                        model_override.as_deref(),
115                        oairequest.truncate_sequence.unwrap_or(false),
116                    )
117                    .await
118                }
119            });
120
121            let results = join_all(futures).await;
122            for (index, result) in results.into_iter().enumerate() {
123                match result {
124                    Ok(EmbeddingWithUsage {
125                        embedding,
126                        prompt_tokens,
127                        total_tokens: item_total_tokens,
128                    }) => {
129                        let embedding = if return_base64 {
130                            EmbeddingVector::Base64(encode_embedding_base64(&embedding))
131                        } else {
132                            EmbeddingVector::Float(embedding)
133                        };
134                        data.push(EmbeddingData {
135                            object: "embedding",
136                            embedding,
137                            index,
138                        });
139                        total_prompt_tokens = total_prompt_tokens.saturating_add(prompt_tokens);
140                        total_tokens = total_tokens.saturating_add(item_total_tokens);
141                    }
142                    Err(e) => {
143                        MistralRs::maybe_log_error(state.clone(), e.as_ref());
144                        return internal_error(e);
145                    }
146                }
147            }
148        }
149        Inputs::Tokens(batches) => {
150            let futures = batches.into_iter().map(|tokens| {
151                let state = state.clone();
152                let model_override = model_override.clone();
153                async move {
154                    fetch_embedding_tokens(
155                        state,
156                        tokens,
157                        model_override.as_deref(),
158                        oairequest.truncate_sequence.unwrap_or(false),
159                    )
160                    .await
161                }
162            });
163
164            let results = join_all(futures).await;
165            for (index, result) in results.into_iter().enumerate() {
166                match result {
167                    Ok(EmbeddingWithUsage {
168                        embedding,
169                        prompt_tokens,
170                        total_tokens: item_total_tokens,
171                    }) => {
172                        let embedding = if return_base64 {
173                            EmbeddingVector::Base64(encode_embedding_base64(&embedding))
174                        } else {
175                            EmbeddingVector::Float(embedding)
176                        };
177                        data.push(EmbeddingData {
178                            object: "embedding",
179                            embedding,
180                            index,
181                        });
182                        total_prompt_tokens = total_prompt_tokens.saturating_add(prompt_tokens);
183                        total_tokens = total_tokens.saturating_add(item_total_tokens);
184                    }
185                    Err(e) => {
186                        MistralRs::maybe_log_error(state.clone(), e.as_ref());
187                        return internal_error(e);
188                    }
189                }
190            }
191        }
192    }
193
194    let usage = EmbeddingUsage {
195        prompt_tokens: saturating_to_u32(total_prompt_tokens),
196        total_tokens: saturating_to_u32(total_tokens),
197    };
198
199    let response = EmbeddingResponse {
200        object: "list",
201        data,
202        model: oairequest.model,
203        usage,
204    };
205
206    MistralRs::maybe_log_response(state.clone(), &response);
207
208    EmbeddingResponder::Json(response)
209}
210
211enum Inputs {
212    Prompt(Vec<String>),
213    Tokens(Vec<Vec<u32>>),
214}
215
216impl Inputs {
217    fn is_empty(&self) -> bool {
218        match self {
219            Self::Prompt(x) => x.is_empty(),
220            Self::Tokens(x) => x.is_empty(),
221        }
222    }
223
224    fn len(&self) -> usize {
225        match self {
226            Self::Prompt(x) => x.len(),
227            Self::Tokens(x) => x.len(),
228        }
229    }
230}
231
232fn normalize_inputs(input: EmbeddingInput) -> Result<Inputs> {
233    match input {
234        EmbeddingInput::Single(s) => Ok(Inputs::Prompt(vec![s])),
235        EmbeddingInput::Multiple(items) => Ok(Inputs::Prompt(items)),
236        EmbeddingInput::Tokens(t) => Ok(Inputs::Tokens(vec![t])),
237        EmbeddingInput::TokensBatch(batch) => Ok(Inputs::Tokens(batch)),
238    }
239}
240
241async fn fetch_embedding(
242    state: SharedMistralRsState,
243    prompt: String,
244    model_id: Option<&str>,
245    truncate_sequence: bool,
246) -> Result<EmbeddingWithUsage> {
247    let (tx, mut rx) = create_response_channel(Some(1));
248
249    let request = Request::Normal(Box::new(NormalRequest {
250        id: state.next_request_id(),
251        messages: RequestMessage::Embedding { prompt },
252        sampling_params: SamplingParams::deterministic(),
253        response: tx,
254        return_logprobs: false,
255        is_streaming: false,
256        suffix: None,
257        constraint: Constraint::None,
258        tool_choice: None,
259        tools: None,
260        logits_processors: None,
261        return_raw_logits: false,
262        web_search_options: None,
263        model_id: model_id.map(|m| m.to_string()),
264        truncate_sequence,
265    }));
266
267    send_request_with_model(&state, request, model_id)
268        .await
269        .context("Failed to dispatch embedding request")?;
270
271    process_embedding_response(&mut rx, state.clone()).await
272}
273
274async fn fetch_embedding_tokens(
275    state: SharedMistralRsState,
276    tokens: Vec<u32>,
277    model_id: Option<&str>,
278    truncate_sequence: bool,
279) -> Result<EmbeddingWithUsage> {
280    let (tx, mut rx) = create_response_channel(Some(1));
281
282    let request = Request::Normal(Box::new(NormalRequest {
283        id: state.next_request_id(),
284        messages: RequestMessage::EmbeddingTokens { prompt: tokens },
285        sampling_params: SamplingParams::deterministic(),
286        response: tx,
287        return_logprobs: false,
288        is_streaming: false,
289        suffix: None,
290        constraint: Constraint::None,
291        tool_choice: None,
292        tools: None,
293        logits_processors: None,
294        return_raw_logits: false,
295        web_search_options: None,
296        model_id: model_id.map(|m| m.to_string()),
297        truncate_sequence,
298    }));
299
300    send_request_with_model(&state, request, model_id)
301        .await
302        .context("Failed to dispatch embedding request")?;
303
304    process_embedding_response(&mut rx, state.clone()).await
305}
306
307async fn process_embedding_response(
308    rx: &mut Receiver<Response>,
309    state: SharedMistralRsState,
310) -> Result<EmbeddingWithUsage> {
311    base_process_non_streaming_response(
312        rx,
313        state.clone(),
314        |_, response| match response {
315            Response::Embeddings {
316                embeddings,
317                prompt_tokens,
318                total_tokens,
319            } => Ok(EmbeddingWithUsage {
320                embedding: embeddings,
321                prompt_tokens,
322                total_tokens,
323            }),
324            Response::ValidationError(e) | Response::InternalError(e) => Err(anyhow!(e)),
325            Response::ModelError(msg, _) => Err(anyhow!(msg)),
326            Response::Done(_)
327            | Response::Chunk(_)
328            | Response::CompletionDone(_)
329            | Response::CompletionChunk(_)
330            | Response::CompletionModelError(_, _)
331            | Response::ImageGeneration(_)
332            | Response::Speech { .. }
333            | Response::Raw { .. } => Err(anyhow!(
334                "Received unexpected response type from embedding request."
335            )),
336        },
337        |_, err| Err(anyhow!(err)),
338    )
339    .await
340}
341
342fn validation_error<E>(err: E) -> EmbeddingResponder
343where
344    E: Into<AnyhowError>,
345{
346    let err = err.into();
347    EmbeddingResponder::ValidationError(err)
348}
349
350fn internal_error<E>(err: E) -> EmbeddingResponder
351where
352    E: Into<AnyhowError>,
353{
354    let err = err.into();
355    EmbeddingResponder::InternalError(err)
356}
357
358fn encode_embedding_base64(embedding: &[f32]) -> String {
359    let mut bytes = Vec::with_capacity(std::mem::size_of_val(embedding));
360    for value in embedding {
361        bytes.extend_from_slice(&value.to_le_bytes());
362    }
363    BASE64_STANDARD.encode(bytes)
364}
365
366fn saturating_to_u32(value: usize) -> u32 {
367    if value > u32::MAX as usize {
368        u32::MAX
369    } else {
370        value as u32
371    }
372}