mistralrs_server/
chat_completion.rs

1use serde_json::Value;
2use std::{env, error::Error, ops::Deref, pin::Pin, sync::Arc, task::Poll, time::Duration};
3use tokio::sync::mpsc::{channel, Receiver, Sender};
4
5use crate::{
6    openai::{
7        ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
8        ResponseFormat, StopTokens,
9    },
10    util,
11};
12use anyhow::Context;
13use anyhow::Result;
14use axum::{
15    extract::{Json, State},
16    http::{self, StatusCode},
17    response::{
18        sse::{Event, KeepAlive},
19        IntoResponse, Sse,
20    },
21};
22use either::Either;
23use indexmap::IndexMap;
24use itertools::Itertools;
25use mistralrs_core::{
26    ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
27    RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
28};
29use serde::Serialize;
30
31#[derive(Debug)]
32struct ModelErrorMessage(String);
33impl std::fmt::Display for ModelErrorMessage {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "{}", self.0)
36    }
37}
38impl std::error::Error for ModelErrorMessage {}
39
40enum DoneState {
41    Running,
42    SendingDone,
43    Done,
44}
45
46pub struct Streamer {
47    rx: Receiver<Response>,
48    done_state: DoneState,
49    state: Arc<MistralRs>,
50}
51
52impl futures::Stream for Streamer {
53    type Item = Result<Event, axum::Error>;
54
55    fn poll_next(
56        mut self: Pin<&mut Self>,
57        cx: &mut std::task::Context<'_>,
58    ) -> Poll<Option<Self::Item>> {
59        match self.done_state {
60            DoneState::SendingDone => {
61                // https://platform.openai.com/docs/api-reference/completions/create
62                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
63                self.done_state = DoneState::Done;
64                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
65            }
66            DoneState::Done => {
67                return Poll::Ready(None);
68            }
69            DoneState::Running => (),
70        }
71
72        match self.rx.poll_recv(cx) {
73            Poll::Ready(Some(resp)) => match resp {
74                Response::ModelError(msg, _) => {
75                    MistralRs::maybe_log_error(
76                        self.state.clone(),
77                        &ModelErrorMessage(msg.to_string()),
78                    );
79                    // Done now, just need to send the [DONE]
80                    self.done_state = DoneState::SendingDone;
81                    Poll::Ready(Some(Ok(Event::default().data(msg))))
82                }
83                Response::ValidationError(e) => {
84                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
85                }
86                Response::InternalError(e) => {
87                    MistralRs::maybe_log_error(self.state.clone(), &*e);
88                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
89                }
90                Response::Chunk(response) => {
91                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
92                        self.done_state = DoneState::SendingDone;
93                    }
94                    // Done now, just need to send the [DONE]
95                    MistralRs::maybe_log_response(self.state.clone(), &response);
96                    Poll::Ready(Some(Event::default().json_data(response)))
97                }
98                Response::Done(_) => unreachable!(),
99                Response::CompletionDone(_) => unreachable!(),
100                Response::CompletionModelError(_, _) => unreachable!(),
101                Response::CompletionChunk(_) => unreachable!(),
102                Response::ImageGeneration(_) => unreachable!(),
103                Response::Raw { .. } => unreachable!(),
104            },
105            Poll::Pending | Poll::Ready(None) => Poll::Pending,
106        }
107    }
108}
109
110pub enum ChatCompletionResponder {
111    Sse(Sse<Streamer>),
112    Json(ChatCompletionResponse),
113    ModelError(String, ChatCompletionResponse),
114    InternalError(Box<dyn Error>),
115    ValidationError(Box<dyn Error>),
116}
117
118trait ErrorToResponse: Serialize {
119    fn to_response(&self, code: StatusCode) -> axum::response::Response {
120        let mut r = Json(self).into_response();
121        *r.status_mut() = code;
122        r
123    }
124}
125
126#[derive(Serialize)]
127struct JsonError {
128    message: String,
129}
130
131impl JsonError {
132    fn new(message: String) -> Self {
133        Self { message }
134    }
135}
136impl ErrorToResponse for JsonError {}
137
138#[derive(Serialize)]
139struct JsonModelError {
140    message: String,
141    partial_response: ChatCompletionResponse,
142}
143
144impl JsonModelError {
145    fn new(message: String, partial_response: ChatCompletionResponse) -> Self {
146        Self {
147            message,
148            partial_response,
149        }
150    }
151}
152
153impl ErrorToResponse for JsonModelError {}
154
155impl IntoResponse for ChatCompletionResponder {
156    fn into_response(self) -> axum::response::Response {
157        match self {
158            ChatCompletionResponder::Sse(s) => s.into_response(),
159            ChatCompletionResponder::Json(s) => Json(s).into_response(),
160            ChatCompletionResponder::InternalError(e) => {
161                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
162            }
163            ChatCompletionResponder::ValidationError(e) => {
164                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
165            }
166            ChatCompletionResponder::ModelError(msg, response) => {
167                JsonModelError::new(msg, response)
168                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
169            }
170        }
171    }
172}
173
174async fn parse_request(
175    oairequest: ChatCompletionRequest,
176    state: Arc<MistralRs>,
177    tx: Sender<Response>,
178) -> Result<(Request, bool)> {
179    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
180    MistralRs::maybe_log_request(state.clone(), repr);
181
182    let stop_toks = match oairequest.stop_seqs {
183        Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
184        Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
185        None => None,
186    };
187    let messages = match oairequest.messages {
188        Either::Left(req_messages) => {
189            let mut messages = Vec::new();
190            let mut image_urls = Vec::new();
191            for message in req_messages {
192                let content = match message.content.as_deref() {
193                    Some(content) => content.clone(),
194                    None => {
195                        // Handle tool call
196                        let calls = message
197                            .tool_calls
198                            .as_ref()
199                            .context(
200                                "No content was provided, expected tool calls to be provided.",
201                            )?
202                            .iter()
203                            .map(|call| &call.function)
204                            .collect::<Vec<_>>();
205
206                        Either::Left(serde_json::to_string(&calls)?)
207                    }
208                };
209
210                match &content {
211                    Either::Left(content) => {
212                        let mut message_map: IndexMap<
213                            String,
214                            Either<String, Vec<IndexMap<String, Value>>>,
215                        > = IndexMap::new();
216                        message_map.insert("role".to_string(), Either::Left(message.role));
217                        message_map.insert("content".to_string(), Either::Left(content.clone()));
218                        messages.push(message_map);
219                    }
220                    Either::Right(image_messages) => {
221                        // If there is only one message, it is possible a text message
222                        // found when rig is used as client. In this case, we need to check if
223                        // the message is a text message or an image message.
224                        if image_messages.len() == 1 {
225                            if !image_messages[0].contains_key("text") {
226                                anyhow::bail!("Expected `text` key in input message.");
227                            }
228                            let content = match image_messages[0]["text"].deref() {
229                                Either::Left(left) => left.to_string(),
230                                Either::Right(right) => format!("{:?}", right),
231                            };
232                            let mut message_map: IndexMap<
233                                String,
234                                Either<String, Vec<IndexMap<String, Value>>>,
235                            > = IndexMap::new();
236                            message_map.insert("role".to_string(), Either::Left(message.role));
237                            message_map.insert("content".to_string(), Either::Left(content));
238                            messages.push(message_map);
239                            continue;
240                        }
241                        if message.role != "user" {
242                            anyhow::bail!(
243                                "Role for an image message must be `user`, but it is {}",
244                                message.role
245                            );
246                        }
247
248                        enum ContentPart {
249                            Text { text: String },
250                            Image { image_url: String },
251                        }
252
253                        let mut items = Vec::new();
254                        for image_message in image_messages {
255                            match image_message.get("type") {
256                                Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
257                                    items.push(ContentPart::Text {
258                                        text: image_message
259                                            .get("text").as_ref()
260                                            .context("Text sub-content must have `text` key.")?.as_ref()
261                                            .left().context("Text sub-content `text` key must be a string.")?.clone(),
262                                    });
263                                }
264                                Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
265                                    items.push(ContentPart::Image {
266                                        image_url: image_message.get("image_url").as_ref()
267                                            .context("Image sub-content must have `image_url` key.")?.as_ref()
268                                            .right()
269                                            .context("Image sub-content `image_url` key must be an object.")?
270                                            .get("url")
271                                            .context("Image sub-content `image_url` object must have a `url` key.")?.clone()
272                                    });
273                                }
274                                _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
275                            }
276                        }
277
278                        let text_content = items
279                            .iter()
280                            .filter_map(|item| match item {
281                                ContentPart::Text { text } => Some(text),
282                                _ => None,
283                            })
284                            .join(" ");
285                        let image_urls_iter = items
286                            .iter()
287                            .filter_map(|item| match item {
288                                ContentPart::Image { image_url } => Some(image_url.clone()),
289                                _ => None,
290                            })
291                            .collect::<Vec<_>>();
292
293                        let mut message_map: IndexMap<
294                            String,
295                            Either<String, Vec<IndexMap<String, Value>>>,
296                        > = IndexMap::new();
297                        message_map.insert("role".to_string(), Either::Left(message.role));
298
299                        let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
300                        for _ in &image_urls_iter {
301                            let mut content_image_map = IndexMap::new();
302                            content_image_map
303                                .insert("type".to_string(), Value::String("image".to_string()));
304                            content_map.push(content_image_map);
305                        }
306                        {
307                            let mut content_text_map = IndexMap::new();
308                            content_text_map
309                                .insert("type".to_string(), Value::String("text".to_string()));
310                            content_text_map
311                                .insert("text".to_string(), Value::String(text_content));
312                            content_map.push(content_text_map);
313                        }
314
315                        message_map.insert("content".to_string(), Either::Right(content_map));
316                        messages.push(message_map);
317                        image_urls.extend(image_urls_iter);
318                    }
319                }
320            }
321            if !image_urls.is_empty() {
322                let mut images = Vec::new();
323                for url_unparsed in image_urls {
324                    let image = util::parse_image_url(&url_unparsed)
325                        .await
326                        .context(format!("Failed to parse image resource: {}", url_unparsed))?;
327
328                    images.push(image);
329                }
330                RequestMessage::VisionChat { messages, images }
331            } else {
332                RequestMessage::Chat(messages)
333            }
334        }
335        Either::Right(prompt) => {
336            let mut messages = Vec::new();
337            let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
338                IndexMap::new();
339            message_map.insert("role".to_string(), Either::Left("user".to_string()));
340            message_map.insert("content".to_string(), Either::Left(prompt));
341            messages.push(message_map);
342            RequestMessage::Chat(messages)
343        }
344    };
345
346    let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
347        Some(DrySamplingParams::new_with_defaults(
348            dry_multiplier,
349            oairequest.dry_sequence_breakers,
350            oairequest.dry_base,
351            oairequest.dry_allowed_length,
352        )?)
353    } else {
354        None
355    };
356
357    let is_streaming = oairequest.stream.unwrap_or(false);
358
359    if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
360        anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
361    }
362
363    let constraint = match oairequest.grammar {
364        Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
365        Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
366        Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
367        Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
368        None => match oairequest.response_format {
369            Some(ResponseFormat::JsonSchema {
370                json_schema: JsonSchemaResponseFormat { name: _, schema },
371            }) => Constraint::JsonSchema(schema),
372            Some(ResponseFormat::Text) => Constraint::None,
373            None => Constraint::None,
374        },
375    };
376
377    Ok((
378        Request::Normal(NormalRequest {
379            id: state.next_request_id(),
380            messages,
381            sampling_params: SamplingParams {
382                temperature: oairequest.temperature,
383                top_k: oairequest.top_k,
384                top_p: oairequest.top_p,
385                min_p: oairequest.min_p,
386                top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
387                frequency_penalty: oairequest.frequency_penalty,
388                presence_penalty: oairequest.presence_penalty,
389                max_len: oairequest.max_tokens,
390                stop_toks,
391                logits_bias: oairequest.logit_bias,
392                n_choices: oairequest.n_choices,
393                dry_params,
394            },
395            response: tx,
396            return_logprobs: oairequest.logprobs,
397            is_streaming,
398            suffix: None,
399            constraint,
400            adapters: oairequest.adapters,
401            tool_choice: oairequest.tool_choice,
402            tools: oairequest.tools,
403            logits_processors: None,
404            return_raw_logits: false,
405            web_search_options: oairequest.web_search_options,
406        }),
407        is_streaming,
408    ))
409}
410
411#[utoipa::path(
412    post,
413    tag = "Mistral.rs",
414    path = "/v1/chat/completions",
415    request_body = ChatCompletionRequest,
416    responses((status = 200, description = "Chat completions"))
417)]
418pub async fn chatcompletions(
419    State(state): State<Arc<MistralRs>>,
420    Json(oairequest): Json<ChatCompletionRequest>,
421) -> ChatCompletionResponder {
422    let (tx, mut rx) = channel(10_000);
423    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
424        Ok(x) => x,
425        Err(e) => {
426            let e = anyhow::Error::msg(e.to_string());
427            MistralRs::maybe_log_error(state, &*e);
428            return ChatCompletionResponder::InternalError(e.into());
429        }
430    };
431    let sender = state.get_sender().unwrap();
432
433    if let Err(e) = sender.send(request).await {
434        let e = anyhow::Error::msg(e.to_string());
435        MistralRs::maybe_log_error(state, &*e);
436        return ChatCompletionResponder::InternalError(e.into());
437    }
438
439    if is_streaming {
440        let streamer = Streamer {
441            rx,
442            done_state: DoneState::Running,
443            state,
444        };
445
446        let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL")
447            .map(|val| val.parse::<u64>().unwrap_or(10000))
448            .unwrap_or(10000);
449        ChatCompletionResponder::Sse(
450            Sse::new(streamer)
451                .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))),
452        )
453    } else {
454        let response = match rx.recv().await {
455            Some(response) => response,
456            None => {
457                let e = anyhow::Error::msg("No response received from the model.");
458                MistralRs::maybe_log_error(state, &*e);
459                return ChatCompletionResponder::InternalError(e.into());
460            }
461        };
462
463        match response {
464            Response::InternalError(e) => {
465                MistralRs::maybe_log_error(state, &*e);
466                ChatCompletionResponder::InternalError(e)
467            }
468            Response::ModelError(msg, response) => {
469                MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
470                MistralRs::maybe_log_response(state, &response);
471                ChatCompletionResponder::ModelError(msg, response)
472            }
473            Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
474            Response::Done(response) => {
475                MistralRs::maybe_log_response(state, &response);
476                ChatCompletionResponder::Json(response)
477            }
478            Response::Chunk(_) => unreachable!(),
479            Response::CompletionDone(_) => unreachable!(),
480            Response::CompletionModelError(_, _) => unreachable!(),
481            Response::CompletionChunk(_) => unreachable!(),
482            Response::ImageGeneration(_) => unreachable!(),
483            Response::Raw { .. } => unreachable!(),
484        }
485    }
486}