mistralrs_server_core/
chat_completion.rs

1//! ## Chat Completions functionality and route handler.
2
3use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7    extract::{Json, State},
8    http::{self},
9    response::{
10        sse::{Event, KeepAlive, KeepAliveStream},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18    ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19    Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25    completion_core::{
26        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27        BaseCompletionResponder,
28    },
29    handler_core::{
30        base_process_non_streaming_response, create_response_channel, send_request_with_model,
31        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32    },
33    openai::{
34        ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35        ResponseFormat,
36    },
37    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39    util::{parse_audio_url, parse_image_url, sanitize_error_message, validate_model_name},
40};
41
42/// A callback function that processes streaming response chunks before they are sent to the client.
43///
44/// This hook allows modification of each chunk in the streaming response, enabling features like
45/// content filtering, transformation, or logging. The callback receives a chunk and must return
46/// a (potentially modified) chunk.
47///
48/// ### Examples
49///
50/// ```no_run
51/// use mistralrs_server_core::chat_completion::ChatCompletionOnChunkCallback;
52///
53/// let on_chunk: ChatCompletionOnChunkCallback = Box::new(|mut chunk| {
54///     // Log the chunk or modify its content
55///     println!("Processing chunk: {:?}", chunk);
56///     chunk
57/// });
58/// ```
59pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61/// A callback function that is executed when the streaming response completes.
62///
63/// This hook receives all chunks that were streamed during the response, allowing for
64/// post-processing, analytics, or cleanup operations after the stream finishes.
65///
66/// ### Examples
67///
68/// ```no_run
69/// use mistralrs_server_core::chat_completion::ChatCompletionOnDoneCallback;
70///
71/// let on_done: ChatCompletionOnDoneCallback = Box::new(|chunks| {
72///     println!("Stream completed with {} chunks", chunks.len());
73///     // Process all chunks for analytics
74/// });
75/// ```
76pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78/// A streaming response handler.
79///
80/// It processes incoming response chunks from a model and converts them
81/// into Server-Sent Events (SSE) format for real-time streaming to clients.
82pub type ChatCompletionStreamer = BaseStreamer<
83    ChatCompletionChunkResponse,
84    ChatCompletionOnChunkCallback,
85    ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89    type Item = Result<Event, axum::Error>;
90
91    /// Polls the stream for the next Server-Sent Event.
92    ///
93    /// This method implements the core streaming logic:
94    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
95    /// 2. Processes incoming model responses and converts them to SSE events
96    /// 3. Applies chunk modifications if a callback is provided
97    /// 4. Stores chunks if completion callback is configured
98    fn poll_next(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101    ) -> Poll<Option<Self::Item>> {
102        match self.done_state {
103            DoneState::SendingDone => {
104                // https://platform.openai.com/docs/api-reference/completions/create
105                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
106                self.done_state = DoneState::Done;
107                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108            }
109            DoneState::Done => {
110                if let Some(on_done) = &self.on_done {
111                    on_done(&self.chunks);
112                }
113                return Poll::Ready(None);
114            }
115            DoneState::Running => (),
116        }
117
118        match self.rx.poll_recv(cx) {
119            Poll::Ready(Some(resp)) => match resp {
120                Response::ModelError(msg, _) => {
121                    MistralRs::maybe_log_error(
122                        self.state.clone(),
123                        &ModelErrorMessage(msg.to_string()),
124                    );
125                    // Done now, just need to send the [DONE]
126                    self.done_state = DoneState::SendingDone;
127                    Poll::Ready(Some(Ok(Event::default().data(msg))))
128                }
129                Response::ValidationError(e) => Poll::Ready(Some(Ok(
130                    Event::default().data(sanitize_error_message(e.as_ref()))
131                ))),
132                Response::InternalError(e) => {
133                    MistralRs::maybe_log_error(self.state.clone(), &*e);
134                    Poll::Ready(Some(Ok(
135                        Event::default().data(sanitize_error_message(e.as_ref()))
136                    )))
137                }
138                Response::Chunk(mut response) => {
139                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
140                        self.done_state = DoneState::SendingDone;
141                    }
142                    // Done now, just need to send the [DONE]
143                    MistralRs::maybe_log_response(self.state.clone(), &response);
144
145                    if let Some(on_chunk) = &self.on_chunk {
146                        response = on_chunk(response);
147                    }
148
149                    if self.store_chunks {
150                        self.chunks.push(response.clone());
151                    }
152
153                    Poll::Ready(Some(Event::default().json_data(response)))
154                }
155                Response::Done(_) => unreachable!(),
156                Response::CompletionDone(_) => unreachable!(),
157                Response::CompletionModelError(_, _) => unreachable!(),
158                Response::CompletionChunk(_) => unreachable!(),
159                Response::ImageGeneration(_) => unreachable!(),
160                Response::Speech { .. } => unreachable!(),
161                Response::Raw { .. } => unreachable!(),
162                Response::Embeddings { .. } => unreachable!(),
163            },
164            Poll::Pending | Poll::Ready(None) => Poll::Pending,
165        }
166    }
167}
168
169/// Represents different types of chat completion responses.
170pub type ChatCompletionResponder =
171    BaseCompletionResponder<ChatCompletionResponse, KeepAliveStream<ChatCompletionStreamer>>;
172
173type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
174impl ErrorToResponse for JsonModelError {}
175
176impl IntoResponse for ChatCompletionResponder {
177    /// Converts the chat completion responder into an HTTP response.
178    fn into_response(self) -> axum::response::Response {
179        match self {
180            ChatCompletionResponder::Sse(s) => s.into_response(),
181            ChatCompletionResponder::Json(s) => Json(s).into_response(),
182            ChatCompletionResponder::InternalError(e) => {
183                JsonError::new(sanitize_error_message(e.as_ref()))
184                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
185            }
186            ChatCompletionResponder::ValidationError(e) => {
187                JsonError::new(sanitize_error_message(e.as_ref()))
188                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
189            }
190            ChatCompletionResponder::ModelError(msg, response) => {
191                JsonModelError::new(msg, response)
192                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
193            }
194        }
195    }
196}
197
198/// Parses and validates a chat completion request.
199///
200/// This function transforms an OpenAI-compatible chat completion request into the
201/// request format used by mistral.rs.
202pub async fn parse_request(
203    oairequest: ChatCompletionRequest,
204    state: SharedMistralRsState,
205    tx: Sender<Response>,
206) -> Result<(Request, bool)> {
207    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
208    MistralRs::maybe_log_request(state.clone(), repr);
209
210    // Validate that the requested model matches the loaded model
211    validate_model_name(&oairequest.model, state.clone())?;
212
213    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
214
215    let messages = match oairequest.messages {
216        Either::Left(req_messages) => {
217            let mut messages = Vec::new();
218            let mut image_urls = Vec::new();
219            let mut audio_urls = Vec::new();
220            for message in req_messages {
221                let content = match message.content.as_deref() {
222                    Some(content) => content.clone(),
223                    None => {
224                        // Handle tool call
225                        let calls = message
226                            .tool_calls
227                            .as_ref()
228                            .context(
229                                "No content was provided, expected tool calls to be provided.",
230                            )?
231                            .iter()
232                            .map(|call| &call.function)
233                            .collect::<Vec<_>>();
234
235                        Either::Left(serde_json::to_string(&calls)?)
236                    }
237                };
238
239                match &content {
240                    Either::Left(content) => {
241                        let mut message_map: IndexMap<
242                            String,
243                            Either<String, Vec<IndexMap<String, Value>>>,
244                        > = IndexMap::new();
245                        message_map.insert("role".to_string(), Either::Left(message.role));
246                        message_map.insert("content".to_string(), Either::Left(content.clone()));
247                        messages.push(message_map);
248                    }
249                    Either::Right(image_messages) => {
250                        // If there is only one message, it is possible a text message
251                        // found when rig is used as client. In this case, we need to check if
252                        // the message is a text message or an image message.
253                        if image_messages.len() == 1 {
254                            if !image_messages[0].contains_key("text") {
255                                anyhow::bail!("Expected `text` key in input message.");
256                            }
257                            let content = match image_messages[0]["text"].deref() {
258                                Either::Left(left) => left.to_string(),
259                                Either::Right(right) => format!("{right:?}"),
260                            };
261                            let mut message_map: IndexMap<
262                                String,
263                                Either<String, Vec<IndexMap<String, Value>>>,
264                            > = IndexMap::new();
265                            message_map.insert("role".to_string(), Either::Left(message.role));
266                            message_map.insert("content".to_string(), Either::Left(content));
267                            messages.push(message_map);
268                            continue;
269                        }
270                        if message.role != "user" {
271                            anyhow::bail!(
272                                "Role for an image message must be `user`, but it is {}",
273                                message.role
274                            );
275                        }
276
277                        enum ContentPart {
278                            Text { text: String },
279                            Image { image_url: String },
280                            Audio { audio_url: String },
281                        }
282
283                        let mut items = Vec::new();
284                        for image_message in image_messages {
285                            match image_message.get("type") {
286                                Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
287                                    items.push(ContentPart::Text {
288                                        text: image_message
289                                            .get("text").as_ref()
290                                            .context("Text sub-content must have `text` key.")?.as_ref()
291                                            .left().context("Text sub-content `text` key must be a string.")?.clone(),
292                                    });
293                                }
294                                Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
295                                    items.push(ContentPart::Image {
296                                        image_url: image_message
297                                            .get("image_url")
298                                            .as_ref()
299                                            .context("Image sub-content must have `image_url` key.")?
300                                            .as_ref()
301                                            .right()
302                                            .context("Image sub-content `image_url` key must be an object.")?
303                                            .get("url")
304                                            .context("Image sub-content `image_url` object must have a `url` key.")?
305                                            .clone(),
306                                    });
307                                }
308                                Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
309                                    items.push(ContentPart::Audio {
310                                        audio_url: image_message
311                                            .get("audio_url")
312                                            .as_ref()
313                                            .context("Audio sub-content must have `audio_url` key.")?
314                                            .as_ref()
315                                            .right()
316                                            .context("Audio sub-content `audio_url` key must be an object.")?
317                                            .get("url")
318                                            .context("Audio sub-content `audio_url` object must have a `url` key.")?
319                                            .clone(),
320                                    });
321                                }
322                                _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
323                            }
324                        }
325
326                        let text_content = items
327                            .iter()
328                            .filter_map(|item| match item {
329                                ContentPart::Text { text } => Some(text),
330                                _ => None,
331                            })
332                            .join(" ");
333                        let image_urls_iter = items
334                            .iter()
335                            .filter_map(|item| match item {
336                                ContentPart::Image { image_url } => Some(image_url.clone()),
337                                _ => None,
338                            })
339                            .collect::<Vec<_>>();
340
341                        let audio_urls_iter = items
342                            .iter()
343                            .filter_map(|item| match item {
344                                ContentPart::Audio { audio_url } => Some(audio_url.clone()),
345                                _ => None,
346                            })
347                            .collect::<Vec<_>>();
348
349                        let mut message_map: IndexMap<
350                            String,
351                            Either<String, Vec<IndexMap<String, Value>>>,
352                        > = IndexMap::new();
353                        message_map.insert("role".to_string(), Either::Left(message.role));
354
355                        let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
356                        for _ in &image_urls_iter {
357                            let mut content_image_map = IndexMap::new();
358                            content_image_map
359                                .insert("type".to_string(), Value::String("image".to_string()));
360                            content_map.push(content_image_map);
361                        }
362                        for _ in &audio_urls_iter {
363                            let mut content_audio_map = IndexMap::new();
364                            content_audio_map
365                                .insert("type".to_string(), Value::String("audio".to_string()));
366                            content_map.push(content_audio_map);
367                        }
368                        {
369                            let mut content_text_map = IndexMap::new();
370                            content_text_map
371                                .insert("type".to_string(), Value::String("text".to_string()));
372                            content_text_map
373                                .insert("text".to_string(), Value::String(text_content));
374                            content_map.push(content_text_map);
375                        }
376
377                        message_map.insert("content".to_string(), Either::Right(content_map));
378                        messages.push(message_map);
379                        image_urls.extend(image_urls_iter);
380                        audio_urls.extend(audio_urls_iter);
381                    }
382                }
383            }
384            if !image_urls.is_empty() || !audio_urls.is_empty() {
385                // Parse images
386                let mut images = Vec::new();
387                for url_unparsed in image_urls {
388                    let image = parse_image_url(&url_unparsed)
389                        .await
390                        .context(format!("Failed to parse image resource: {url_unparsed}"))?;
391                    images.push(image);
392                }
393
394                // Parse audios
395                let mut audios = Vec::new();
396                for url_unparsed in audio_urls {
397                    let audio = parse_audio_url(&url_unparsed)
398                        .await
399                        .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
400                    audios.push(audio);
401                }
402
403                RequestMessage::VisionChat {
404                    messages,
405                    images,
406                    audios,
407                    enable_thinking: oairequest.enable_thinking,
408                }
409            } else {
410                RequestMessage::Chat {
411                    messages,
412                    enable_thinking: oairequest.enable_thinking,
413                }
414            }
415        }
416        Either::Right(prompt) => {
417            let mut messages = Vec::new();
418            let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
419                IndexMap::new();
420            message_map.insert("role".to_string(), Either::Left("user".to_string()));
421            message_map.insert("content".to_string(), Either::Left(prompt));
422            messages.push(message_map);
423            RequestMessage::Chat {
424                messages,
425                enable_thinking: oairequest.enable_thinking,
426            }
427        }
428    };
429
430    let dry_params = get_dry_sampling_params(
431        oairequest.dry_multiplier,
432        oairequest.dry_sequence_breakers,
433        oairequest.dry_base,
434        oairequest.dry_allowed_length,
435    )?;
436
437    let is_streaming = oairequest.stream.unwrap_or(false);
438
439    if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
440        anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
441    }
442
443    let constraint = match oairequest.grammar {
444        Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
445        Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
446        Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
447        Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
448        None => match oairequest.response_format {
449            Some(ResponseFormat::JsonSchema {
450                json_schema: JsonSchemaResponseFormat { name: _, schema },
451            }) => Constraint::JsonSchema(schema),
452            Some(ResponseFormat::Text) => Constraint::None,
453            None => Constraint::None,
454        },
455    };
456
457    Ok((
458        Request::Normal(Box::new(NormalRequest {
459            id: state.next_request_id(),
460            messages,
461            sampling_params: SamplingParams {
462                temperature: oairequest.temperature,
463                top_k: oairequest.top_k,
464                top_p: oairequest.top_p,
465                min_p: oairequest.min_p,
466                top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
467                frequency_penalty: oairequest.frequency_penalty,
468                presence_penalty: oairequest.presence_penalty,
469                repetition_penalty: oairequest.repetition_penalty,
470                max_len: oairequest.max_tokens,
471                stop_toks,
472                logits_bias: oairequest.logit_bias,
473                n_choices: oairequest.n_choices,
474                dry_params,
475            },
476            response: tx,
477            return_logprobs: oairequest.logprobs,
478            is_streaming,
479            suffix: None,
480            constraint,
481            tool_choice: oairequest.tool_choice,
482            tools: oairequest.tools,
483            logits_processors: None,
484            return_raw_logits: false,
485            web_search_options: oairequest.web_search_options,
486            model_id: if oairequest.model == "default" {
487                None
488            } else {
489                Some(oairequest.model.clone())
490            },
491            truncate_sequence: oairequest.truncate_sequence.unwrap_or(false),
492        })),
493        is_streaming,
494    ))
495}
496
497/// OpenAI-compatible chat completions endpoint handler.
498#[utoipa::path(
499    post,
500    tag = "Mistral.rs",
501    path = "/v1/chat/completions",
502    request_body = ChatCompletionRequest,
503    responses((status = 200, description = "Chat completions"))
504)]
505pub async fn chatcompletions(
506    State(state): ExtractedMistralRsState,
507    Json(oairequest): Json<ChatCompletionRequest>,
508) -> ChatCompletionResponder {
509    let (tx, mut rx) = create_response_channel(None);
510
511    // Extract model_id for routing before parsing
512    let model_id = if oairequest.model == "default" {
513        None
514    } else {
515        Some(oairequest.model.clone())
516    };
517
518    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
519        Ok(x) => x,
520        Err(e) => return handle_error(state, e.into()),
521    };
522
523    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
524        return handle_error(state, e.into());
525    }
526
527    if is_streaming {
528        ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
529    } else {
530        process_non_streaming_response(&mut rx, state).await
531    }
532}
533
534/// Handle route / generation errors and logging them.
535pub fn handle_error(
536    state: SharedMistralRsState,
537    e: Box<dyn std::error::Error + Send + Sync + 'static>,
538) -> ChatCompletionResponder {
539    handle_completion_error(state, e)
540}
541
542/// Creates a SSE streamer for chat completions with optional callbacks.
543pub fn create_streamer(
544    rx: Receiver<Response>,
545    state: SharedMistralRsState,
546    on_chunk: Option<ChatCompletionOnChunkCallback>,
547    on_done: Option<ChatCompletionOnDoneCallback>,
548) -> Sse<KeepAliveStream<ChatCompletionStreamer>> {
549    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
550    let keep_alive_interval = get_keep_alive_interval();
551
552    Sse::new(streamer)
553        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
554}
555
556/// Process non-streaming chat completion responses.
557pub async fn process_non_streaming_response(
558    rx: &mut Receiver<Response>,
559    state: SharedMistralRsState,
560) -> ChatCompletionResponder {
561    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
562}
563
564/// Matches and processes different types of model responses into appropriate chat completion responses.
565pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
566    match response {
567        Response::InternalError(e) => {
568            MistralRs::maybe_log_error(state, &*e);
569            ChatCompletionResponder::InternalError(e)
570        }
571        Response::ModelError(msg, response) => {
572            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
573            MistralRs::maybe_log_response(state, &response);
574            ChatCompletionResponder::ModelError(msg, response)
575        }
576        Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
577        Response::Done(response) => {
578            MistralRs::maybe_log_response(state, &response);
579            ChatCompletionResponder::Json(response)
580        }
581        Response::Chunk(_) => unreachable!(),
582        Response::CompletionDone(_) => unreachable!(),
583        Response::CompletionModelError(_, _) => unreachable!(),
584        Response::CompletionChunk(_) => unreachable!(),
585        Response::ImageGeneration(_) => unreachable!(),
586        Response::Speech { .. } => unreachable!(),
587        Response::Raw { .. } => unreachable!(),
588        Response::Embeddings { .. } => unreachable!(),
589    }
590}