mistralrs_server_core/
chat_completion.rs

1//! ## Chat Completions functionality and route handler.
2
3use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7    extract::{Json, State},
8    http::{self},
9    response::{
10        sse::{Event, KeepAlive},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18    ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19    Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25    completion_core::{
26        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27        BaseCompletionResponder,
28    },
29    handler_core::{
30        base_process_non_streaming_response, create_response_channel, send_request_with_model,
31        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32    },
33    openai::{
34        ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35        ResponseFormat,
36    },
37    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39    util::{parse_audio_url, parse_image_url, validate_model_name},
40};
41
42/// A callback function that processes streaming response chunks before they are sent to the client.
43///
44/// This hook allows modification of each chunk in the streaming response, enabling features like
45/// content filtering, transformation, or logging. The callback receives a chunk and must return
46/// a (potentially modified) chunk.
47///
48/// ### Examples
49///
50/// ```no_run
51/// use mistralrs_server_core::chat_completion::ChatCompletionOnChunkCallback;
52///
53/// let on_chunk: ChatCompletionOnChunkCallback = Box::new(|mut chunk| {
54///     // Log the chunk or modify its content
55///     println!("Processing chunk: {:?}", chunk);
56///     chunk
57/// });
58/// ```
59pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61/// A callback function that is executed when the streaming response completes.
62///
63/// This hook receives all chunks that were streamed during the response, allowing for
64/// post-processing, analytics, or cleanup operations after the stream finishes.
65///
66/// ### Examples
67///
68/// ```no_run
69/// use mistralrs_server_core::chat_completion::ChatCompletionOnDoneCallback;
70///
71/// let on_done: ChatCompletionOnDoneCallback = Box::new(|chunks| {
72///     println!("Stream completed with {} chunks", chunks.len());
73///     // Process all chunks for analytics
74/// });
75/// ```
76pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78/// A streaming response handler.
79///
80/// It processes incoming response chunks from a model and converts them
81/// into Server-Sent Events (SSE) format for real-time streaming to clients.
82pub type ChatCompletionStreamer = BaseStreamer<
83    ChatCompletionChunkResponse,
84    ChatCompletionOnChunkCallback,
85    ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89    type Item = Result<Event, axum::Error>;
90
91    /// Polls the stream for the next Server-Sent Event.
92    ///
93    /// This method implements the core streaming logic:
94    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
95    /// 2. Processes incoming model responses and converts them to SSE events
96    /// 3. Applies chunk modifications if a callback is provided
97    /// 4. Stores chunks if completion callback is configured
98    fn poll_next(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101    ) -> Poll<Option<Self::Item>> {
102        match self.done_state {
103            DoneState::SendingDone => {
104                // https://platform.openai.com/docs/api-reference/completions/create
105                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
106                self.done_state = DoneState::Done;
107                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108            }
109            DoneState::Done => {
110                if let Some(on_done) = &self.on_done {
111                    on_done(&self.chunks);
112                }
113                return Poll::Ready(None);
114            }
115            DoneState::Running => (),
116        }
117
118        match self.rx.poll_recv(cx) {
119            Poll::Ready(Some(resp)) => match resp {
120                Response::ModelError(msg, _) => {
121                    MistralRs::maybe_log_error(
122                        self.state.clone(),
123                        &ModelErrorMessage(msg.to_string()),
124                    );
125                    // Done now, just need to send the [DONE]
126                    self.done_state = DoneState::SendingDone;
127                    Poll::Ready(Some(Ok(Event::default().data(msg))))
128                }
129                Response::ValidationError(e) => {
130                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
131                }
132                Response::InternalError(e) => {
133                    MistralRs::maybe_log_error(self.state.clone(), &*e);
134                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
135                }
136                Response::Chunk(mut response) => {
137                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
138                        self.done_state = DoneState::SendingDone;
139                    }
140                    // Done now, just need to send the [DONE]
141                    MistralRs::maybe_log_response(self.state.clone(), &response);
142
143                    if let Some(on_chunk) = &self.on_chunk {
144                        response = on_chunk(response);
145                    }
146
147                    if self.store_chunks {
148                        self.chunks.push(response.clone());
149                    }
150
151                    Poll::Ready(Some(Event::default().json_data(response)))
152                }
153                Response::Done(_) => unreachable!(),
154                Response::CompletionDone(_) => unreachable!(),
155                Response::CompletionModelError(_, _) => unreachable!(),
156                Response::CompletionChunk(_) => unreachable!(),
157                Response::ImageGeneration(_) => unreachable!(),
158                Response::Speech { .. } => unreachable!(),
159                Response::Raw { .. } => unreachable!(),
160            },
161            Poll::Pending | Poll::Ready(None) => Poll::Pending,
162        }
163    }
164}
165
166/// Represents different types of chat completion responses.
167pub type ChatCompletionResponder =
168    BaseCompletionResponder<ChatCompletionResponse, ChatCompletionStreamer>;
169
170type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
171impl ErrorToResponse for JsonModelError {}
172
173impl IntoResponse for ChatCompletionResponder {
174    /// Converts the chat completion responder into an HTTP response.
175    fn into_response(self) -> axum::response::Response {
176        match self {
177            ChatCompletionResponder::Sse(s) => s.into_response(),
178            ChatCompletionResponder::Json(s) => Json(s).into_response(),
179            ChatCompletionResponder::InternalError(e) => {
180                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
181            }
182            ChatCompletionResponder::ValidationError(e) => {
183                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
184            }
185            ChatCompletionResponder::ModelError(msg, response) => {
186                JsonModelError::new(msg, response)
187                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
188            }
189        }
190    }
191}
192
193/// Parses and validates a chat completion request.
194///
195/// This function transforms an OpenAI-compatible chat completion request into the
196/// request format used by mistral.rs.
197pub async fn parse_request(
198    oairequest: ChatCompletionRequest,
199    state: SharedMistralRsState,
200    tx: Sender<Response>,
201) -> Result<(Request, bool)> {
202    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
203    MistralRs::maybe_log_request(state.clone(), repr);
204
205    // Validate that the requested model matches the loaded model
206    validate_model_name(&oairequest.model, state.clone())?;
207
208    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
209
210    let messages = match oairequest.messages {
211        Either::Left(req_messages) => {
212            let mut messages = Vec::new();
213            let mut image_urls = Vec::new();
214            let mut audio_urls = Vec::new();
215            for message in req_messages {
216                let content = match message.content.as_deref() {
217                    Some(content) => content.clone(),
218                    None => {
219                        // Handle tool call
220                        let calls = message
221                            .tool_calls
222                            .as_ref()
223                            .context(
224                                "No content was provided, expected tool calls to be provided.",
225                            )?
226                            .iter()
227                            .map(|call| &call.function)
228                            .collect::<Vec<_>>();
229
230                        Either::Left(serde_json::to_string(&calls)?)
231                    }
232                };
233
234                match &content {
235                    Either::Left(content) => {
236                        let mut message_map: IndexMap<
237                            String,
238                            Either<String, Vec<IndexMap<String, Value>>>,
239                        > = IndexMap::new();
240                        message_map.insert("role".to_string(), Either::Left(message.role));
241                        message_map.insert("content".to_string(), Either::Left(content.clone()));
242                        messages.push(message_map);
243                    }
244                    Either::Right(image_messages) => {
245                        // If there is only one message, it is possible a text message
246                        // found when rig is used as client. In this case, we need to check if
247                        // the message is a text message or an image message.
248                        if image_messages.len() == 1 {
249                            if !image_messages[0].contains_key("text") {
250                                anyhow::bail!("Expected `text` key in input message.");
251                            }
252                            let content = match image_messages[0]["text"].deref() {
253                                Either::Left(left) => left.to_string(),
254                                Either::Right(right) => format!("{right:?}"),
255                            };
256                            let mut message_map: IndexMap<
257                                String,
258                                Either<String, Vec<IndexMap<String, Value>>>,
259                            > = IndexMap::new();
260                            message_map.insert("role".to_string(), Either::Left(message.role));
261                            message_map.insert("content".to_string(), Either::Left(content));
262                            messages.push(message_map);
263                            continue;
264                        }
265                        if message.role != "user" {
266                            anyhow::bail!(
267                                "Role for an image message must be `user`, but it is {}",
268                                message.role
269                            );
270                        }
271
272                        enum ContentPart {
273                            Text { text: String },
274                            Image { image_url: String },
275                            Audio { audio_url: String },
276                        }
277
278                        let mut items = Vec::new();
279                        for image_message in image_messages {
280                            match image_message.get("type") {
281                                Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
282                                    items.push(ContentPart::Text {
283                                        text: image_message
284                                            .get("text").as_ref()
285                                            .context("Text sub-content must have `text` key.")?.as_ref()
286                                            .left().context("Text sub-content `text` key must be a string.")?.clone(),
287                                    });
288                                }
289                                Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
290                                    items.push(ContentPart::Image {
291                                        image_url: image_message
292                                            .get("image_url")
293                                            .as_ref()
294                                            .context("Image sub-content must have `image_url` key.")?
295                                            .as_ref()
296                                            .right()
297                                            .context("Image sub-content `image_url` key must be an object.")?
298                                            .get("url")
299                                            .context("Image sub-content `image_url` object must have a `url` key.")?
300                                            .clone(),
301                                    });
302                                }
303                                Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
304                                    items.push(ContentPart::Audio {
305                                        audio_url: image_message
306                                            .get("audio_url")
307                                            .as_ref()
308                                            .context("Audio sub-content must have `audio_url` key.")?
309                                            .as_ref()
310                                            .right()
311                                            .context("Audio sub-content `audio_url` key must be an object.")?
312                                            .get("url")
313                                            .context("Audio sub-content `audio_url` object must have a `url` key.")?
314                                            .clone(),
315                                    });
316                                }
317                                _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
318                            }
319                        }
320
321                        let text_content = items
322                            .iter()
323                            .filter_map(|item| match item {
324                                ContentPart::Text { text } => Some(text),
325                                _ => None,
326                            })
327                            .join(" ");
328                        let image_urls_iter = items
329                            .iter()
330                            .filter_map(|item| match item {
331                                ContentPart::Image { image_url } => Some(image_url.clone()),
332                                _ => None,
333                            })
334                            .collect::<Vec<_>>();
335
336                        let audio_urls_iter = items
337                            .iter()
338                            .filter_map(|item| match item {
339                                ContentPart::Audio { audio_url } => Some(audio_url.clone()),
340                                _ => None,
341                            })
342                            .collect::<Vec<_>>();
343
344                        let mut message_map: IndexMap<
345                            String,
346                            Either<String, Vec<IndexMap<String, Value>>>,
347                        > = IndexMap::new();
348                        message_map.insert("role".to_string(), Either::Left(message.role));
349
350                        let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
351                        for _ in &image_urls_iter {
352                            let mut content_image_map = IndexMap::new();
353                            content_image_map
354                                .insert("type".to_string(), Value::String("image".to_string()));
355                            content_map.push(content_image_map);
356                        }
357                        for _ in &audio_urls_iter {
358                            let mut content_audio_map = IndexMap::new();
359                            content_audio_map
360                                .insert("type".to_string(), Value::String("audio".to_string()));
361                            content_map.push(content_audio_map);
362                        }
363                        {
364                            let mut content_text_map = IndexMap::new();
365                            content_text_map
366                                .insert("type".to_string(), Value::String("text".to_string()));
367                            content_text_map
368                                .insert("text".to_string(), Value::String(text_content));
369                            content_map.push(content_text_map);
370                        }
371
372                        message_map.insert("content".to_string(), Either::Right(content_map));
373                        messages.push(message_map);
374                        image_urls.extend(image_urls_iter);
375                        audio_urls.extend(audio_urls_iter);
376                    }
377                }
378            }
379            if !image_urls.is_empty() || !audio_urls.is_empty() {
380                // Parse images
381                let mut images = Vec::new();
382                for url_unparsed in image_urls {
383                    let image = parse_image_url(&url_unparsed)
384                        .await
385                        .context(format!("Failed to parse image resource: {url_unparsed}"))?;
386                    images.push(image);
387                }
388
389                // Parse audios
390                let mut audios = Vec::new();
391                for url_unparsed in audio_urls {
392                    let audio = parse_audio_url(&url_unparsed)
393                        .await
394                        .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
395                    audios.push(audio);
396                }
397
398                RequestMessage::VisionChat {
399                    messages,
400                    images,
401                    audios,
402                    enable_thinking: oairequest.enable_thinking,
403                }
404            } else {
405                RequestMessage::Chat {
406                    messages,
407                    enable_thinking: oairequest.enable_thinking,
408                }
409            }
410        }
411        Either::Right(prompt) => {
412            let mut messages = Vec::new();
413            let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
414                IndexMap::new();
415            message_map.insert("role".to_string(), Either::Left("user".to_string()));
416            message_map.insert("content".to_string(), Either::Left(prompt));
417            messages.push(message_map);
418            RequestMessage::Chat {
419                messages,
420                enable_thinking: oairequest.enable_thinking,
421            }
422        }
423    };
424
425    let dry_params = get_dry_sampling_params(
426        oairequest.dry_multiplier,
427        oairequest.dry_sequence_breakers,
428        oairequest.dry_base,
429        oairequest.dry_allowed_length,
430    )?;
431
432    let is_streaming = oairequest.stream.unwrap_or(false);
433
434    if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
435        anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
436    }
437
438    let constraint = match oairequest.grammar {
439        Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
440        Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
441        Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
442        Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
443        None => match oairequest.response_format {
444            Some(ResponseFormat::JsonSchema {
445                json_schema: JsonSchemaResponseFormat { name: _, schema },
446            }) => Constraint::JsonSchema(schema),
447            Some(ResponseFormat::Text) => Constraint::None,
448            None => Constraint::None,
449        },
450    };
451
452    Ok((
453        Request::Normal(Box::new(NormalRequest {
454            id: state.next_request_id(),
455            messages,
456            sampling_params: SamplingParams {
457                temperature: oairequest.temperature,
458                top_k: oairequest.top_k,
459                top_p: oairequest.top_p,
460                min_p: oairequest.min_p,
461                top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
462                frequency_penalty: oairequest.frequency_penalty,
463                presence_penalty: oairequest.presence_penalty,
464                max_len: oairequest.max_tokens,
465                stop_toks,
466                logits_bias: oairequest.logit_bias,
467                n_choices: oairequest.n_choices,
468                dry_params,
469            },
470            response: tx,
471            return_logprobs: oairequest.logprobs,
472            is_streaming,
473            suffix: None,
474            constraint,
475            tool_choice: oairequest.tool_choice,
476            tools: oairequest.tools,
477            logits_processors: None,
478            return_raw_logits: false,
479            web_search_options: oairequest.web_search_options,
480            model_id: if oairequest.model == "default" {
481                None
482            } else {
483                Some(oairequest.model.clone())
484            },
485        })),
486        is_streaming,
487    ))
488}
489
490/// OpenAI-compatible chat completions endpoint handler.
491#[utoipa::path(
492    post,
493    tag = "Mistral.rs",
494    path = "/v1/chat/completions",
495    request_body = ChatCompletionRequest,
496    responses((status = 200, description = "Chat completions"))
497)]
498pub async fn chatcompletions(
499    State(state): ExtractedMistralRsState,
500    Json(oairequest): Json<ChatCompletionRequest>,
501) -> ChatCompletionResponder {
502    let (tx, mut rx) = create_response_channel(None);
503
504    // Extract model_id for routing before parsing
505    let model_id = if oairequest.model == "default" {
506        None
507    } else {
508        Some(oairequest.model.clone())
509    };
510
511    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
512        Ok(x) => x,
513        Err(e) => return handle_error(state, e.into()),
514    };
515
516    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
517        return handle_error(state, e.into());
518    }
519
520    if is_streaming {
521        ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
522    } else {
523        process_non_streaming_response(&mut rx, state).await
524    }
525}
526
527/// Handle route / generation errors and logging them.
528pub fn handle_error(
529    state: SharedMistralRsState,
530    e: Box<dyn std::error::Error + Send + Sync + 'static>,
531) -> ChatCompletionResponder {
532    handle_completion_error(state, e)
533}
534
535/// Creates a SSE streamer for chat completions with optional callbacks.
536pub fn create_streamer(
537    rx: Receiver<Response>,
538    state: SharedMistralRsState,
539    on_chunk: Option<ChatCompletionOnChunkCallback>,
540    on_done: Option<ChatCompletionOnDoneCallback>,
541) -> Sse<ChatCompletionStreamer> {
542    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
543    let keep_alive_interval = get_keep_alive_interval();
544
545    Sse::new(streamer)
546        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
547}
548
549/// Process non-streaming chat completion responses.
550pub async fn process_non_streaming_response(
551    rx: &mut Receiver<Response>,
552    state: SharedMistralRsState,
553) -> ChatCompletionResponder {
554    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
555}
556
557/// Matches and processes different types of model responses into appropriate chat completion responses.
558pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
559    match response {
560        Response::InternalError(e) => {
561            MistralRs::maybe_log_error(state, &*e);
562            ChatCompletionResponder::InternalError(e)
563        }
564        Response::ModelError(msg, response) => {
565            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
566            MistralRs::maybe_log_response(state, &response);
567            ChatCompletionResponder::ModelError(msg, response)
568        }
569        Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
570        Response::Done(response) => {
571            MistralRs::maybe_log_response(state, &response);
572            ChatCompletionResponder::Json(response)
573        }
574        Response::Chunk(_) => unreachable!(),
575        Response::CompletionDone(_) => unreachable!(),
576        Response::CompletionModelError(_, _) => unreachable!(),
577        Response::CompletionChunk(_) => unreachable!(),
578        Response::ImageGeneration(_) => unreachable!(),
579        Response::Speech { .. } => unreachable!(),
580        Response::Raw { .. } => unreachable!(),
581    }
582}