1use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7 extract::{Json, State},
8 http::{self},
9 response::{
10 sse::{Event, KeepAlive, KeepAliveStream},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18 ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19 ReasoningEffort, Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25 completion_core::{
26 convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27 BaseCompletionResponder,
28 },
29 handler_core::{
30 base_process_non_streaming_response, create_response_channel, send_request_with_model,
31 BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32 },
33 openai::{
34 ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35 ResponseFormat,
36 },
37 streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39 util::{parse_audio_url, parse_image_url, sanitize_error_message, validate_model_name},
40};
41
42pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78pub type ChatCompletionStreamer = BaseStreamer<
83 ChatCompletionChunkResponse,
84 ChatCompletionOnChunkCallback,
85 ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89 type Item = Result<Event, axum::Error>;
90
91 fn poll_next(
99 mut self: Pin<&mut Self>,
100 cx: &mut std::task::Context<'_>,
101 ) -> Poll<Option<Self::Item>> {
102 match self.done_state {
103 DoneState::SendingDone => {
104 self.done_state = DoneState::Done;
107 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108 }
109 DoneState::Done => {
110 if let Some(on_done) = &self.on_done {
111 on_done(&self.chunks);
112 }
113 return Poll::Ready(None);
114 }
115 DoneState::Running => (),
116 }
117
118 match self.rx.poll_recv(cx) {
119 Poll::Ready(Some(resp)) => match resp {
120 Response::ModelError(msg, _) => {
121 MistralRs::maybe_log_error(
122 self.state.clone(),
123 &ModelErrorMessage(msg.to_string()),
124 );
125 self.done_state = DoneState::SendingDone;
127 Poll::Ready(Some(Ok(Event::default().data(msg))))
128 }
129 Response::ValidationError(e) => Poll::Ready(Some(Ok(
130 Event::default().data(sanitize_error_message(e.as_ref()))
131 ))),
132 Response::InternalError(e) => {
133 MistralRs::maybe_log_error(self.state.clone(), &*e);
134 Poll::Ready(Some(Ok(
135 Event::default().data(sanitize_error_message(e.as_ref()))
136 )))
137 }
138 Response::Chunk(mut response) => {
139 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
140 self.done_state = DoneState::SendingDone;
141 }
142 MistralRs::maybe_log_response(self.state.clone(), &response);
144
145 if let Some(on_chunk) = &self.on_chunk {
146 response = on_chunk(response);
147 }
148
149 if self.store_chunks {
150 self.chunks.push(response.clone());
151 }
152
153 Poll::Ready(Some(Event::default().json_data(response)))
154 }
155 Response::Done(_) => unreachable!(),
156 Response::CompletionDone(_) => unreachable!(),
157 Response::CompletionModelError(_, _) => unreachable!(),
158 Response::CompletionChunk(_) => unreachable!(),
159 Response::ImageGeneration(_) => unreachable!(),
160 Response::Speech { .. } => unreachable!(),
161 Response::Raw { .. } => unreachable!(),
162 Response::Embeddings { .. } => unreachable!(),
163 },
164 Poll::Pending | Poll::Ready(None) => Poll::Pending,
165 }
166 }
167}
168
169pub type ChatCompletionResponder =
171 BaseCompletionResponder<ChatCompletionResponse, KeepAliveStream<ChatCompletionStreamer>>;
172
173type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
174impl ErrorToResponse for JsonModelError {}
175
176impl IntoResponse for ChatCompletionResponder {
177 fn into_response(self) -> axum::response::Response {
179 match self {
180 ChatCompletionResponder::Sse(s) => s.into_response(),
181 ChatCompletionResponder::Json(s) => Json(s).into_response(),
182 ChatCompletionResponder::InternalError(e) => {
183 JsonError::new(sanitize_error_message(e.as_ref()))
184 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
185 }
186 ChatCompletionResponder::ValidationError(e) => {
187 JsonError::new(sanitize_error_message(e.as_ref()))
188 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
189 }
190 ChatCompletionResponder::ModelError(msg, response) => {
191 JsonModelError::new(msg, response)
192 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
193 }
194 }
195 }
196}
197
198fn parse_reasoning_effort(effort: &Option<String>) -> Option<ReasoningEffort> {
200 effort
201 .as_ref()
202 .and_then(|e| match e.to_lowercase().as_str() {
203 "low" => Some(ReasoningEffort::Low),
204 "medium" => Some(ReasoningEffort::Medium),
205 "high" => Some(ReasoningEffort::High),
206 _ => None,
207 })
208}
209
210pub async fn parse_request(
215 oairequest: ChatCompletionRequest,
216 state: SharedMistralRsState,
217 tx: Sender<Response>,
218) -> Result<(Request, bool)> {
219 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
220 MistralRs::maybe_log_request(state.clone(), repr);
221
222 validate_model_name(&oairequest.model, state.clone())?;
224
225 let reasoning_effort = parse_reasoning_effort(&oairequest.reasoning_effort);
227
228 let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
229
230 let messages = match oairequest.messages {
231 Either::Left(req_messages) => {
232 let mut messages = Vec::new();
233 let mut image_urls = Vec::new();
234 let mut audio_urls = Vec::new();
235 for message in req_messages {
236 let content = match message.content.as_deref() {
237 Some(content) => content.clone(),
238 None => {
239 let calls = message
241 .tool_calls
242 .as_ref()
243 .context(
244 "No content was provided, expected tool calls to be provided.",
245 )?
246 .iter()
247 .map(|call| &call.function)
248 .collect::<Vec<_>>();
249
250 Either::Left(serde_json::to_string(&calls)?)
251 }
252 };
253
254 match &content {
255 Either::Left(content) => {
256 let mut message_map: IndexMap<
257 String,
258 Either<String, Vec<IndexMap<String, Value>>>,
259 > = IndexMap::new();
260 message_map.insert("role".to_string(), Either::Left(message.role.clone()));
261 message_map.insert("content".to_string(), Either::Left(content.clone()));
262
263 if let Some(ref tool_calls) = message.tool_calls {
265 let tool_calls_vec: Vec<IndexMap<String, Value>> = tool_calls
267 .iter()
268 .map(|tc| {
269 let mut tc_map = IndexMap::new();
270 let id =
272 tc.id.clone().unwrap_or_else(|| tc.function.name.clone());
273 tc_map.insert("id".to_string(), Value::String(id));
274 tc_map.insert(
275 "type".to_string(),
276 Value::String("function".to_string()),
277 );
278 let mut function_map = serde_json::Map::new();
279 function_map.insert(
280 "name".to_string(),
281 Value::String(tc.function.name.clone()),
282 );
283 function_map.insert(
284 "arguments".to_string(),
285 Value::String(tc.function.arguments.clone()),
286 );
287 tc_map.insert(
288 "function".to_string(),
289 Value::Object(function_map),
290 );
291 tc_map
292 })
293 .collect();
294 message_map
295 .insert("tool_calls".to_string(), Either::Right(tool_calls_vec));
296 }
297
298 if let Some(ref tool_call_id) = message.tool_call_id {
300 message_map.insert(
301 "tool_call_id".to_string(),
302 Either::Left(tool_call_id.clone()),
303 );
304 }
305
306 if let Some(ref name) = message.name {
308 message_map.insert("name".to_string(), Either::Left(name.clone()));
309 }
310
311 messages.push(message_map);
312 }
313 Either::Right(image_messages) => {
314 if image_messages.len() == 1 {
318 if !image_messages[0].contains_key("text") {
319 anyhow::bail!("Expected `text` key in input message.");
320 }
321 let content = match image_messages[0]["text"].deref() {
322 Either::Left(left) => left.to_string(),
323 Either::Right(right) => format!("{right:?}"),
324 };
325 let mut message_map: IndexMap<
326 String,
327 Either<String, Vec<IndexMap<String, Value>>>,
328 > = IndexMap::new();
329 message_map.insert("role".to_string(), Either::Left(message.role));
330 message_map.insert("content".to_string(), Either::Left(content));
331 messages.push(message_map);
332 continue;
333 }
334 if message.role != "user" {
335 anyhow::bail!(
336 "Role for an image message must be `user`, but it is {}",
337 message.role
338 );
339 }
340
341 enum ContentPart {
342 Text { text: String },
343 Image { image_url: String },
344 Audio { audio_url: String },
345 }
346
347 let mut items = Vec::new();
348 for image_message in image_messages {
349 match image_message.get("type") {
350 Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
351 items.push(ContentPart::Text {
352 text: image_message
353 .get("text").as_ref()
354 .context("Text sub-content must have `text` key.")?.as_ref()
355 .left().context("Text sub-content `text` key must be a string.")?.clone(),
356 });
357 }
358 Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
359 items.push(ContentPart::Image {
360 image_url: image_message
361 .get("image_url")
362 .as_ref()
363 .context("Image sub-content must have `image_url` key.")?
364 .as_ref()
365 .right()
366 .context("Image sub-content `image_url` key must be an object.")?
367 .get("url")
368 .context("Image sub-content `image_url` object must have a `url` key.")?
369 .clone(),
370 });
371 }
372 Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
373 items.push(ContentPart::Audio {
374 audio_url: image_message
375 .get("audio_url")
376 .as_ref()
377 .context("Audio sub-content must have `audio_url` key.")?
378 .as_ref()
379 .right()
380 .context("Audio sub-content `audio_url` key must be an object.")?
381 .get("url")
382 .context("Audio sub-content `audio_url` object must have a `url` key.")?
383 .clone(),
384 });
385 }
386 _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
387 }
388 }
389
390 let text_content = items
391 .iter()
392 .filter_map(|item| match item {
393 ContentPart::Text { text } => Some(text),
394 _ => None,
395 })
396 .join(" ");
397 let image_urls_iter = items
398 .iter()
399 .filter_map(|item| match item {
400 ContentPart::Image { image_url } => Some(image_url.clone()),
401 _ => None,
402 })
403 .collect::<Vec<_>>();
404
405 let audio_urls_iter = items
406 .iter()
407 .filter_map(|item| match item {
408 ContentPart::Audio { audio_url } => Some(audio_url.clone()),
409 _ => None,
410 })
411 .collect::<Vec<_>>();
412
413 let mut message_map: IndexMap<
414 String,
415 Either<String, Vec<IndexMap<String, Value>>>,
416 > = IndexMap::new();
417 message_map.insert("role".to_string(), Either::Left(message.role));
418
419 let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
420 for _ in &image_urls_iter {
421 let mut content_image_map = IndexMap::new();
422 content_image_map
423 .insert("type".to_string(), Value::String("image".to_string()));
424 content_map.push(content_image_map);
425 }
426 for _ in &audio_urls_iter {
427 let mut content_audio_map = IndexMap::new();
428 content_audio_map
429 .insert("type".to_string(), Value::String("audio".to_string()));
430 content_map.push(content_audio_map);
431 }
432 {
433 let mut content_text_map = IndexMap::new();
434 content_text_map
435 .insert("type".to_string(), Value::String("text".to_string()));
436 content_text_map
437 .insert("text".to_string(), Value::String(text_content));
438 content_map.push(content_text_map);
439 }
440
441 message_map.insert("content".to_string(), Either::Right(content_map));
442 messages.push(message_map);
443 image_urls.extend(image_urls_iter);
444 audio_urls.extend(audio_urls_iter);
445 }
446 }
447 }
448 if !image_urls.is_empty() || !audio_urls.is_empty() {
449 let mut images = Vec::new();
451 for url_unparsed in image_urls {
452 let image = parse_image_url(&url_unparsed)
453 .await
454 .context(format!("Failed to parse image resource: {url_unparsed}"))?;
455 images.push(image);
456 }
457
458 let mut audios = Vec::new();
460 for url_unparsed in audio_urls {
461 let audio = parse_audio_url(&url_unparsed)
462 .await
463 .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
464 audios.push(audio);
465 }
466
467 RequestMessage::VisionChat {
468 messages,
469 images,
470 audios,
471 enable_thinking: oairequest.enable_thinking,
472 reasoning_effort,
473 }
474 } else {
475 RequestMessage::Chat {
476 messages,
477 enable_thinking: oairequest.enable_thinking,
478 reasoning_effort,
479 }
480 }
481 }
482 Either::Right(prompt) => {
483 let mut messages = Vec::new();
484 let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
485 IndexMap::new();
486 message_map.insert("role".to_string(), Either::Left("user".to_string()));
487 message_map.insert("content".to_string(), Either::Left(prompt));
488 messages.push(message_map);
489 RequestMessage::Chat {
490 messages,
491 enable_thinking: oairequest.enable_thinking,
492 reasoning_effort,
493 }
494 }
495 };
496
497 let dry_params = get_dry_sampling_params(
498 oairequest.dry_multiplier,
499 oairequest.dry_sequence_breakers,
500 oairequest.dry_base,
501 oairequest.dry_allowed_length,
502 )?;
503
504 let is_streaming = oairequest.stream.unwrap_or(false);
505
506 if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
507 anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
508 }
509
510 let constraint = match oairequest.grammar {
511 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
512 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
513 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
514 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
515 None => match oairequest.response_format {
516 Some(ResponseFormat::JsonSchema {
517 json_schema: JsonSchemaResponseFormat { name: _, schema },
518 }) => Constraint::JsonSchema(schema),
519 Some(ResponseFormat::Text) => Constraint::None,
520 None => Constraint::None,
521 },
522 };
523
524 Ok((
525 Request::Normal(Box::new(NormalRequest {
526 id: state.next_request_id(),
527 messages,
528 sampling_params: SamplingParams {
529 temperature: oairequest.temperature,
530 top_k: oairequest.top_k,
531 top_p: oairequest.top_p,
532 min_p: oairequest.min_p,
533 top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
534 frequency_penalty: oairequest.frequency_penalty,
535 presence_penalty: oairequest.presence_penalty,
536 repetition_penalty: oairequest.repetition_penalty,
537 max_len: oairequest.max_tokens,
538 stop_toks,
539 logits_bias: oairequest.logit_bias,
540 n_choices: oairequest.n_choices,
541 dry_params,
542 },
543 response: tx,
544 return_logprobs: oairequest.logprobs,
545 is_streaming,
546 suffix: None,
547 constraint,
548 tool_choice: oairequest.tool_choice,
549 tools: oairequest.tools,
550 logits_processors: None,
551 return_raw_logits: false,
552 web_search_options: oairequest.web_search_options,
553 model_id: if oairequest.model == "default" {
554 None
555 } else {
556 Some(oairequest.model.clone())
557 },
558 truncate_sequence: oairequest.truncate_sequence.unwrap_or(false),
559 })),
560 is_streaming,
561 ))
562}
563
564#[utoipa::path(
566 post,
567 tag = "Mistral.rs",
568 path = "/v1/chat/completions",
569 request_body = ChatCompletionRequest,
570 responses((status = 200, description = "Chat completions"))
571)]
572pub async fn chatcompletions(
573 State(state): ExtractedMistralRsState,
574 Json(oairequest): Json<ChatCompletionRequest>,
575) -> ChatCompletionResponder {
576 let (tx, mut rx) = create_response_channel(None);
577
578 let model_id = if oairequest.model == "default" {
580 None
581 } else {
582 Some(oairequest.model.clone())
583 };
584
585 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
586 Ok(x) => x,
587 Err(e) => return handle_error(state, e.into()),
588 };
589
590 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
591 return handle_error(state, e.into());
592 }
593
594 if is_streaming {
595 ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
596 } else {
597 process_non_streaming_response(&mut rx, state).await
598 }
599}
600
601pub fn handle_error(
603 state: SharedMistralRsState,
604 e: Box<dyn std::error::Error + Send + Sync + 'static>,
605) -> ChatCompletionResponder {
606 handle_completion_error(state, e)
607}
608
609pub fn create_streamer(
611 rx: Receiver<Response>,
612 state: SharedMistralRsState,
613 on_chunk: Option<ChatCompletionOnChunkCallback>,
614 on_done: Option<ChatCompletionOnDoneCallback>,
615) -> Sse<KeepAliveStream<ChatCompletionStreamer>> {
616 let streamer = base_create_streamer(rx, state, on_chunk, on_done);
617 let keep_alive_interval = get_keep_alive_interval();
618
619 Sse::new(streamer)
620 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
621}
622
623pub async fn process_non_streaming_response(
625 rx: &mut Receiver<Response>,
626 state: SharedMistralRsState,
627) -> ChatCompletionResponder {
628 base_process_non_streaming_response(rx, state, match_responses, handle_error).await
629}
630
631pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
633 match response {
634 Response::InternalError(e) => {
635 MistralRs::maybe_log_error(state, &*e);
636 ChatCompletionResponder::InternalError(e)
637 }
638 Response::ModelError(msg, response) => {
639 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
640 MistralRs::maybe_log_response(state, &response);
641 ChatCompletionResponder::ModelError(msg, response)
642 }
643 Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
644 Response::Done(response) => {
645 MistralRs::maybe_log_response(state, &response);
646 ChatCompletionResponder::Json(response)
647 }
648 Response::Chunk(_) => unreachable!(),
649 Response::CompletionDone(_) => unreachable!(),
650 Response::CompletionModelError(_, _) => unreachable!(),
651 Response::CompletionChunk(_) => unreachable!(),
652 Response::ImageGeneration(_) => unreachable!(),
653 Response::Speech { .. } => unreachable!(),
654 Response::Raw { .. } => unreachable!(),
655 Response::Embeddings { .. } => unreachable!(),
656 }
657}