1use serde_json::Value;
2use std::{env, error::Error, ops::Deref, pin::Pin, sync::Arc, task::Poll, time::Duration};
3use tokio::sync::mpsc::{channel, Receiver, Sender};
4
5use crate::{
6 openai::{
7 ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
8 ResponseFormat, StopTokens,
9 },
10 util,
11};
12use anyhow::Context;
13use anyhow::Result;
14use axum::{
15 extract::{Json, State},
16 http::{self, StatusCode},
17 response::{
18 sse::{Event, KeepAlive},
19 IntoResponse, Sse,
20 },
21};
22use either::Either;
23use indexmap::IndexMap;
24use itertools::Itertools;
25use mistralrs_core::{
26 ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
27 RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
28};
29use serde::Serialize;
30
31#[derive(Debug)]
32struct ModelErrorMessage(String);
33impl std::fmt::Display for ModelErrorMessage {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "{}", self.0)
36 }
37}
38impl std::error::Error for ModelErrorMessage {}
39
40enum DoneState {
41 Running,
42 SendingDone,
43 Done,
44}
45
46pub struct Streamer {
47 rx: Receiver<Response>,
48 done_state: DoneState,
49 state: Arc<MistralRs>,
50}
51
52impl futures::Stream for Streamer {
53 type Item = Result<Event, axum::Error>;
54
55 fn poll_next(
56 mut self: Pin<&mut Self>,
57 cx: &mut std::task::Context<'_>,
58 ) -> Poll<Option<Self::Item>> {
59 match self.done_state {
60 DoneState::SendingDone => {
61 self.done_state = DoneState::Done;
64 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
65 }
66 DoneState::Done => {
67 return Poll::Ready(None);
68 }
69 DoneState::Running => (),
70 }
71
72 match self.rx.poll_recv(cx) {
73 Poll::Ready(Some(resp)) => match resp {
74 Response::ModelError(msg, _) => {
75 MistralRs::maybe_log_error(
76 self.state.clone(),
77 &ModelErrorMessage(msg.to_string()),
78 );
79 self.done_state = DoneState::SendingDone;
81 Poll::Ready(Some(Ok(Event::default().data(msg))))
82 }
83 Response::ValidationError(e) => {
84 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
85 }
86 Response::InternalError(e) => {
87 MistralRs::maybe_log_error(self.state.clone(), &*e);
88 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
89 }
90 Response::Chunk(response) => {
91 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
92 self.done_state = DoneState::SendingDone;
93 }
94 MistralRs::maybe_log_response(self.state.clone(), &response);
96 Poll::Ready(Some(Event::default().json_data(response)))
97 }
98 Response::Done(_) => unreachable!(),
99 Response::CompletionDone(_) => unreachable!(),
100 Response::CompletionModelError(_, _) => unreachable!(),
101 Response::CompletionChunk(_) => unreachable!(),
102 Response::ImageGeneration(_) => unreachable!(),
103 Response::Raw { .. } => unreachable!(),
104 },
105 Poll::Pending | Poll::Ready(None) => Poll::Pending,
106 }
107 }
108}
109
110pub enum ChatCompletionResponder {
111 Sse(Sse<Streamer>),
112 Json(ChatCompletionResponse),
113 ModelError(String, ChatCompletionResponse),
114 InternalError(Box<dyn Error>),
115 ValidationError(Box<dyn Error>),
116}
117
118trait ErrorToResponse: Serialize {
119 fn to_response(&self, code: StatusCode) -> axum::response::Response {
120 let mut r = Json(self).into_response();
121 *r.status_mut() = code;
122 r
123 }
124}
125
126#[derive(Serialize)]
127struct JsonError {
128 message: String,
129}
130
131impl JsonError {
132 fn new(message: String) -> Self {
133 Self { message }
134 }
135}
136impl ErrorToResponse for JsonError {}
137
138#[derive(Serialize)]
139struct JsonModelError {
140 message: String,
141 partial_response: ChatCompletionResponse,
142}
143
144impl JsonModelError {
145 fn new(message: String, partial_response: ChatCompletionResponse) -> Self {
146 Self {
147 message,
148 partial_response,
149 }
150 }
151}
152
153impl ErrorToResponse for JsonModelError {}
154
155impl IntoResponse for ChatCompletionResponder {
156 fn into_response(self) -> axum::response::Response {
157 match self {
158 ChatCompletionResponder::Sse(s) => s.into_response(),
159 ChatCompletionResponder::Json(s) => Json(s).into_response(),
160 ChatCompletionResponder::InternalError(e) => {
161 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
162 }
163 ChatCompletionResponder::ValidationError(e) => {
164 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
165 }
166 ChatCompletionResponder::ModelError(msg, response) => {
167 JsonModelError::new(msg, response)
168 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
169 }
170 }
171 }
172}
173
174async fn parse_request(
175 oairequest: ChatCompletionRequest,
176 state: Arc<MistralRs>,
177 tx: Sender<Response>,
178) -> Result<(Request, bool)> {
179 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
180 MistralRs::maybe_log_request(state.clone(), repr);
181
182 let stop_toks = match oairequest.stop_seqs {
183 Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
184 Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
185 None => None,
186 };
187 let messages = match oairequest.messages {
188 Either::Left(req_messages) => {
189 let mut messages = Vec::new();
190 let mut image_urls = Vec::new();
191 for message in req_messages {
192 let content = match message.content.as_deref() {
193 Some(content) => content.clone(),
194 None => {
195 let calls = message
197 .tool_calls
198 .as_ref()
199 .context(
200 "No content was provided, expected tool calls to be provided.",
201 )?
202 .iter()
203 .map(|call| &call.function)
204 .collect::<Vec<_>>();
205
206 Either::Left(serde_json::to_string(&calls)?)
207 }
208 };
209
210 match &content {
211 Either::Left(content) => {
212 let mut message_map: IndexMap<
213 String,
214 Either<String, Vec<IndexMap<String, Value>>>,
215 > = IndexMap::new();
216 message_map.insert("role".to_string(), Either::Left(message.role));
217 message_map.insert("content".to_string(), Either::Left(content.clone()));
218 messages.push(message_map);
219 }
220 Either::Right(image_messages) => {
221 if image_messages.len() == 1 {
225 if !image_messages[0].contains_key("text") {
226 anyhow::bail!("Expected `text` key in input message.");
227 }
228 let content = match image_messages[0]["text"].deref() {
229 Either::Left(left) => left.to_string(),
230 Either::Right(right) => format!("{:?}", right),
231 };
232 let mut message_map: IndexMap<
233 String,
234 Either<String, Vec<IndexMap<String, Value>>>,
235 > = IndexMap::new();
236 message_map.insert("role".to_string(), Either::Left(message.role));
237 message_map.insert("content".to_string(), Either::Left(content));
238 messages.push(message_map);
239 continue;
240 }
241 if message.role != "user" {
242 anyhow::bail!(
243 "Role for an image message must be `user`, but it is {}",
244 message.role
245 );
246 }
247
248 enum ContentPart {
249 Text { text: String },
250 Image { image_url: String },
251 }
252
253 let mut items = Vec::new();
254 for image_message in image_messages {
255 match image_message.get("type") {
256 Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
257 items.push(ContentPart::Text {
258 text: image_message
259 .get("text").as_ref()
260 .context("Text sub-content must have `text` key.")?.as_ref()
261 .left().context("Text sub-content `text` key must be a string.")?.clone(),
262 });
263 }
264 Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
265 items.push(ContentPart::Image {
266 image_url: image_message.get("image_url").as_ref()
267 .context("Image sub-content must have `image_url` key.")?.as_ref()
268 .right()
269 .context("Image sub-content `image_url` key must be an object.")?
270 .get("url")
271 .context("Image sub-content `image_url` object must have a `url` key.")?.clone()
272 });
273 }
274 _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
275 }
276 }
277
278 let text_content = items
279 .iter()
280 .filter_map(|item| match item {
281 ContentPart::Text { text } => Some(text),
282 _ => None,
283 })
284 .join(" ");
285 let image_urls_iter = items
286 .iter()
287 .filter_map(|item| match item {
288 ContentPart::Image { image_url } => Some(image_url.clone()),
289 _ => None,
290 })
291 .collect::<Vec<_>>();
292
293 let mut message_map: IndexMap<
294 String,
295 Either<String, Vec<IndexMap<String, Value>>>,
296 > = IndexMap::new();
297 message_map.insert("role".to_string(), Either::Left(message.role));
298
299 let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
300 for _ in &image_urls_iter {
301 let mut content_image_map = IndexMap::new();
302 content_image_map
303 .insert("type".to_string(), Value::String("image".to_string()));
304 content_map.push(content_image_map);
305 }
306 {
307 let mut content_text_map = IndexMap::new();
308 content_text_map
309 .insert("type".to_string(), Value::String("text".to_string()));
310 content_text_map
311 .insert("text".to_string(), Value::String(text_content));
312 content_map.push(content_text_map);
313 }
314
315 message_map.insert("content".to_string(), Either::Right(content_map));
316 messages.push(message_map);
317 image_urls.extend(image_urls_iter);
318 }
319 }
320 }
321 if !image_urls.is_empty() {
322 let mut images = Vec::new();
323 for url_unparsed in image_urls {
324 let image = util::parse_image_url(&url_unparsed)
325 .await
326 .context(format!("Failed to parse image resource: {}", url_unparsed))?;
327
328 images.push(image);
329 }
330 RequestMessage::VisionChat { messages, images }
331 } else {
332 RequestMessage::Chat(messages)
333 }
334 }
335 Either::Right(prompt) => {
336 let mut messages = Vec::new();
337 let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
338 IndexMap::new();
339 message_map.insert("role".to_string(), Either::Left("user".to_string()));
340 message_map.insert("content".to_string(), Either::Left(prompt));
341 messages.push(message_map);
342 RequestMessage::Chat(messages)
343 }
344 };
345
346 let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
347 Some(DrySamplingParams::new_with_defaults(
348 dry_multiplier,
349 oairequest.dry_sequence_breakers,
350 oairequest.dry_base,
351 oairequest.dry_allowed_length,
352 )?)
353 } else {
354 None
355 };
356
357 let is_streaming = oairequest.stream.unwrap_or(false);
358
359 if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
360 anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
361 }
362
363 let constraint = match oairequest.grammar {
364 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
365 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
366 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
367 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
368 None => match oairequest.response_format {
369 Some(ResponseFormat::JsonSchema {
370 json_schema: JsonSchemaResponseFormat { name: _, schema },
371 }) => Constraint::JsonSchema(schema),
372 Some(ResponseFormat::Text) => Constraint::None,
373 None => Constraint::None,
374 },
375 };
376
377 Ok((
378 Request::Normal(NormalRequest {
379 id: state.next_request_id(),
380 messages,
381 sampling_params: SamplingParams {
382 temperature: oairequest.temperature,
383 top_k: oairequest.top_k,
384 top_p: oairequest.top_p,
385 min_p: oairequest.min_p,
386 top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
387 frequency_penalty: oairequest.frequency_penalty,
388 presence_penalty: oairequest.presence_penalty,
389 max_len: oairequest.max_tokens,
390 stop_toks,
391 logits_bias: oairequest.logit_bias,
392 n_choices: oairequest.n_choices,
393 dry_params,
394 },
395 response: tx,
396 return_logprobs: oairequest.logprobs,
397 is_streaming,
398 suffix: None,
399 constraint,
400 adapters: oairequest.adapters,
401 tool_choice: oairequest.tool_choice,
402 tools: oairequest.tools,
403 logits_processors: None,
404 return_raw_logits: false,
405 web_search_options: oairequest.web_search_options,
406 }),
407 is_streaming,
408 ))
409}
410
411#[utoipa::path(
412 post,
413 tag = "Mistral.rs",
414 path = "/v1/chat/completions",
415 request_body = ChatCompletionRequest,
416 responses((status = 200, description = "Chat completions"))
417)]
418pub async fn chatcompletions(
419 State(state): State<Arc<MistralRs>>,
420 Json(oairequest): Json<ChatCompletionRequest>,
421) -> ChatCompletionResponder {
422 let (tx, mut rx) = channel(10_000);
423 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
424 Ok(x) => x,
425 Err(e) => {
426 let e = anyhow::Error::msg(e.to_string());
427 MistralRs::maybe_log_error(state, &*e);
428 return ChatCompletionResponder::InternalError(e.into());
429 }
430 };
431 let sender = state.get_sender().unwrap();
432
433 if let Err(e) = sender.send(request).await {
434 let e = anyhow::Error::msg(e.to_string());
435 MistralRs::maybe_log_error(state, &*e);
436 return ChatCompletionResponder::InternalError(e.into());
437 }
438
439 if is_streaming {
440 let streamer = Streamer {
441 rx,
442 done_state: DoneState::Running,
443 state,
444 };
445
446 let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL")
447 .map(|val| val.parse::<u64>().unwrap_or(10000))
448 .unwrap_or(10000);
449 ChatCompletionResponder::Sse(
450 Sse::new(streamer)
451 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))),
452 )
453 } else {
454 let response = match rx.recv().await {
455 Some(response) => response,
456 None => {
457 let e = anyhow::Error::msg("No response received from the model.");
458 MistralRs::maybe_log_error(state, &*e);
459 return ChatCompletionResponder::InternalError(e.into());
460 }
461 };
462
463 match response {
464 Response::InternalError(e) => {
465 MistralRs::maybe_log_error(state, &*e);
466 ChatCompletionResponder::InternalError(e)
467 }
468 Response::ModelError(msg, response) => {
469 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
470 MistralRs::maybe_log_response(state, &response);
471 ChatCompletionResponder::ModelError(msg, response)
472 }
473 Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
474 Response::Done(response) => {
475 MistralRs::maybe_log_response(state, &response);
476 ChatCompletionResponder::Json(response)
477 }
478 Response::Chunk(_) => unreachable!(),
479 Response::CompletionDone(_) => unreachable!(),
480 Response::CompletionModelError(_, _) => unreachable!(),
481 Response::CompletionChunk(_) => unreachable!(),
482 Response::ImageGeneration(_) => unreachable!(),
483 Response::Raw { .. } => unreachable!(),
484 }
485 }
486}