1use std::{env, error::Error, ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7 extract::{Json, State},
8 http::{self, StatusCode},
9 response::{
10 sse::{Event, KeepAlive},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18 ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, DrySamplingParams, MistralRs,
19 NormalRequest, Request, RequestMessage, Response, SamplingParams,
20 StopTokens as InternalStopTokens,
21};
22use serde::Serialize;
23use serde_json::Value;
24use tokio::sync::mpsc::{channel, Receiver, Sender};
25
26use crate::{
27 openai::{
28 ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
29 ResponseFormat, StopTokens,
30 },
31 types::{ExtractedMistralRsState, SharedMistralRsState},
32 util,
33};
34
35pub type OnChunkCallback =
53 Box<dyn Fn(ChatCompletionChunkResponse) -> ChatCompletionChunkResponse + Send + Sync>;
54
55pub type OnDoneCallback = Box<dyn Fn(&[ChatCompletionChunkResponse]) + Send + Sync>;
71
72pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;
78
79pub const DEFAULT_KEEP_ALIVE_INTERVAL_MS: u64 = 10_000;
81
82#[derive(Debug)]
87struct ModelErrorMessage(String);
88impl std::fmt::Display for ModelErrorMessage {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "{}", self.0)
91 }
92}
93impl std::error::Error for ModelErrorMessage {}
94
95enum DoneState {
97 Running,
99 SendingDone,
101 Done,
103}
104
105pub struct Streamer {
110 rx: Receiver<Response>,
112 done_state: DoneState,
114 state: SharedMistralRsState,
116 store_chunks: bool,
118 chunks: Vec<ChatCompletionChunkResponse>,
120 on_chunk: Option<OnChunkCallback>,
122 on_done: Option<OnDoneCallback>,
124}
125
126impl futures::Stream for Streamer {
127 type Item = Result<Event, axum::Error>;
128
129 fn poll_next(
137 mut self: Pin<&mut Self>,
138 cx: &mut std::task::Context<'_>,
139 ) -> Poll<Option<Self::Item>> {
140 match self.done_state {
141 DoneState::SendingDone => {
142 self.done_state = DoneState::Done;
145 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
146 }
147 DoneState::Done => {
148 if let Some(on_done) = &self.on_done {
149 on_done(&self.chunks);
150 }
151 return Poll::Ready(None);
152 }
153 DoneState::Running => (),
154 }
155
156 match self.rx.poll_recv(cx) {
157 Poll::Ready(Some(resp)) => match resp {
158 Response::ModelError(msg, _) => {
159 MistralRs::maybe_log_error(
160 self.state.clone(),
161 &ModelErrorMessage(msg.to_string()),
162 );
163 self.done_state = DoneState::SendingDone;
165 Poll::Ready(Some(Ok(Event::default().data(msg))))
166 }
167 Response::ValidationError(e) => {
168 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
169 }
170 Response::InternalError(e) => {
171 MistralRs::maybe_log_error(self.state.clone(), &*e);
172 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
173 }
174 Response::Chunk(mut response) => {
175 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
176 self.done_state = DoneState::SendingDone;
177 }
178 MistralRs::maybe_log_response(self.state.clone(), &response);
180
181 if let Some(on_chunk) = &self.on_chunk {
182 response = on_chunk(response);
183 }
184
185 if self.store_chunks {
186 self.chunks.push(response.clone());
187 }
188
189 Poll::Ready(Some(Event::default().json_data(response)))
190 }
191 Response::Done(_) => unreachable!(),
192 Response::CompletionDone(_) => unreachable!(),
193 Response::CompletionModelError(_, _) => unreachable!(),
194 Response::CompletionChunk(_) => unreachable!(),
195 Response::ImageGeneration(_) => unreachable!(),
196 Response::Speech { .. } => unreachable!(),
197 Response::Raw { .. } => unreachable!(),
198 },
199 Poll::Pending | Poll::Ready(None) => Poll::Pending,
200 }
201 }
202}
203
204pub enum ChatCompletionResponder {
206 Sse(Sse<Streamer>),
208 Json(ChatCompletionResponse),
210 ModelError(String, ChatCompletionResponse),
212 InternalError(Box<dyn Error>),
214 ValidationError(Box<dyn Error>),
216}
217
218trait ErrorToResponse: Serialize {
220 fn to_response(&self, code: StatusCode) -> axum::response::Response {
222 let mut r = Json(self).into_response();
223 *r.status_mut() = code;
224 r
225 }
226}
227
228#[derive(Serialize)]
230struct JsonError {
231 message: String,
232}
233
234impl JsonError {
235 fn new(message: String) -> Self {
237 Self { message }
238 }
239}
240impl ErrorToResponse for JsonError {}
241
242#[derive(Serialize)]
244struct JsonModelError {
245 message: String,
246 partial_response: ChatCompletionResponse,
248}
249
250impl JsonModelError {
251 fn new(message: String, partial_response: ChatCompletionResponse) -> Self {
253 Self {
254 message,
255 partial_response,
256 }
257 }
258}
259
260impl ErrorToResponse for JsonModelError {}
261
262impl IntoResponse for ChatCompletionResponder {
263 fn into_response(self) -> axum::response::Response {
265 match self {
266 ChatCompletionResponder::Sse(s) => s.into_response(),
267 ChatCompletionResponder::Json(s) => Json(s).into_response(),
268 ChatCompletionResponder::InternalError(e) => {
269 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
270 }
271 ChatCompletionResponder::ValidationError(e) => {
272 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
273 }
274 ChatCompletionResponder::ModelError(msg, response) => {
275 JsonModelError::new(msg, response)
276 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
277 }
278 }
279 }
280}
281
282pub async fn parse_request(
287 oairequest: ChatCompletionRequest,
288 state: SharedMistralRsState,
289 tx: Sender<Response>,
290) -> Result<(Request, bool)> {
291 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
292 MistralRs::maybe_log_request(state.clone(), repr);
293
294 let stop_toks = match oairequest.stop_seqs {
295 Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
296 Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
297 None => None,
298 };
299 let messages = match oairequest.messages {
300 Either::Left(req_messages) => {
301 let mut messages = Vec::new();
302 let mut image_urls = Vec::new();
303 for message in req_messages {
304 let content = match message.content.as_deref() {
305 Some(content) => content.clone(),
306 None => {
307 let calls = message
309 .tool_calls
310 .as_ref()
311 .context(
312 "No content was provided, expected tool calls to be provided.",
313 )?
314 .iter()
315 .map(|call| &call.function)
316 .collect::<Vec<_>>();
317
318 Either::Left(serde_json::to_string(&calls)?)
319 }
320 };
321
322 match &content {
323 Either::Left(content) => {
324 let mut message_map: IndexMap<
325 String,
326 Either<String, Vec<IndexMap<String, Value>>>,
327 > = IndexMap::new();
328 message_map.insert("role".to_string(), Either::Left(message.role));
329 message_map.insert("content".to_string(), Either::Left(content.clone()));
330 messages.push(message_map);
331 }
332 Either::Right(image_messages) => {
333 if image_messages.len() == 1 {
337 if !image_messages[0].contains_key("text") {
338 anyhow::bail!("Expected `text` key in input message.");
339 }
340 let content = match image_messages[0]["text"].deref() {
341 Either::Left(left) => left.to_string(),
342 Either::Right(right) => format!("{:?}", right),
343 };
344 let mut message_map: IndexMap<
345 String,
346 Either<String, Vec<IndexMap<String, Value>>>,
347 > = IndexMap::new();
348 message_map.insert("role".to_string(), Either::Left(message.role));
349 message_map.insert("content".to_string(), Either::Left(content));
350 messages.push(message_map);
351 continue;
352 }
353 if message.role != "user" {
354 anyhow::bail!(
355 "Role for an image message must be `user`, but it is {}",
356 message.role
357 );
358 }
359
360 enum ContentPart {
361 Text { text: String },
362 Image { image_url: String },
363 }
364
365 let mut items = Vec::new();
366 for image_message in image_messages {
367 match image_message.get("type") {
368 Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
369 items.push(ContentPart::Text {
370 text: image_message
371 .get("text").as_ref()
372 .context("Text sub-content must have `text` key.")?.as_ref()
373 .left().context("Text sub-content `text` key must be a string.")?.clone(),
374 });
375 }
376 Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
377 items.push(ContentPart::Image {
378 image_url: image_message.get("image_url").as_ref()
379 .context("Image sub-content must have `image_url` key.")?.as_ref()
380 .right()
381 .context("Image sub-content `image_url` key must be an object.")?
382 .get("url")
383 .context("Image sub-content `image_url` object must have a `url` key.")?.clone()
384 });
385 }
386 _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
387 }
388 }
389
390 let text_content = items
391 .iter()
392 .filter_map(|item| match item {
393 ContentPart::Text { text } => Some(text),
394 _ => None,
395 })
396 .join(" ");
397 let image_urls_iter = items
398 .iter()
399 .filter_map(|item| match item {
400 ContentPart::Image { image_url } => Some(image_url.clone()),
401 _ => None,
402 })
403 .collect::<Vec<_>>();
404
405 let mut message_map: IndexMap<
406 String,
407 Either<String, Vec<IndexMap<String, Value>>>,
408 > = IndexMap::new();
409 message_map.insert("role".to_string(), Either::Left(message.role));
410
411 let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
412 for _ in &image_urls_iter {
413 let mut content_image_map = IndexMap::new();
414 content_image_map
415 .insert("type".to_string(), Value::String("image".to_string()));
416 content_map.push(content_image_map);
417 }
418 {
419 let mut content_text_map = IndexMap::new();
420 content_text_map
421 .insert("type".to_string(), Value::String("text".to_string()));
422 content_text_map
423 .insert("text".to_string(), Value::String(text_content));
424 content_map.push(content_text_map);
425 }
426
427 message_map.insert("content".to_string(), Either::Right(content_map));
428 messages.push(message_map);
429 image_urls.extend(image_urls_iter);
430 }
431 }
432 }
433 if !image_urls.is_empty() {
434 let mut images = Vec::new();
435 for url_unparsed in image_urls {
436 let image = util::parse_image_url(&url_unparsed)
437 .await
438 .context(format!("Failed to parse image resource: {}", url_unparsed))?;
439
440 images.push(image);
441 }
442 RequestMessage::VisionChat {
443 messages,
444 images,
445 enable_thinking: oairequest.enable_thinking,
446 }
447 } else {
448 RequestMessage::Chat {
449 messages,
450 enable_thinking: oairequest.enable_thinking,
451 }
452 }
453 }
454 Either::Right(prompt) => {
455 let mut messages = Vec::new();
456 let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
457 IndexMap::new();
458 message_map.insert("role".to_string(), Either::Left("user".to_string()));
459 message_map.insert("content".to_string(), Either::Left(prompt));
460 messages.push(message_map);
461 RequestMessage::Chat {
462 messages,
463 enable_thinking: oairequest.enable_thinking,
464 }
465 }
466 };
467
468 let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
469 Some(DrySamplingParams::new_with_defaults(
470 dry_multiplier,
471 oairequest.dry_sequence_breakers,
472 oairequest.dry_base,
473 oairequest.dry_allowed_length,
474 )?)
475 } else {
476 None
477 };
478
479 let is_streaming = oairequest.stream.unwrap_or(false);
480
481 if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
482 anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
483 }
484
485 let constraint = match oairequest.grammar {
486 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
487 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
488 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
489 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
490 None => match oairequest.response_format {
491 Some(ResponseFormat::JsonSchema {
492 json_schema: JsonSchemaResponseFormat { name: _, schema },
493 }) => Constraint::JsonSchema(schema),
494 Some(ResponseFormat::Text) => Constraint::None,
495 None => Constraint::None,
496 },
497 };
498
499 Ok((
500 Request::Normal(Box::new(NormalRequest {
501 id: state.next_request_id(),
502 messages,
503 sampling_params: SamplingParams {
504 temperature: oairequest.temperature,
505 top_k: oairequest.top_k,
506 top_p: oairequest.top_p,
507 min_p: oairequest.min_p,
508 top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
509 frequency_penalty: oairequest.frequency_penalty,
510 presence_penalty: oairequest.presence_penalty,
511 max_len: oairequest.max_tokens,
512 stop_toks,
513 logits_bias: oairequest.logit_bias,
514 n_choices: oairequest.n_choices,
515 dry_params,
516 },
517 response: tx,
518 return_logprobs: oairequest.logprobs,
519 is_streaming,
520 suffix: None,
521 constraint,
522 tool_choice: oairequest.tool_choice,
523 tools: oairequest.tools,
524 logits_processors: None,
525 return_raw_logits: false,
526 web_search_options: oairequest.web_search_options,
527 })),
528 is_streaming,
529 ))
530}
531
532#[utoipa::path(
534 post,
535 tag = "Mistral.rs",
536 path = "/v1/chat/completions",
537 request_body = ChatCompletionRequest,
538 responses((status = 200, description = "Chat completions"))
539)]
540pub async fn chatcompletions(
541 State(state): ExtractedMistralRsState,
542 Json(oairequest): Json<ChatCompletionRequest>,
543) -> ChatCompletionResponder {
544 let (tx, mut rx) = create_response_channel(None);
545
546 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
547 Ok(x) => x,
548 Err(e) => return handle_chat_completion_error(state, e.into()),
549 };
550
551 if let Err(e) = send_request(&state, request).await {
552 return handle_chat_completion_error(state, e.into());
553 }
554
555 if is_streaming {
556 ChatCompletionResponder::Sse(create_chat_streamer(rx, state, None, None))
557 } else {
558 process_non_streaming_chat_response(&mut rx, state).await
559 }
560}
561
562pub fn handle_chat_completion_error(
564 state: SharedMistralRsState,
565 e: Box<dyn std::error::Error + Send + Sync + 'static>,
566) -> ChatCompletionResponder {
567 let e = anyhow::Error::msg(e.to_string());
568 MistralRs::maybe_log_error(state, &*e);
569 ChatCompletionResponder::InternalError(e.into())
570}
571
572pub fn create_response_channel(
574 buffer_size: Option<usize>,
575) -> (Sender<Response>, Receiver<Response>) {
576 let channel_buffer_size = buffer_size.unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
577
578 channel(channel_buffer_size)
579}
580
581pub fn get_keep_alive_interval() -> u64 {
583 env::var("KEEP_ALIVE_INTERVAL")
584 .map(|val| {
585 val.parse::<u64>().unwrap_or_else(|e| {
586 tracing::warn!("Failed to parse KEEP_ALIVE_INTERVAL: {}. Using default.", e);
587 DEFAULT_KEEP_ALIVE_INTERVAL_MS
588 })
589 })
590 .unwrap_or(DEFAULT_KEEP_ALIVE_INTERVAL_MS)
591}
592
593pub async fn send_request(state: &SharedMistralRsState, request: Request) -> Result<()> {
595 let sender = state
596 .get_sender()
597 .context("mistral.rs sender not available.")?;
598
599 sender.send(request).await.map_err(|e| e.into())
600}
601
602pub fn create_chat_streamer(
604 rx: Receiver<Response>,
605 state: SharedMistralRsState,
606 on_chunk: Option<OnChunkCallback>,
607 on_done: Option<OnDoneCallback>,
608) -> Sse<Streamer> {
609 let store_chunks = on_done.is_some();
610
611 let streamer = Streamer {
612 rx,
613 done_state: DoneState::Running,
614 store_chunks,
615 state,
616 chunks: Vec::new(),
617 on_chunk,
618 on_done,
619 };
620
621 let keep_alive_interval = get_keep_alive_interval();
622
623 Sse::new(streamer)
624 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
625}
626
627pub async fn process_non_streaming_chat_response(
629 rx: &mut Receiver<Response>,
630 state: SharedMistralRsState,
631) -> ChatCompletionResponder {
632 let response = match rx.recv().await {
633 Some(response) => response,
634 None => {
635 let e = anyhow::Error::msg("No response received from the model.");
636 return handle_chat_completion_error(state, e.into());
637 }
638 };
639
640 match_responses(state, response)
641}
642
643pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
645 match response {
646 Response::InternalError(e) => {
647 MistralRs::maybe_log_error(state, &*e);
648 ChatCompletionResponder::InternalError(e)
649 }
650 Response::ModelError(msg, response) => {
651 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
652 MistralRs::maybe_log_response(state, &response);
653 ChatCompletionResponder::ModelError(msg, response)
654 }
655 Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
656 Response::Done(response) => {
657 MistralRs::maybe_log_response(state, &response);
658 ChatCompletionResponder::Json(response)
659 }
660 Response::Chunk(_) => unreachable!(),
661 Response::CompletionDone(_) => unreachable!(),
662 Response::CompletionModelError(_, _) => unreachable!(),
663 Response::CompletionChunk(_) => unreachable!(),
664 Response::ImageGeneration(_) => unreachable!(),
665 Response::Speech { .. } => unreachable!(),
666 Response::Raw { .. } => unreachable!(),
667 }
668}