mistralrs_server_core/
chat_completion.rs

1//! ## Chat Completions functionality and route handler.
2
3use std::{env, error::Error, ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7    extract::{Json, State},
8    http::{self, StatusCode},
9    response::{
10        sse::{Event, KeepAlive},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18    ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs,
19    NormalRequest, Request, RequestMessage, Response, SamplingParams,
20    StopTokens as InternalStopTokens,
21};
22use serde::Serialize;
23use serde_json::Value;
24use tokio::sync::mpsc::{channel, Receiver, Sender};
25
26use crate::{
27    openai::{
28        ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
29        ResponseFormat, StopTokens,
30    },
31    types::{ExtractedMistralRsState, SharedMistralRsState},
32    util,
33};
34
35/// A callback function that processes streaming response chunks before they are sent to the client.
36///
37/// This hook allows modification of each chunk in the streaming response, enabling features like
38/// content filtering, transformation, or logging. The callback receives a chunk and must return
39/// a (potentially modified) chunk.
40///
41/// ### Examples
42///
43/// ```no_run
44/// use mistralrs_server_core::chat_completion::OnChunkCallback;
45///
46/// let on_chunk: OnChunkCallback = Box::new(|mut chunk| {
47///     // Log the chunk or modify its content
48///     println!("Processing chunk: {:?}", chunk);
49///     chunk
50/// });
51/// ```
52pub type OnChunkCallback =
53    Box<dyn Fn(ChatCompletionChunkResponse) -> ChatCompletionChunkResponse + Send + Sync>;
54
55/// A callback function that is executed when the streaming response completes.
56///
57/// This hook receives all chunks that were streamed during the response, allowing for
58/// post-processing, analytics, or cleanup operations after the stream finishes.
59///
60/// ### Examples
61///
62/// ```no_run
63/// use mistralrs_server_core::chat_completion::OnDoneCallback;
64///
65/// let on_done: OnDoneCallback = Box::new(|chunks| {
66///     println!("Stream completed with {} chunks", chunks.len());
67///     // Process all chunks for analytics
68/// });
69/// ```
70pub type OnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sync>;
71
72/// Default buffer size for the response channel used in streaming operations.
73///
74/// This constant defines the maximum number of response messages that can be buffered
75/// in the channel before backpressure is applied. A larger buffer reduces the likelihood
76/// of blocking but uses more memory.
77pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;
78
79/// Default keep-alive interval for Server-Sent Events (SSE) streams in milliseconds.
80pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000;
81
82/// Internal error type for model-related errors with a descriptive message.
83///
84/// This struct wraps error messages from the underlying model and implements
85/// the standard error traits for proper error handling and display.
86#[derive(Debug)]
87struct ModelErrorMessage(String);
88impl std::fmt::Display for ModelErrorMessage {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        write!(f, "{}", self.0)
91    }
92}
93impl std::error::Error for ModelErrorMessage {}
94
95/// Represents the current state of a streaming response.
96enum DoneState {
97    /// The stream is actively processing and sending response chunks
98    Running,
99    /// The stream has finished processing and is about to send the `[DONE]` message
100    SendingDone,
101    /// The stream has completed entirely
102    Done,
103}
104
105/// A streaming response handler.
106///
107/// It processes incoming response chunks from a model and converts them
108/// into Server-Sent Events (SSE) format for real-time streaming to clients.
109pub struct Streamer {
110    /// Channel receiver for incoming model responses
111    rx: Receiver<Response>,
112    /// Current state of the streaming operation
113    done_state: DoneState,
114    /// Underlying mistral.rs instance
115    state: SharedMistralRsState,
116    /// Whether to store chunks for the completion callback
117    store_chunks: bool,
118    /// All chunks received during streaming (if `store_chunks` is true)
119    chunks: Vec<ChatCompletionChunkResponse>,
120    /// Optional callback to process each chunk before sending
121    on_chunk: Option<OnChunkCallback>,
122    /// Optional callback to execute when streaming completes
123    on_done: Option<OnDoneCallback>,
124}
125
126impl futures::Stream for Streamer {
127    type Item = Result<Event, axum::Error>;
128
129    /// Polls the stream for the next Server-Sent Event.
130    ///
131    /// This method implements the core streaming logic:
132    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
133    /// 2. Processes incoming model responses and converts them to SSE events
134    /// 3. Applies chunk modifications if a callback is provided
135    /// 4. Stores chunks if completion callback is configured
136    fn poll_next(
137        mut self: Pin<&mut Self>,
138        cx: &mut std::task::Context<'_>,
139    ) -> Poll<Option<Self::Item>> {
140        match self.done_state {
141            DoneState::SendingDone => {
142                // https://platform.openai.com/docs/api-reference/completions/create
143                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
144                self.done_state = DoneState::Done;
145                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
146            }
147            DoneState::Done => {
148                if let Some(on_done) = &self.on_done {
149                    on_done(&self.chunks);
150                }
151                return Poll::Ready(None);
152            }
153            DoneState::Running => (),
154        }
155
156        match self.rx.poll_recv(cx) {
157            Poll::Ready(Some(resp)) => match resp {
158                Response::ModelError(msg, _) => {
159                    MistralRs::maybe_log_error(
160                        self.state.clone(),
161                        &ModelErrorMessage(msg.to_string()),
162                    );
163                    // Done now, just need to send the [DONE]
164                    self.done_state = DoneState::SendingDone;
165                    Poll::Ready(Some(Ok(Event::default().data(msg))))
166                }
167                Response::ValidationError(e) => {
168                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
169                }
170                Response::InternalError(e) => {
171                    MistralRs::maybe_log_error(self.state.clone(), &*e);
172                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
173                }
174                Response::Chunk(mut response) => {
175                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
176                        self.done_state = DoneState::SendingDone;
177                    }
178                    // Done now, just need to send the [DONE]
179                    MistralRs::maybe_log_response(self.state.clone(), &response);
180
181                    if let Some(on_chunk) = &self.on_chunk {
182                        response = on_chunk(response);
183                    }
184
185                    if self.store_chunks {
186                        self.chunks.push(response.clone());
187                    }
188
189                    Poll::Ready(Some(Event::default().json_data(response)))
190                }
191                Response::Done(_) => unreachable!(),
192                Response::CompletionDone(_) => unreachable!(),
193                Response::CompletionModelError(_, _) => unreachable!(),
194                Response::CompletionChunk(_) => unreachable!(),
195                Response::ImageGeneration(_) => unreachable!(),
196                Response::Speech { .. } => unreachable!(),
197                Response::Raw { .. } => unreachable!(),
198            },
199            Poll::Pending | Poll::Ready(None) => Poll::Pending,
200        }
201    }
202}
203
204/// Represents different types of chat completion responses.
205pub enum ChatCompletionResponder {
206    /// Server-Sent Events streaming response
207    Sse(Sse<Streamer>),
208    /// Complete JSON response for non-streaming requests
209    Json(ChatCompletionResponse),
210    /// Model error with partial response data
211    ModelError(String, ChatCompletionResponse),
212    /// Internal server error
213    InternalError(Box<dyn Error>),
214    /// Request validation error
215    ValidationError(Box<dyn Error>),
216}
217
218/// Trait for converting errors to HTTP responses with appropriate status codes.
219trait ErrorToResponse: Serialize {
220    /// Converts the error to an HTTP response with the specified status code.
221    fn to_response(&self, code: StatusCode) -> axum::response::Response {
222        let mut r = Json(self).into_response();
223        *r.status_mut() = code;
224        r
225    }
226}
227
228/// Standard JSON error response structure.
229#[derive(Serialize)]
230struct JsonError {
231    message: String,
232}
233
234impl JsonError {
235    /// Creates a new JSON error with the specified message.
236    fn new(message: String) -> Self {
237        Self { message }
238    }
239}
240impl ErrorToResponse for JsonError {}
241
242/// JSON error response structure for model errors.
243#[derive(Serialize)]
244struct JsonModelError {
245    message: String,
246    /// Partial response data that was generated before the error occurred
247    partial_response: ChatCompletionResponse,
248}
249
250impl JsonModelError {
251    /// Creates a new JSON model error with message and partial response.
252    fn new(message: String, partial_response: ChatCompletionResponse) -> Self {
253        Self {
254            message,
255            partial_response,
256        }
257    }
258}
259
260impl ErrorToResponse for JsonModelError {}
261
262impl IntoResponse for ChatCompletionResponder {
263    /// Converts the chat completion responder into an HTTP response.
264    fn into_response(self) -> axum::response::Response {
265        match self {
266            ChatCompletionResponder::Sse(s) => s.into_response(),
267            ChatCompletionResponder::Json(s) => Json(s).into_response(),
268            ChatCompletionResponder::InternalError(e) => {
269                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
270            }
271            ChatCompletionResponder::ValidationError(e) => {
272                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
273            }
274            ChatCompletionResponder::ModelError(msg, response) => {
275                JsonModelError::new(msg, response)
276                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
277            }
278        }
279    }
280}
281
282/// Parses and validates a chat completion request.
283///
284/// This function transforms an OpenAI-compatible chat completion request into the
285/// request format used by mistral.rs.
286pub async fn parse_request(
287    oairequest: ChatCompletionRequest,
288    state: SharedMistralRsState,
289    tx: Sender<Response>,
290) -> Result<(Request, bool)> {
291    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
292    MistralRs::maybe_log_request(state.clone(), repr);
293
294    let stop_toks = match oairequest.stop_seqs {
295        Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
296        Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
297        None => None,
298    };
299    let messages = match oairequest.messages {
300        Either::Left(req_messages) => {
301            let mut messages = Vec::new();
302            let mut image_urls = Vec::new();
303            for message in req_messages {
304                let content = match message.content.as_deref() {
305                    Some(content) => content.clone(),
306                    None => {
307                        // Handle tool call
308                        let calls = message
309                            .tool_calls
310                            .as_ref()
311                            .context(
312                                "No content was provided, expected tool calls to be provided.",
313                            )?
314                            .iter()
315                            .map(|call| &call.function)
316                            .collect::<Vec<_>>();
317
318                        Either::Left(serde_json::to_string(&calls)?)
319                    }
320                };
321
322                match &content {
323                    Either::Left(content) => {
324                        let mut message_map: IndexMap<
325                            String,
326                            Either<String, Vec<IndexMap<String, Value>>>,
327                        > = IndexMap::new();
328                        message_map.insert("role".to_string(), Either::Left(message.role));
329                        message_map.insert("content".to_string(), Either::Left(content.clone()));
330                        messages.push(message_map);
331                    }
332                    Either::Right(image_messages) => {
333                        // If there is only one message, it is possible a text message
334                        // found when rig is used as client. In this case, we need to check if
335                        // the message is a text message or an image message.
336                        if image_messages.len() == 1 {
337                            if !image_messages[0].contains_key("text") {
338                                anyhow::bail!("Expected `text` key in input message.");
339                            }
340                            let content = match image_messages[0]["text"].deref() {
341                                Either::Left(left) => left.to_string(),
342                                Either::Right(right) => format!("{:?}", right),
343                            };
344                            let mut message_map: IndexMap<
345                                String,
346                                Either<String, Vec<IndexMap<String, Value>>>,
347                            > = IndexMap::new();
348                            message_map.insert("role".to_string(), Either::Left(message.role));
349                            message_map.insert("content".to_string(), Either::Left(content));
350                            messages.push(message_map);
351                            continue;
352                        }
353                        if message.role != "user" {
354                            anyhow::bail!(
355                                "Role for an image message must be `user`, but it is {}",
356                                message.role
357                            );
358                        }
359
360                        enum ContentPart {
361                            Text { text: String },
362                            Image { image_url: String },
363                        }
364
365                        let mut items = Vec::new();
366                        for image_message in image_messages {
367                            match image_message.get("type") {
368                                Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
369                                    items.push(ContentPart::Text {
370                                        text: image_message
371                                            .get("text").as_ref()
372                                            .context("Text sub-content must have `text` key.")?.as_ref()
373                                            .left().context("Text sub-content `text` key must be a string.")?.clone(),
374                                    });
375                                }
376                                Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
377                                    items.push(ContentPart::Image {
378                                        image_url: image_message.get("image_url").as_ref()
379                                            .context("Image sub-content must have `image_url` key.")?.as_ref()
380                                            .right()
381                                            .context("Image sub-content `image_url` key must be an object.")?
382                                            .get("url")
383                                            .context("Image sub-content `image_url` object must have a `url` key.")?.clone()
384                                    });
385                                }
386                                _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
387                            }
388                        }
389
390                        let text_content = items
391                            .iter()
392                            .filter_map(|item| match item {
393                                ContentPart::Text { text } => Some(text),
394                                _ => None,
395                            })
396                            .join(" ");
397                        let image_urls_iter = items
398                            .iter()
399                            .filter_map(|item| match item {
400                                ContentPart::Image { image_url } => Some(image_url.clone()),
401                                _ => None,
402                            })
403                            .collect::<Vec<_>>();
404
405                        let mut message_map: IndexMap<
406                            String,
407                            Either<String, Vec<IndexMap<String, Value>>>,
408                        > = IndexMap::new();
409                        message_map.insert("role".to_string(), Either::Left(message.role));
410
411                        let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
412                        for _ in &image_urls_iter {
413                            let mut content_image_map = IndexMap::new();
414                            content_image_map
415                                .insert("type".to_string(), Value::String("image".to_string()));
416                            content_map.push(content_image_map);
417                        }
418                        {
419                            let mut content_text_map = IndexMap::new();
420                            content_text_map
421                                .insert("type".to_string(), Value::String("text".to_string()));
422                            content_text_map
423                                .insert("text".to_string(), Value::String(text_content));
424                            content_map.push(content_text_map);
425                        }
426
427                        message_map.insert("content".to_string(), Either::Right(content_map));
428                        messages.push(message_map);
429                        image_urls.extend(image_urls_iter);
430                    }
431                }
432            }
433            if !image_urls.is_empty() {
434                let mut images = Vec::new();
435                for url_unparsed in image_urls {
436                    let image = util::parse_image_url(&url_unparsed)
437                        .await
438                        .context(format!("Failed to parse image resource: {}", url_unparsed))?;
439
440                    images.push(image);
441                }
442                RequestMessage::VisionChat {
443                    messages,
444                    images,
445                    enable_thinking: oairequest.enable_thinking,
446                }
447            } else {
448                RequestMessage::Chat {
449                    messages,
450                    enable_thinking: oairequest.enable_thinking,
451                }
452            }
453        }
454        Either::Right(prompt) => {
455            let mut messages = Vec::new();
456            let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
457                IndexMap::new();
458            message_map.insert("role".to_string(), Either::Left("user".to_string()));
459            message_map.insert("content".to_string(), Either::Left(prompt));
460            messages.push(message_map);
461            RequestMessage::Chat {
462                messages,
463                enable_thinking: oairequest.enable_thinking,
464            }
465        }
466    };
467
468    let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
469        Some(DrySamplingParams::new_with_defaults(
470            dry_multiplier,
471            oairequest.dry_sequence_breakers,
472            oairequest.dry_base,
473            oairequest.dry_allowed_length,
474        )?)
475    } else {
476        None
477    };
478
479    let is_streaming = oairequest.stream.unwrap_or(false);
480
481    if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
482        anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
483    }
484
485    let constraint = match oairequest.grammar {
486        Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
487        Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
488        Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
489        Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
490        None => match oairequest.response_format {
491            Some(ResponseFormat::JsonSchema {
492                json_schema: JsonSchemaResponseFormat { name: _, schema },
493            }) => Constraint::JsonSchema(schema),
494            Some(ResponseFormat::Text) => Constraint::None,
495            None => Constraint::None,
496        },
497    };
498
499    Ok((
500        Request::Normal(Box::new(NormalRequest {
501            id: state.next_request_id(),
502            messages,
503            sampling_params: SamplingParams {
504                temperature: oairequest.temperature,
505                top_k: oairequest.top_k,
506                top_p: oairequest.top_p,
507                min_p: oairequest.min_p,
508                top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
509                frequency_penalty: oairequest.frequency_penalty,
510                presence_penalty: oairequest.presence_penalty,
511                max_len: oairequest.max_tokens,
512                stop_toks,
513                logits_bias: oairequest.logit_bias,
514                n_choices: oairequest.n_choices,
515                dry_params,
516            },
517            response: tx,
518            return_logprobs: oairequest.logprobs,
519            is_streaming,
520            suffix: None,
521            constraint,
522            tool_choice: oairequest.tool_choice,
523            tools: oairequest.tools,
524            logits_processors: None,
525            return_raw_logits: false,
526            web_search_options: oairequest.web_search_options,
527        })),
528        is_streaming,
529    ))
530}
531
532/// OpenAI-compatible chat completions endpoint handler.
533#[utoipa::path(
534    post,
535    tag = "Mistral.rs",
536    path = "/v1/chat/completions",
537    request_body = ChatCompletionRequest,
538    responses((status = 200, description = "Chat completions"))
539)]
540pub async fn chatcompletions(
541    State(state): ExtractedMistralRsState,
542    Json(oairequest): Json<ChatCompletionRequest>,
543) -> ChatCompletionResponder {
544    let (tx, mut rx) = create_response_channel(None);
545
546    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
547        Ok(x) => x,
548        Err(e) => return handle_chat_completion_error(state, e.into()),
549    };
550
551    if let Err(e) = send_request(&state, request).await {
552        return handle_chat_completion_error(state, e.into());
553    }
554
555    if is_streaming {
556        ChatCompletionResponder::Sse(create_chat_streamer(rx, state, None, None))
557    } else {
558        process_non_streaming_chat_response(&mut rx, state).await
559    }
560}
561
562/// Helper function to handle chat completion errors and logging them.
563pub fn handle_chat_completion_error(
564    state: SharedMistralRsState,
565    e: Box<dyn std::error::Error + Send + Sync + 'static>,
566) -> ChatCompletionResponder {
567    let e = anyhow::Error::msg(e.to_string());
568    MistralRs::maybe_log_error(state, &*e);
569    ChatCompletionResponder::InternalError(e.into())
570}
571
572/// Creates a channel for response communication.
573pub fn create_response_channel(
574    buffer_size: Option<usize>,
575) -> (Sender<Response>, Receiver<Response>) {
576    let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
577
578    channel(channel_buffer_size)
579}
580
581/// Gets the keep-alive interval for SSE streams from environment or default.
582pub fn get_keep_alive_interval() -> u64 {
583    env::var("KEEP_ALIVE_INTERVAL")
584        .map(|val| {
585            val.parse::<u64>().unwrap_or_else(|e| {
586                tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e);
587                DEFAULT_KEEP_ALIVE_INTERVAL_MS
588            })
589        })
590        .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS)
591}
592
593/// Sends a request to the model processing pipeline.
594pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> {
595    let sender = state
596        .get_sender()
597        .context("mistral.rs sender not available.")?;
598
599    sender.send(request).await.map_err(|e| e.into())
600}
601
602/// Creates a SSE streamer for chat completions with optional callbacks.
603pub fn create_chat_streamer(
604    rx: Receiver<Response>,
605    state: SharedMistralRsState,
606    on_chunk: Option<OnChunkCallback>,
607    on_done: Option<OnDoneCallback>,
608) -> Sse<Streamer> {
609    let store_chunks = on_done.is_some();
610
611    let streamer = Streamer {
612        rx,
613        done_state: DoneState::Running,
614        store_chunks,
615        state,
616        chunks: Vec::new(),
617        on_chunk,
618        on_done,
619    };
620
621    let keep_alive_interval = get_keep_alive_interval();
622
623    Sse::new(streamer)
624        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
625}
626
627/// Processes non-streaming chat completion responses.
628pub async fn process_non_streaming_chat_response(
629    rx: &mut Receiver<Response>,
630    state: SharedMistralRsState,
631) -> ChatCompletionResponder {
632    let response = match rx.recv().await {
633        Some(response) => response,
634        None => {
635            let e = anyhow::Error::msg("No response received from the model.");
636            return handle_chat_completion_error(state, e.into());
637        }
638    };
639
640    match_responses(state, response)
641}
642
643/// Matches and processes different types of model responses into appropriate chat completion responses.
644pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
645    match response {
646        Response::InternalError(e) => {
647            MistralRs::maybe_log_error(state, &*e);
648            ChatCompletionResponder::InternalError(e)
649        }
650        Response::ModelError(msg, response) => {
651            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
652            MistralRs::maybe_log_response(state, &response);
653            ChatCompletionResponder::ModelError(msg, response)
654        }
655        Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
656        Response::Done(response) => {
657            MistralRs::maybe_log_response(state, &response);
658            ChatCompletionResponder::Json(response)
659        }
660        Response::Chunk(_) => unreachable!(),
661        Response::CompletionDone(_) => unreachable!(),
662        Response::CompletionModelError(_, _) => unreachable!(),
663        Response::CompletionChunk(_) => unreachable!(),
664        Response::ImageGeneration(_) => unreachable!(),
665        Response::Speech { .. } => unreachable!(),
666        Response::Raw { .. } => unreachable!(),
667    }
668}