1use anyhow::{anyhow, Context, Error as AnyhowError, Result};
4use axum::{
5 extract::{Json, State},
6 http,
7 response::IntoResponse,
8};
9use base64::{prelude::BASE64_STANDARD, Engine};
10use futures::future::join_all;
11use mistralrs_core::{
12 Constraint, MistralRs, NormalRequest, Request, RequestMessage, Response, SamplingParams,
13};
14use tokio::sync::mpsc::Receiver;
15
16use crate::{
17 handler_core::{
18 base_process_non_streaming_response, create_response_channel, send_request_with_model,
19 ErrorToResponse, JsonError,
20 },
21 openai::{
22 EmbeddingData, EmbeddingEncodingFormat, EmbeddingInput, EmbeddingRequest,
23 EmbeddingResponse, EmbeddingUsage, EmbeddingVector,
24 },
25 types::{ExtractedMistralRsState, SharedMistralRsState},
26 util::{sanitize_error_message, validate_model_name},
27};
28
29pub enum EmbeddingResponder {
31 Json(EmbeddingResponse),
32 InternalError(AnyhowError),
33 ValidationError(AnyhowError),
34}
35
36struct EmbeddingWithUsage {
37 embedding: Vec<f32>,
38 prompt_tokens: usize,
39 total_tokens: usize,
40}
41
42impl IntoResponse for EmbeddingResponder {
43 fn into_response(self) -> axum::response::Response {
44 match self {
45 EmbeddingResponder::Json(s) => Json(s).into_response(),
46 EmbeddingResponder::InternalError(e) => {
47 JsonError::new(sanitize_error_message(e.root_cause()))
48 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
49 }
50 EmbeddingResponder::ValidationError(e) => {
51 JsonError::new(sanitize_error_message(e.root_cause()))
52 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
53 }
54 }
55 }
56}
57
58#[utoipa::path(
59 post,
60 tag = "Mistral.rs",
61 path = "/v1/embeddings",
62 request_body = EmbeddingRequest,
63 responses((status = 200, description = "Embeddings", body = EmbeddingResponse))
64)]
65pub async fn embeddings(
66 State(state): ExtractedMistralRsState,
67 Json(oairequest): Json<EmbeddingRequest>,
68) -> EmbeddingResponder {
69 let repr =
70 serde_json::to_string(&oairequest).expect("Serialization of embedding request failed.");
71 MistralRs::maybe_log_request(state.clone(), repr);
72
73 if let Err(e) = validate_model_name(&oairequest.model, state.clone()) {
74 return validation_error(e);
75 }
76
77 if let Some(dimensions) = oairequest.dimensions {
78 return validation_error(anyhow!(
79 "Custom embedding dimensions ({dimensions}) are not supported."
80 ));
81 }
82
83 let inputs = match normalize_inputs(oairequest.input) {
84 Ok(inputs) => inputs,
85 Err(e) => return validation_error(e),
86 };
87
88 if inputs.is_empty() {
89 return validation_error(anyhow!("input must contain at least one entry."));
90 }
91
92 let model_override = if oairequest.model == "default" {
93 None
94 } else {
95 Some(oairequest.model.clone())
96 };
97
98 let encoding = oairequest.encoding_format.unwrap_or_default();
99 let return_base64 = matches!(encoding, EmbeddingEncodingFormat::Base64);
100
101 let mut data = Vec::with_capacity(inputs.len());
102 let mut total_prompt_tokens: usize = 0;
103 let mut total_tokens: usize = 0;
104
105 match inputs {
106 Inputs::Prompt(prompts) => {
107 let futures = prompts.into_iter().map(|prompt| {
108 let state = state.clone();
109 let model_override = model_override.clone();
110 async move {
111 fetch_embedding(
112 state,
113 prompt,
114 model_override.as_deref(),
115 oairequest.truncate_sequence.unwrap_or(false),
116 )
117 .await
118 }
119 });
120
121 let results = join_all(futures).await;
122 for (index, result) in results.into_iter().enumerate() {
123 match result {
124 Ok(EmbeddingWithUsage {
125 embedding,
126 prompt_tokens,
127 total_tokens: item_total_tokens,
128 }) => {
129 let embedding = if return_base64 {
130 EmbeddingVector::Base64(encode_embedding_base64(&embedding))
131 } else {
132 EmbeddingVector::Float(embedding)
133 };
134 data.push(EmbeddingData {
135 object: "embedding",
136 embedding,
137 index,
138 });
139 total_prompt_tokens = total_prompt_tokens.saturating_add(prompt_tokens);
140 total_tokens = total_tokens.saturating_add(item_total_tokens);
141 }
142 Err(e) => {
143 MistralRs::maybe_log_error(state.clone(), e.as_ref());
144 return internal_error(e);
145 }
146 }
147 }
148 }
149 Inputs::Tokens(batches) => {
150 let futures = batches.into_iter().map(|tokens| {
151 let state = state.clone();
152 let model_override = model_override.clone();
153 async move {
154 fetch_embedding_tokens(
155 state,
156 tokens,
157 model_override.as_deref(),
158 oairequest.truncate_sequence.unwrap_or(false),
159 )
160 .await
161 }
162 });
163
164 let results = join_all(futures).await;
165 for (index, result) in results.into_iter().enumerate() {
166 match result {
167 Ok(EmbeddingWithUsage {
168 embedding,
169 prompt_tokens,
170 total_tokens: item_total_tokens,
171 }) => {
172 let embedding = if return_base64 {
173 EmbeddingVector::Base64(encode_embedding_base64(&embedding))
174 } else {
175 EmbeddingVector::Float(embedding)
176 };
177 data.push(EmbeddingData {
178 object: "embedding",
179 embedding,
180 index,
181 });
182 total_prompt_tokens = total_prompt_tokens.saturating_add(prompt_tokens);
183 total_tokens = total_tokens.saturating_add(item_total_tokens);
184 }
185 Err(e) => {
186 MistralRs::maybe_log_error(state.clone(), e.as_ref());
187 return internal_error(e);
188 }
189 }
190 }
191 }
192 }
193
194 let usage = EmbeddingUsage {
195 prompt_tokens: saturating_to_u32(total_prompt_tokens),
196 total_tokens: saturating_to_u32(total_tokens),
197 };
198
199 let response = EmbeddingResponse {
200 object: "list",
201 data,
202 model: oairequest.model,
203 usage,
204 };
205
206 MistralRs::maybe_log_response(state.clone(), &response);
207
208 EmbeddingResponder::Json(response)
209}
210
211enum Inputs {
212 Prompt(Vec<String>),
213 Tokens(Vec<Vec<u32>>),
214}
215
216impl Inputs {
217 fn is_empty(&self) -> bool {
218 match self {
219 Self::Prompt(x) => x.is_empty(),
220 Self::Tokens(x) => x.is_empty(),
221 }
222 }
223
224 fn len(&self) -> usize {
225 match self {
226 Self::Prompt(x) => x.len(),
227 Self::Tokens(x) => x.len(),
228 }
229 }
230}
231
232fn normalize_inputs(input: EmbeddingInput) -> Result<Inputs> {
233 match input {
234 EmbeddingInput::Single(s) => Ok(Inputs::Prompt(vec![s])),
235 EmbeddingInput::Multiple(items) => Ok(Inputs::Prompt(items)),
236 EmbeddingInput::Tokens(t) => Ok(Inputs::Tokens(vec![t])),
237 EmbeddingInput::TokensBatch(batch) => Ok(Inputs::Tokens(batch)),
238 }
239}
240
241async fn fetch_embedding(
242 state: SharedMistralRsState,
243 prompt: String,
244 model_id: Option<&str>,
245 truncate_sequence: bool,
246) -> Result<EmbeddingWithUsage> {
247 let (tx, mut rx) = create_response_channel(Some(1));
248
249 let request = Request::Normal(Box::new(NormalRequest {
250 id: state.next_request_id(),
251 messages: RequestMessage::Embedding { prompt },
252 sampling_params: SamplingParams::deterministic(),
253 response: tx,
254 return_logprobs: false,
255 is_streaming: false,
256 suffix: None,
257 constraint: Constraint::None,
258 tool_choice: None,
259 tools: None,
260 logits_processors: None,
261 return_raw_logits: false,
262 web_search_options: None,
263 model_id: model_id.map(|m| m.to_string()),
264 truncate_sequence,
265 }));
266
267 send_request_with_model(&state, request, model_id)
268 .await
269 .context("Failed to dispatch embedding request")?;
270
271 process_embedding_response(&mut rx, state.clone()).await
272}
273
274async fn fetch_embedding_tokens(
275 state: SharedMistralRsState,
276 tokens: Vec<u32>,
277 model_id: Option<&str>,
278 truncate_sequence: bool,
279) -> Result<EmbeddingWithUsage> {
280 let (tx, mut rx) = create_response_channel(Some(1));
281
282 let request = Request::Normal(Box::new(NormalRequest {
283 id: state.next_request_id(),
284 messages: RequestMessage::EmbeddingTokens { prompt: tokens },
285 sampling_params: SamplingParams::deterministic(),
286 response: tx,
287 return_logprobs: false,
288 is_streaming: false,
289 suffix: None,
290 constraint: Constraint::None,
291 tool_choice: None,
292 tools: None,
293 logits_processors: None,
294 return_raw_logits: false,
295 web_search_options: None,
296 model_id: model_id.map(|m| m.to_string()),
297 truncate_sequence,
298 }));
299
300 send_request_with_model(&state, request, model_id)
301 .await
302 .context("Failed to dispatch embedding request")?;
303
304 process_embedding_response(&mut rx, state.clone()).await
305}
306
307async fn process_embedding_response(
308 rx: &mut Receiver<Response>,
309 state: SharedMistralRsState,
310) -> Result<EmbeddingWithUsage> {
311 base_process_non_streaming_response(
312 rx,
313 state.clone(),
314 |_, response| match response {
315 Response::Embeddings {
316 embeddings,
317 prompt_tokens,
318 total_tokens,
319 } => Ok(EmbeddingWithUsage {
320 embedding: embeddings,
321 prompt_tokens,
322 total_tokens,
323 }),
324 Response::ValidationError(e) | Response::InternalError(e) => Err(anyhow!(e)),
325 Response::ModelError(msg, _) => Err(anyhow!(msg)),
326 Response::Done(_)
327 | Response::Chunk(_)
328 | Response::CompletionDone(_)
329 | Response::CompletionChunk(_)
330 | Response::CompletionModelError(_, _)
331 | Response::ImageGeneration(_)
332 | Response::Speech { .. }
333 | Response::Raw { .. } => Err(anyhow!(
334 "Received unexpected response type from embedding request."
335 )),
336 },
337 |_, err| Err(anyhow!(err)),
338 )
339 .await
340}
341
342fn validation_error<E>(err: E) -> EmbeddingResponder
343where
344 E: Into<AnyhowError>,
345{
346 let err = err.into();
347 EmbeddingResponder::ValidationError(err)
348}
349
350fn internal_error<E>(err: E) -> EmbeddingResponder
351where
352 E: Into<AnyhowError>,
353{
354 let err = err.into();
355 EmbeddingResponder::InternalError(err)
356}
357
358fn encode_embedding_base64(embedding: &[f32]) -> String {
359 let mut bytes = Vec::with_capacity(std::mem::size_of_val(embedding));
360 for value in embedding {
361 bytes.extend_from_slice(&value.to_le_bytes());
362 }
363 BASE64_STANDARD.encode(bytes)
364}
365
366fn saturating_to_u32(value: usize) -> u32 {
367 if value > u32::MAX as usize {
368 u32::MAX
369 } else {
370 value as u32
371 }
372}