mistralrs_server_core/
chat_completion.rs

1//! ## Chat Completions functionality and route handler.
2
3use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7    extract::{Json, State},
8    http::{self},
9    response::{
10        sse::{Event, KeepAlive, KeepAliveStream},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18    ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19    ReasoningEffort, Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25    completion_core::{
26        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27        BaseCompletionResponder,
28    },
29    handler_core::{
30        base_process_non_streaming_response, create_response_channel, send_request_with_model,
31        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32    },
33    openai::{
34        ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35        ResponseFormat,
36    },
37    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39    util::{parse_audio_url, parse_image_url, sanitize_error_message, validate_model_name},
40};
41
42/// A callback function that processes streaming response chunks before they are sent to the client.
43///
44/// This hook allows modification of each chunk in the streaming response, enabling features like
45/// content filtering, transformation, or logging. The callback receives a chunk and must return
46/// a (potentially modified) chunk.
47///
48/// ### Examples
49///
50/// ```no_run
51/// use mistralrs_server_core::chat_completion::ChatCompletionOnChunkCallback;
52///
53/// let on_chunk: ChatCompletionOnChunkCallback = Box::new(|mut chunk| {
54///     // Log the chunk or modify its content
55///     println!("Processing chunk: {:?}", chunk);
56///     chunk
57/// });
58/// ```
59pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61/// A callback function that is executed when the streaming response completes.
62///
63/// This hook receives all chunks that were streamed during the response, allowing for
64/// post-processing, analytics, or cleanup operations after the stream finishes.
65///
66/// ### Examples
67///
68/// ```no_run
69/// use mistralrs_server_core::chat_completion::ChatCompletionOnDoneCallback;
70///
71/// let on_done: ChatCompletionOnDoneCallback = Box::new(|chunks| {
72///     println!("Stream completed with {} chunks", chunks.len());
73///     // Process all chunks for analytics
74/// });
75/// ```
76pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78/// A streaming response handler.
79///
80/// It processes incoming response chunks from a model and converts them
81/// into Server-Sent Events (SSE) format for real-time streaming to clients.
82pub type ChatCompletionStreamer = BaseStreamer<
83    ChatCompletionChunkResponse,
84    ChatCompletionOnChunkCallback,
85    ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89    type Item = Result<Event, axum::Error>;
90
91    /// Polls the stream for the next Server-Sent Event.
92    ///
93    /// This method implements the core streaming logic:
94    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
95    /// 2. Processes incoming model responses and converts them to SSE events
96    /// 3. Applies chunk modifications if a callback is provided
97    /// 4. Stores chunks if completion callback is configured
98    fn poll_next(
99        mut self: Pin<&mut Self>,
100        cx: &mut std::task::Context<'_>,
101    ) -> Poll<Option<Self::Item>> {
102        match self.done_state {
103            DoneState::SendingDone => {
104                // https://platform.openai.com/docs/api-reference/completions/create
105                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
106                self.done_state = DoneState::Done;
107                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108            }
109            DoneState::Done => {
110                if let Some(on_done) = &self.on_done {
111                    on_done(&self.chunks);
112                }
113                return Poll::Ready(None);
114            }
115            DoneState::Running => (),
116        }
117
118        match self.rx.poll_recv(cx) {
119            Poll::Ready(Some(resp)) => match resp {
120                Response::ModelError(msg, _) => {
121                    MistralRs::maybe_log_error(
122                        self.state.clone(),
123                        &ModelErrorMessage(msg.to_string()),
124                    );
125                    // Done now, just need to send the [DONE]
126                    self.done_state = DoneState::SendingDone;
127                    Poll::Ready(Some(Ok(Event::default().data(msg))))
128                }
129                Response::ValidationError(e) => Poll::Ready(Some(Ok(
130                    Event::default().data(sanitize_error_message(e.as_ref()))
131                ))),
132                Response::InternalError(e) => {
133                    MistralRs::maybe_log_error(self.state.clone(), &*e);
134                    Poll::Ready(Some(Ok(
135                        Event::default().data(sanitize_error_message(e.as_ref()))
136                    )))
137                }
138                Response::Chunk(mut response) => {
139                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
140                        self.done_state = DoneState::SendingDone;
141                    }
142                    // Done now, just need to send the [DONE]
143                    MistralRs::maybe_log_response(self.state.clone(), &response);
144
145                    if let Some(on_chunk) = &self.on_chunk {
146                        response = on_chunk(response);
147                    }
148
149                    if self.store_chunks {
150                        self.chunks.push(response.clone());
151                    }
152
153                    Poll::Ready(Some(Event::default().json_data(response)))
154                }
155                Response::Done(_) => unreachable!(),
156                Response::CompletionDone(_) => unreachable!(),
157                Response::CompletionModelError(_, _) => unreachable!(),
158                Response::CompletionChunk(_) => unreachable!(),
159                Response::ImageGeneration(_) => unreachable!(),
160                Response::Speech { .. } => unreachable!(),
161                Response::Raw { .. } => unreachable!(),
162                Response::Embeddings { .. } => unreachable!(),
163            },
164            Poll::Pending | Poll::Ready(None) => Poll::Pending,
165        }
166    }
167}
168
169/// Represents different types of chat completion responses.
170pub type ChatCompletionResponder =
171    BaseCompletionResponder<ChatCompletionResponse, KeepAliveStream<ChatCompletionStreamer>>;
172
173type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
174impl ErrorToResponse for JsonModelError {}
175
176impl IntoResponse for ChatCompletionResponder {
177    /// Converts the chat completion responder into an HTTP response.
178    fn into_response(self) -> axum::response::Response {
179        match self {
180            ChatCompletionResponder::Sse(s) => s.into_response(),
181            ChatCompletionResponder::Json(s) => Json(s).into_response(),
182            ChatCompletionResponder::InternalError(e) => {
183                JsonError::new(sanitize_error_message(e.as_ref()))
184                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
185            }
186            ChatCompletionResponder::ValidationError(e) => {
187                JsonError::new(sanitize_error_message(e.as_ref()))
188                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
189            }
190            ChatCompletionResponder::ModelError(msg, response) => {
191                JsonModelError::new(msg, response)
192                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
193            }
194        }
195    }
196}
197
198/// Parse reasoning_effort string to ReasoningEffort enum
199fn parse_reasoning_effort(effort: &Option<String>) -> Option<ReasoningEffort> {
200    effort
201        .as_ref()
202        .and_then(|e| match e.to_lowercase().as_str() {
203            "low" => Some(ReasoningEffort::Low),
204            "medium" => Some(ReasoningEffort::Medium),
205            "high" => Some(ReasoningEffort::High),
206            _ => None,
207        })
208}
209
210/// Parses and validates a chat completion request.
211///
212/// This function transforms an OpenAI-compatible chat completion request into the
213/// request format used by mistral.rs.
214pub async fn parse_request(
215    oairequest: ChatCompletionRequest,
216    state: SharedMistralRsState,
217    tx: Sender<Response>,
218) -> Result<(Request, bool)> {
219    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
220    MistralRs::maybe_log_request(state.clone(), repr);
221
222    // Validate that the requested model matches the loaded model
223    validate_model_name(&oairequest.model, state.clone())?;
224
225    // Parse reasoning effort for Harmony-format models
226    let reasoning_effort = parse_reasoning_effort(&oairequest.reasoning_effort);
227
228    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
229
230    let messages = match oairequest.messages {
231        Either::Left(req_messages) => {
232            let mut messages = Vec::new();
233            let mut image_urls = Vec::new();
234            let mut audio_urls = Vec::new();
235            for message in req_messages {
236                let content = match message.content.as_deref() {
237                    Some(content) => content.clone(),
238                    None => {
239                        // Handle tool call
240                        let calls = message
241                            .tool_calls
242                            .as_ref()
243                            .context(
244                                "No content was provided, expected tool calls to be provided.",
245                            )?
246                            .iter()
247                            .map(|call| &call.function)
248                            .collect::<Vec<_>>();
249
250                        Either::Left(serde_json::to_string(&calls)?)
251                    }
252                };
253
254                match &content {
255                    Either::Left(content) => {
256                        let mut message_map: IndexMap<
257                            String,
258                            Either<String, Vec<IndexMap<String, Value>>>,
259                        > = IndexMap::new();
260                        message_map.insert("role".to_string(), Either::Left(message.role.clone()));
261                        message_map.insert("content".to_string(), Either::Left(content.clone()));
262
263                        // Add tool_calls for assistant messages that have them
264                        if let Some(ref tool_calls) = message.tool_calls {
265                            // Convert tool_calls to Vec<IndexMap<String, Value>> for Jinja template
266                            let tool_calls_vec: Vec<IndexMap<String, Value>> = tool_calls
267                                .iter()
268                                .map(|tc| {
269                                    let mut tc_map = IndexMap::new();
270                                    // Use provided ID or fallback to function name
271                                    let id =
272                                        tc.id.clone().unwrap_or_else(|| tc.function.name.clone());
273                                    tc_map.insert("id".to_string(), Value::String(id));
274                                    tc_map.insert(
275                                        "type".to_string(),
276                                        Value::String("function".to_string()),
277                                    );
278                                    let mut function_map = serde_json::Map::new();
279                                    function_map.insert(
280                                        "name".to_string(),
281                                        Value::String(tc.function.name.clone()),
282                                    );
283                                    function_map.insert(
284                                        "arguments".to_string(),
285                                        Value::String(tc.function.arguments.clone()),
286                                    );
287                                    tc_map.insert(
288                                        "function".to_string(),
289                                        Value::Object(function_map),
290                                    );
291                                    tc_map
292                                })
293                                .collect();
294                            message_map
295                                .insert("tool_calls".to_string(), Either::Right(tool_calls_vec));
296                        }
297
298                        // Add tool_call_id for tool messages
299                        if let Some(ref tool_call_id) = message.tool_call_id {
300                            message_map.insert(
301                                "tool_call_id".to_string(),
302                                Either::Left(tool_call_id.clone()),
303                            );
304                        }
305
306                        // Add name for tool messages
307                        if let Some(ref name) = message.name {
308                            message_map.insert("name".to_string(), Either::Left(name.clone()));
309                        }
310
311                        messages.push(message_map);
312                    }
313                    Either::Right(image_messages) => {
314                        // If there is only one message, it is possible a text message
315                        // found when rig is used as client. In this case, we need to check if
316                        // the message is a text message or an image message.
317                        if image_messages.len() == 1 {
318                            if !image_messages[0].contains_key("text") {
319                                anyhow::bail!("Expected `text` key in input message.");
320                            }
321                            let content = match image_messages[0]["text"].deref() {
322                                Either::Left(left) => left.to_string(),
323                                Either::Right(right) => format!("{right:?}"),
324                            };
325                            let mut message_map: IndexMap<
326                                String,
327                                Either<String, Vec<IndexMap<String, Value>>>,
328                            > = IndexMap::new();
329                            message_map.insert("role".to_string(), Either::Left(message.role));
330                            message_map.insert("content".to_string(), Either::Left(content));
331                            messages.push(message_map);
332                            continue;
333                        }
334                        if message.role != "user" {
335                            anyhow::bail!(
336                                "Role for an image message must be `user`, but it is {}",
337                                message.role
338                            );
339                        }
340
341                        enum ContentPart {
342                            Text { text: String },
343                            Image { image_url: String },
344                            Audio { audio_url: String },
345                        }
346
347                        let mut items = Vec::new();
348                        for image_message in image_messages {
349                            match image_message.get("type") {
350                                Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
351                                    items.push(ContentPart::Text {
352                                        text: image_message
353                                            .get("text").as_ref()
354                                            .context("Text sub-content must have `text` key.")?.as_ref()
355                                            .left().context("Text sub-content `text` key must be a string.")?.clone(),
356                                    });
357                                }
358                                Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
359                                    items.push(ContentPart::Image {
360                                        image_url: image_message
361                                            .get("image_url")
362                                            .as_ref()
363                                            .context("Image sub-content must have `image_url` key.")?
364                                            .as_ref()
365                                            .right()
366                                            .context("Image sub-content `image_url` key must be an object.")?
367                                            .get("url")
368                                            .context("Image sub-content `image_url` object must have a `url` key.")?
369                                            .clone(),
370                                    });
371                                }
372                                Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
373                                    items.push(ContentPart::Audio {
374                                        audio_url: image_message
375                                            .get("audio_url")
376                                            .as_ref()
377                                            .context("Audio sub-content must have `audio_url` key.")?
378                                            .as_ref()
379                                            .right()
380                                            .context("Audio sub-content `audio_url` key must be an object.")?
381                                            .get("url")
382                                            .context("Audio sub-content `audio_url` object must have a `url` key.")?
383                                            .clone(),
384                                    });
385                                }
386                                _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
387                            }
388                        }
389
390                        let text_content = items
391                            .iter()
392                            .filter_map(|item| match item {
393                                ContentPart::Text { text } => Some(text),
394                                _ => None,
395                            })
396                            .join(" ");
397                        let image_urls_iter = items
398                            .iter()
399                            .filter_map(|item| match item {
400                                ContentPart::Image { image_url } => Some(image_url.clone()),
401                                _ => None,
402                            })
403                            .collect::<Vec<_>>();
404
405                        let audio_urls_iter = items
406                            .iter()
407                            .filter_map(|item| match item {
408                                ContentPart::Audio { audio_url } => Some(audio_url.clone()),
409                                _ => None,
410                            })
411                            .collect::<Vec<_>>();
412
413                        let mut message_map: IndexMap<
414                            String,
415                            Either<String, Vec<IndexMap<String, Value>>>,
416                        > = IndexMap::new();
417                        message_map.insert("role".to_string(), Either::Left(message.role));
418
419                        let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
420                        for _ in &image_urls_iter {
421                            let mut content_image_map = IndexMap::new();
422                            content_image_map
423                                .insert("type".to_string(), Value::String("image".to_string()));
424                            content_map.push(content_image_map);
425                        }
426                        for _ in &audio_urls_iter {
427                            let mut content_audio_map = IndexMap::new();
428                            content_audio_map
429                                .insert("type".to_string(), Value::String("audio".to_string()));
430                            content_map.push(content_audio_map);
431                        }
432                        {
433                            let mut content_text_map = IndexMap::new();
434                            content_text_map
435                                .insert("type".to_string(), Value::String("text".to_string()));
436                            content_text_map
437                                .insert("text".to_string(), Value::String(text_content));
438                            content_map.push(content_text_map);
439                        }
440
441                        message_map.insert("content".to_string(), Either::Right(content_map));
442                        messages.push(message_map);
443                        image_urls.extend(image_urls_iter);
444                        audio_urls.extend(audio_urls_iter);
445                    }
446                }
447            }
448            if !image_urls.is_empty() || !audio_urls.is_empty() {
449                // Parse images
450                let mut images = Vec::new();
451                for url_unparsed in image_urls {
452                    let image = parse_image_url(&url_unparsed)
453                        .await
454                        .context(format!("Failed to parse image resource: {url_unparsed}"))?;
455                    images.push(image);
456                }
457
458                // Parse audios
459                let mut audios = Vec::new();
460                for url_unparsed in audio_urls {
461                    let audio = parse_audio_url(&url_unparsed)
462                        .await
463                        .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
464                    audios.push(audio);
465                }
466
467                RequestMessage::VisionChat {
468                    messages,
469                    images,
470                    audios,
471                    enable_thinking: oairequest.enable_thinking,
472                    reasoning_effort,
473                }
474            } else {
475                RequestMessage::Chat {
476                    messages,
477                    enable_thinking: oairequest.enable_thinking,
478                    reasoning_effort,
479                }
480            }
481        }
482        Either::Right(prompt) => {
483            let mut messages = Vec::new();
484            let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
485                IndexMap::new();
486            message_map.insert("role".to_string(), Either::Left("user".to_string()));
487            message_map.insert("content".to_string(), Either::Left(prompt));
488            messages.push(message_map);
489            RequestMessage::Chat {
490                messages,
491                enable_thinking: oairequest.enable_thinking,
492                reasoning_effort,
493            }
494        }
495    };
496
497    let dry_params = get_dry_sampling_params(
498        oairequest.dry_multiplier,
499        oairequest.dry_sequence_breakers,
500        oairequest.dry_base,
501        oairequest.dry_allowed_length,
502    )?;
503
504    let is_streaming = oairequest.stream.unwrap_or(false);
505
506    if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
507        anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
508    }
509
510    let constraint = match oairequest.grammar {
511        Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
512        Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
513        Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
514        Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
515        None => match oairequest.response_format {
516            Some(ResponseFormat::JsonSchema {
517                json_schema: JsonSchemaResponseFormat { name: _, schema },
518            }) => Constraint::JsonSchema(schema),
519            Some(ResponseFormat::Text) => Constraint::None,
520            None => Constraint::None,
521        },
522    };
523
524    Ok((
525        Request::Normal(Box::new(NormalRequest {
526            id: state.next_request_id(),
527            messages,
528            sampling_params: SamplingParams {
529                temperature: oairequest.temperature,
530                top_k: oairequest.top_k,
531                top_p: oairequest.top_p,
532                min_p: oairequest.min_p,
533                top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
534                frequency_penalty: oairequest.frequency_penalty,
535                presence_penalty: oairequest.presence_penalty,
536                repetition_penalty: oairequest.repetition_penalty,
537                max_len: oairequest.max_tokens,
538                stop_toks,
539                logits_bias: oairequest.logit_bias,
540                n_choices: oairequest.n_choices,
541                dry_params,
542            },
543            response: tx,
544            return_logprobs: oairequest.logprobs,
545            is_streaming,
546            suffix: None,
547            constraint,
548            tool_choice: oairequest.tool_choice,
549            tools: oairequest.tools,
550            logits_processors: None,
551            return_raw_logits: false,
552            web_search_options: oairequest.web_search_options,
553            model_id: if oairequest.model == "default" {
554                None
555            } else {
556                Some(oairequest.model.clone())
557            },
558            truncate_sequence: oairequest.truncate_sequence.unwrap_or(false),
559        })),
560        is_streaming,
561    ))
562}
563
564/// OpenAI-compatible chat completions endpoint handler.
565#[utoipa::path(
566    post,
567    tag = "Mistral.rs",
568    path = "/v1/chat/completions",
569    request_body = ChatCompletionRequest,
570    responses((status = 200, description = "Chat completions"))
571)]
572pub async fn chatcompletions(
573    State(state): ExtractedMistralRsState,
574    Json(oairequest): Json<ChatCompletionRequest>,
575) -> ChatCompletionResponder {
576    let (tx, mut rx) = create_response_channel(None);
577
578    // Extract model_id for routing before parsing
579    let model_id = if oairequest.model == "default" {
580        None
581    } else {
582        Some(oairequest.model.clone())
583    };
584
585    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
586        Ok(x) => x,
587        Err(e) => return handle_error(state, e.into()),
588    };
589
590    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
591        return handle_error(state, e.into());
592    }
593
594    if is_streaming {
595        ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
596    } else {
597        process_non_streaming_response(&mut rx, state).await
598    }
599}
600
601/// Handle route / generation errors and logging them.
602pub fn handle_error(
603    state: SharedMistralRsState,
604    e: Box<dyn std::error::Error + Send + Sync + 'static>,
605) -> ChatCompletionResponder {
606    handle_completion_error(state, e)
607}
608
609/// Creates a SSE streamer for chat completions with optional callbacks.
610pub fn create_streamer(
611    rx: Receiver<Response>,
612    state: SharedMistralRsState,
613    on_chunk: Option<ChatCompletionOnChunkCallback>,
614    on_done: Option<ChatCompletionOnDoneCallback>,
615) -> Sse<KeepAliveStream<ChatCompletionStreamer>> {
616    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
617    let keep_alive_interval = get_keep_alive_interval();
618
619    Sse::new(streamer)
620        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
621}
622
623/// Process non-streaming chat completion responses.
624pub async fn process_non_streaming_response(
625    rx: &mut Receiver<Response>,
626    state: SharedMistralRsState,
627) -> ChatCompletionResponder {
628    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
629}
630
631/// Matches and processes different types of model responses into appropriate chat completion responses.
632pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
633    match response {
634        Response::InternalError(e) => {
635            MistralRs::maybe_log_error(state, &*e);
636            ChatCompletionResponder::InternalError(e)
637        }
638        Response::ModelError(msg, response) => {
639            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
640            MistralRs::maybe_log_response(state, &response);
641            ChatCompletionResponder::ModelError(msg, response)
642        }
643        Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
644        Response::Done(response) => {
645            MistralRs::maybe_log_response(state, &response);
646            ChatCompletionResponder::Json(response)
647        }
648        Response::Chunk(_) => unreachable!(),
649        Response::CompletionDone(_) => unreachable!(),
650        Response::CompletionModelError(_, _) => unreachable!(),
651        Response::CompletionChunk(_) => unreachable!(),
652        Response::ImageGeneration(_) => unreachable!(),
653        Response::Speech { .. } => unreachable!(),
654        Response::Raw { .. } => unreachable!(),
655        Response::Embeddings { .. } => unreachable!(),
656    }
657}