mistralrs_server_core/
chat_completion.rs

1//! ## Chat Completions functionality and route handler.
2
3use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7    extract::{Json, State},
8    http::{self},
9    response::{
10        sse::{Event, KeepAlive},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18    ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19    Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25    completion_core::{
26        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27        BaseCompletionResponder,
28    },
29    handler_core::{
30        base_process_non_streaming_response, create_response_channel, send_request_with_model,
31        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32    },
33    openai::{
34        ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35        ResponseFormat,
36    },
37    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39    util::{parse_audio_url, parse_image_url, sanitize_error_message, validate_model_name},
40};
41
42/// A callback function that processes streaming response chunks before they are sent to the client.
43///
44/// This hook allows modification of each chunk in the streaming response, enabling features like
45/// content filtering, transformation, or logging. The callback receives a chunk and must return
46/// a (potentially modified) chunk.
47///
48/// ### Examples
49///
50/// ```no_run
51/// use mistralrs_server_core::chat_completion::ChatCompletionOnChunkCallback;
52///
53/// let on_chunk: ChatCompletionOnChunkCallback = Box::new(|mut chunk| {
54///     // Log the chunk or modify its content
55///     println!("Processing chunk: {:?}", chunk);
56///     chunk
57/// });
58/// ```
59pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61/// A callback function that is executed when the streaming response completes.
62///
63/// This hook receives all chunks that were streamed during the response, allowing for
64/// post-processing, analytics, or cleanup operations after the stream finishes.
65///
66/// ### Examples
67///
68/// ```no_run
69/// use mistralrs_server_core::chat_completion::ChatCompletionOnDoneCallback;
70///
71/// let on_done: ChatCompletionOnDoneCallback = Box::new(|chunks| {
72///     println!("Stream completed with {} chunks", chunks.len());
73///     // Process all chunks for analytics
74/// });
75/// ```
76pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78/// A streaming response handler.
79///
80/// It processes incoming response chunks from a model and converts them
81/// into Server-Sent Events (SSE) format for real-time streaming to clients.
82pub type ChatCompletionStreamer = BaseStreamer<
83    ChatCompletionChunkResponse,
84    ChatCompletionOnChunkCallback,
85    ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89    type Item = Result<Event, axum::Error>;
90
91    /// Polls the stream for the next Server-Sent Event.
92    ///
93    /// This method implements the core streaming logic:
94    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
95    /// 2. Processes incoming model responses and converts them to SSE events
96    /// 3. Applies chunk modifications if a callback is provided
97    /// 4. Stores chunks if completion callback is configured
98    fn poll_next(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101    ) -> Poll<Option<Self::Item>> {
102        match self.done_state {
103            DoneState::SendingDone => {
104                // https://platform.openai.com/docs/api-reference/completions/create
105                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
106                self.done_state = DoneState::Done;
107                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108            }
109            DoneState::Done => {
110                if let Some(on_done) = &self.on_done {
111                    on_done(&self.chunks);
112                }
113                return Poll::Ready(None);
114            }
115            DoneState::Running => (),
116        }
117
118        match self.rx.poll_recv(cx) {
119            Poll::Ready(Some(resp)) => match resp {
120                Response::ModelError(msg, _) => {
121                    MistralRs::maybe_log_error(
122                        self.state.clone(),
123                        &ModelErrorMessage(msg.to_string()),
124                    );
125                    // Done now, just need to send the [DONE]
126                    self.done_state = DoneState::SendingDone;
127                    Poll::Ready(Some(Ok(Event::default().data(msg))))
128                }
129                Response::ValidationError(e) => Poll::Ready(Some(Ok(
130                    Event::default().data(sanitize_error_message(e.as_ref()))
131                ))),
132                Response::InternalError(e) => {
133                    MistralRs::maybe_log_error(self.state.clone(), &*e);
134                    Poll::Ready(Some(Ok(
135                        Event::default().data(sanitize_error_message(e.as_ref()))
136                    )))
137                }
138                Response::Chunk(mut response) => {
139                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
140                        self.done_state = DoneState::SendingDone;
141                    }
142                    // Done now, just need to send the [DONE]
143                    MistralRs::maybe_log_response(self.state.clone(), &response);
144
145                    if let Some(on_chunk) = &self.on_chunk {
146                        response = on_chunk(response);
147                    }
148
149                    if self.store_chunks {
150                        self.chunks.push(response.clone());
151                    }
152
153                    Poll::Ready(Some(Event::default().json_data(response)))
154                }
155                Response::Done(_) => unreachable!(),
156                Response::CompletionDone(_) => unreachable!(),
157                Response::CompletionModelError(_, _) => unreachable!(),
158                Response::CompletionChunk(_) => unreachable!(),
159                Response::ImageGeneration(_) => unreachable!(),
160                Response::Speech { .. } => unreachable!(),
161                Response::Raw { .. } => unreachable!(),
162            },
163            Poll::Pending | Poll::Ready(None) => Poll::Pending,
164        }
165    }
166}
167
168/// Represents different types of chat completion responses.
169pub type ChatCompletionResponder =
170    BaseCompletionResponder<ChatCompletionResponse, ChatCompletionStreamer>;
171
172type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
173impl ErrorToResponse for JsonModelError {}
174
175impl IntoResponse for ChatCompletionResponder {
176    /// Converts the chat completion responder into an HTTP response.
177    fn into_response(self) -> axum::response::Response {
178        match self {
179            ChatCompletionResponder::Sse(s) => s.into_response(),
180            ChatCompletionResponder::Json(s) => Json(s).into_response(),
181            ChatCompletionResponder::InternalError(e) => {
182                JsonError::new(sanitize_error_message(e.as_ref()))
183                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
184            }
185            ChatCompletionResponder::ValidationError(e) => {
186                JsonError::new(sanitize_error_message(e.as_ref()))
187                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
188            }
189            ChatCompletionResponder::ModelError(msg, response) => {
190                JsonModelError::new(msg, response)
191                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
192            }
193        }
194    }
195}
196
197/// Parses and validates a chat completion request.
198///
199/// This function transforms an OpenAI-compatible chat completion request into the
200/// request format used by mistral.rs.
201pub async fn parse_request(
202    oairequest: ChatCompletionRequest,
203    state: SharedMistralRsState,
204    tx: Sender<Response>,
205) -> Result<(Request, bool)> {
206    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
207    MistralRs::maybe_log_request(state.clone(), repr);
208
209    // Validate that the requested model matches the loaded model
210    validate_model_name(&oairequest.model, state.clone())?;
211
212    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
213
214    let messages = match oairequest.messages {
215        Either::Left(req_messages) => {
216            let mut messages = Vec::new();
217            let mut image_urls = Vec::new();
218            let mut audio_urls = Vec::new();
219            for message in req_messages {
220                let content = match message.content.as_deref() {
221                    Some(content) => content.clone(),
222                    None => {
223                        // Handle tool call
224                        let calls = message
225                            .tool_calls
226                            .as_ref()
227                            .context(
228                                "No content was provided, expected tool calls to be provided.",
229                            )?
230                            .iter()
231                            .map(|call| &call.function)
232                            .collect::<Vec<_>>();
233
234                        Either::Left(serde_json::to_string(&calls)?)
235                    }
236                };
237
238                match &content {
239                    Either::Left(content) => {
240                        let mut message_map: IndexMap<
241                            String,
242                            Either<String, Vec<IndexMap<String, Value>>>,
243                        > = IndexMap::new();
244                        message_map.insert("role".to_string(), Either::Left(message.role));
245                        message_map.insert("content".to_string(), Either::Left(content.clone()));
246                        messages.push(message_map);
247                    }
248                    Either::Right(image_messages) => {
249                        // If there is only one message, it is possible a text message
250                        // found when rig is used as client. In this case, we need to check if
251                        // the message is a text message or an image message.
252                        if image_messages.len() == 1 {
253                            if !image_messages[0].contains_key("text") {
254                                anyhow::bail!("Expected `text` key in input message.");
255                            }
256                            let content = match image_messages[0]["text"].deref() {
257                                Either::Left(left) => left.to_string(),
258                                Either::Right(right) => format!("{right:?}"),
259                            };
260                            let mut message_map: IndexMap<
261                                String,
262                                Either<String, Vec<IndexMap<String, Value>>>,
263                            > = IndexMap::new();
264                            message_map.insert("role".to_string(), Either::Left(message.role));
265                            message_map.insert("content".to_string(), Either::Left(content));
266                            messages.push(message_map);
267                            continue;
268                        }
269                        if message.role != "user" {
270                            anyhow::bail!(
271                                "Role for an image message must be `user`, but it is {}",
272                                message.role
273                            );
274                        }
275
276                        enum ContentPart {
277                            Text { text: String },
278                            Image { image_url: String },
279                            Audio { audio_url: String },
280                        }
281
282                        let mut items = Vec::new();
283                        for image_message in image_messages {
284                            match image_message.get("type") {
285                                Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
286                                    items.push(ContentPart::Text {
287                                        text: image_message
288                                            .get("text").as_ref()
289                                            .context("Text sub-content must have `text` key.")?.as_ref()
290                                            .left().context("Text sub-content `text` key must be a string.")?.clone(),
291                                    });
292                                }
293                                Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
294                                    items.push(ContentPart::Image {
295                                        image_url: image_message
296                                            .get("image_url")
297                                            .as_ref()
298                                            .context("Image sub-content must have `image_url` key.")?
299                                            .as_ref()
300                                            .right()
301                                            .context("Image sub-content `image_url` key must be an object.")?
302                                            .get("url")
303                                            .context("Image sub-content `image_url` object must have a `url` key.")?
304                                            .clone(),
305                                    });
306                                }
307                                Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
308                                    items.push(ContentPart::Audio {
309                                        audio_url: image_message
310                                            .get("audio_url")
311                                            .as_ref()
312                                            .context("Audio sub-content must have `audio_url` key.")?
313                                            .as_ref()
314                                            .right()
315                                            .context("Audio sub-content `audio_url` key must be an object.")?
316                                            .get("url")
317                                            .context("Audio sub-content `audio_url` object must have a `url` key.")?
318                                            .clone(),
319                                    });
320                                }
321                                _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
322                            }
323                        }
324
325                        let text_content = items
326                            .iter()
327                            .filter_map(|item| match item {
328                                ContentPart::Text { text } => Some(text),
329                                _ => None,
330                            })
331                            .join(" ");
332                        let image_urls_iter = items
333                            .iter()
334                            .filter_map(|item| match item {
335                                ContentPart::Image { image_url } => Some(image_url.clone()),
336                                _ => None,
337                            })
338                            .collect::<Vec<_>>();
339
340                        let audio_urls_iter = items
341                            .iter()
342                            .filter_map(|item| match item {
343                                ContentPart::Audio { audio_url } => Some(audio_url.clone()),
344                                _ => None,
345                            })
346                            .collect::<Vec<_>>();
347
348                        let mut message_map: IndexMap<
349                            String,
350                            Either<String, Vec<IndexMap<String, Value>>>,
351                        > = IndexMap::new();
352                        message_map.insert("role".to_string(), Either::Left(message.role));
353
354                        let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
355                        for _ in &image_urls_iter {
356                            let mut content_image_map = IndexMap::new();
357                            content_image_map
358                                .insert("type".to_string(), Value::String("image".to_string()));
359                            content_map.push(content_image_map);
360                        }
361                        for _ in &audio_urls_iter {
362                            let mut content_audio_map = IndexMap::new();
363                            content_audio_map
364                                .insert("type".to_string(), Value::String("audio".to_string()));
365                            content_map.push(content_audio_map);
366                        }
367                        {
368                            let mut content_text_map = IndexMap::new();
369                            content_text_map
370                                .insert("type".to_string(), Value::String("text".to_string()));
371                            content_text_map
372                                .insert("text".to_string(), Value::String(text_content));
373                            content_map.push(content_text_map);
374                        }
375
376                        message_map.insert("content".to_string(), Either::Right(content_map));
377                        messages.push(message_map);
378                        image_urls.extend(image_urls_iter);
379                        audio_urls.extend(audio_urls_iter);
380                    }
381                }
382            }
383            if !image_urls.is_empty() || !audio_urls.is_empty() {
384                // Parse images
385                let mut images = Vec::new();
386                for url_unparsed in image_urls {
387                    let image = parse_image_url(&url_unparsed)
388                        .await
389                        .context(format!("Failed to parse image resource: {url_unparsed}"))?;
390                    images.push(image);
391                }
392
393                // Parse audios
394                let mut audios = Vec::new();
395                for url_unparsed in audio_urls {
396                    let audio = parse_audio_url(&url_unparsed)
397                        .await
398                        .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
399                    audios.push(audio);
400                }
401
402                RequestMessage::VisionChat {
403                    messages,
404                    images,
405                    audios,
406                    enable_thinking: oairequest.enable_thinking,
407                }
408            } else {
409                RequestMessage::Chat {
410                    messages,
411                    enable_thinking: oairequest.enable_thinking,
412                }
413            }
414        }
415        Either::Right(prompt) => {
416            let mut messages = Vec::new();
417            let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
418                IndexMap::new();
419            message_map.insert("role".to_string(), Either::Left("user".to_string()));
420            message_map.insert("content".to_string(), Either::Left(prompt));
421            messages.push(message_map);
422            RequestMessage::Chat {
423                messages,
424                enable_thinking: oairequest.enable_thinking,
425            }
426        }
427    };
428
429    let dry_params = get_dry_sampling_params(
430        oairequest.dry_multiplier,
431        oairequest.dry_sequence_breakers,
432        oairequest.dry_base,
433        oairequest.dry_allowed_length,
434    )?;
435
436    let is_streaming = oairequest.stream.unwrap_or(false);
437
438    if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
439        anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
440    }
441
442    let constraint = match oairequest.grammar {
443        Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
444        Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
445        Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
446        Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
447        None => match oairequest.response_format {
448            Some(ResponseFormat::JsonSchema {
449                json_schema: JsonSchemaResponseFormat { name: _, schema },
450            }) => Constraint::JsonSchema(schema),
451            Some(ResponseFormat::Text) => Constraint::None,
452            None => Constraint::None,
453        },
454    };
455
456    Ok((
457        Request::Normal(Box::new(NormalRequest {
458            id: state.next_request_id(),
459            messages,
460            sampling_params: SamplingParams {
461                temperature: oairequest.temperature,
462                top_k: oairequest.top_k,
463                top_p: oairequest.top_p,
464                min_p: oairequest.min_p,
465                top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
466                frequency_penalty: oairequest.frequency_penalty,
467                presence_penalty: oairequest.presence_penalty,
468                max_len: oairequest.max_tokens,
469                stop_toks,
470                logits_bias: oairequest.logit_bias,
471                n_choices: oairequest.n_choices,
472                dry_params,
473            },
474            response: tx,
475            return_logprobs: oairequest.logprobs,
476            is_streaming,
477            suffix: None,
478            constraint,
479            tool_choice: oairequest.tool_choice,
480            tools: oairequest.tools,
481            logits_processors: None,
482            return_raw_logits: false,
483            web_search_options: oairequest.web_search_options,
484            model_id: if oairequest.model == "default" {
485                None
486            } else {
487                Some(oairequest.model.clone())
488            },
489        })),
490        is_streaming,
491    ))
492}
493
494/// OpenAI-compatible chat completions endpoint handler.
495#[utoipa::path(
496    post,
497    tag = "Mistral.rs",
498    path = "/v1/chat/completions",
499    request_body = ChatCompletionRequest,
500    responses((status = 200, description = "Chat completions"))
501)]
502pub async fn chatcompletions(
503    State(state): ExtractedMistralRsState,
504    Json(oairequest): Json<ChatCompletionRequest>,
505) -> ChatCompletionResponder {
506    let (tx, mut rx) = create_response_channel(None);
507
508    // Extract model_id for routing before parsing
509    let model_id = if oairequest.model == "default" {
510        None
511    } else {
512        Some(oairequest.model.clone())
513    };
514
515    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
516        Ok(x) => x,
517        Err(e) => return handle_error(state, e.into()),
518    };
519
520    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
521        return handle_error(state, e.into());
522    }
523
524    if is_streaming {
525        ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
526    } else {
527        process_non_streaming_response(&mut rx, state).await
528    }
529}
530
531/// Handle route / generation errors and logging them.
532pub fn handle_error(
533    state: SharedMistralRsState,
534    e: Box<dyn std::error::Error + Send + Sync + 'static>,
535) -> ChatCompletionResponder {
536    handle_completion_error(state, e)
537}
538
539/// Creates a SSE streamer for chat completions with optional callbacks.
540pub fn create_streamer(
541    rx: Receiver<Response>,
542    state: SharedMistralRsState,
543    on_chunk: Option<ChatCompletionOnChunkCallback>,
544    on_done: Option<ChatCompletionOnDoneCallback>,
545) -> Sse<ChatCompletionStreamer> {
546    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
547    let keep_alive_interval = get_keep_alive_interval();
548
549    Sse::new(streamer)
550        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
551}
552
553/// Process non-streaming chat completion responses.
554pub async fn process_non_streaming_response(
555    rx: &mut Receiver<Response>,
556    state: SharedMistralRsState,
557) -> ChatCompletionResponder {
558    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
559}
560
561/// Matches and processes different types of model responses into appropriate chat completion responses.
562pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
563    match response {
564        Response::InternalError(e) => {
565            MistralRs::maybe_log_error(state, &*e);
566            ChatCompletionResponder::InternalError(e)
567        }
568        Response::ModelError(msg, response) => {
569            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
570            MistralRs::maybe_log_response(state, &response);
571            ChatCompletionResponder::ModelError(msg, response)
572        }
573        Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
574        Response::Done(response) => {
575            MistralRs::maybe_log_response(state, &response);
576            ChatCompletionResponder::Json(response)
577        }
578        Response::Chunk(_) => unreachable!(),
579        Response::CompletionDone(_) => unreachable!(),
580        Response::CompletionModelError(_, _) => unreachable!(),
581        Response::CompletionChunk(_) => unreachable!(),
582        Response::ImageGeneration(_) => unreachable!(),
583        Response::Speech { .. } => unreachable!(),
584        Response::Raw { .. } => unreachable!(),
585    }
586}