1use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7 extract::{Json, State},
8 http::{self},
9 response::{
10 sse::{Event, KeepAlive},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18 ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19 Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25 completion_core::{
26 convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27 BaseCompletionResponder,
28 },
29 handler_core::{
30 base_process_non_streaming_response, create_response_channel, send_request_with_model,
31 BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32 },
33 openai::{
34 ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35 ResponseFormat,
36 },
37 streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39 util::{parse_audio_url, parse_image_url, sanitize_error_message, validate_model_name},
40};
41
42pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78pub type ChatCompletionStreamer = BaseStreamer<
83 ChatCompletionChunkResponse,
84 ChatCompletionOnChunkCallback,
85 ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89 type Item = Result<Event, axum::Error>;
90
91 fn poll_next(
99 mut self: Pin<&mut Self>,
100 cx: &mut std::task::Context<'_>,
101 ) -> Poll<Option<Self::Item>> {
102 match self.done_state {
103 DoneState::SendingDone => {
104 self.done_state = DoneState::Done;
107 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108 }
109 DoneState::Done => {
110 if let Some(on_done) = &self.on_done {
111 on_done(&self.chunks);
112 }
113 return Poll::Ready(None);
114 }
115 DoneState::Running => (),
116 }
117
118 match self.rx.poll_recv(cx) {
119 Poll::Ready(Some(resp)) => match resp {
120 Response::ModelError(msg, _) => {
121 MistralRs::maybe_log_error(
122 self.state.clone(),
123 &ModelErrorMessage(msg.to_string()),
124 );
125 self.done_state = DoneState::SendingDone;
127 Poll::Ready(Some(Ok(Event::default().data(msg))))
128 }
129 Response::ValidationError(e) => Poll::Ready(Some(Ok(
130 Event::default().data(sanitize_error_message(e.as_ref()))
131 ))),
132 Response::InternalError(e) => {
133 MistralRs::maybe_log_error(self.state.clone(), &*e);
134 Poll::Ready(Some(Ok(
135 Event::default().data(sanitize_error_message(e.as_ref()))
136 )))
137 }
138 Response::Chunk(mut response) => {
139 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
140 self.done_state = DoneState::SendingDone;
141 }
142 MistralRs::maybe_log_response(self.state.clone(), &response);
144
145 if let Some(on_chunk) = &self.on_chunk {
146 response = on_chunk(response);
147 }
148
149 if self.store_chunks {
150 self.chunks.push(response.clone());
151 }
152
153 Poll::Ready(Some(Event::default().json_data(response)))
154 }
155 Response::Done(_) => unreachable!(),
156 Response::CompletionDone(_) => unreachable!(),
157 Response::CompletionModelError(_, _) => unreachable!(),
158 Response::CompletionChunk(_) => unreachable!(),
159 Response::ImageGeneration(_) => unreachable!(),
160 Response::Speech { .. } => unreachable!(),
161 Response::Raw { .. } => unreachable!(),
162 },
163 Poll::Pending | Poll::Ready(None) => Poll::Pending,
164 }
165 }
166}
167
168pub type ChatCompletionResponder =
170 BaseCompletionResponder<ChatCompletionResponse, ChatCompletionStreamer>;
171
172type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
173impl ErrorToResponse for JsonModelError {}
174
175impl IntoResponse for ChatCompletionResponder {
176 fn into_response(self) -> axum::response::Response {
178 match self {
179 ChatCompletionResponder::Sse(s) => s.into_response(),
180 ChatCompletionResponder::Json(s) => Json(s).into_response(),
181 ChatCompletionResponder::InternalError(e) => {
182 JsonError::new(sanitize_error_message(e.as_ref()))
183 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
184 }
185 ChatCompletionResponder::ValidationError(e) => {
186 JsonError::new(sanitize_error_message(e.as_ref()))
187 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
188 }
189 ChatCompletionResponder::ModelError(msg, response) => {
190 JsonModelError::new(msg, response)
191 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
192 }
193 }
194 }
195}
196
197pub async fn parse_request(
202 oairequest: ChatCompletionRequest,
203 state: SharedMistralRsState,
204 tx: Sender<Response>,
205) -> Result<(Request, bool)> {
206 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
207 MistralRs::maybe_log_request(state.clone(), repr);
208
209 validate_model_name(&oairequest.model, state.clone())?;
211
212 let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
213
214 let messages = match oairequest.messages {
215 Either::Left(req_messages) => {
216 let mut messages = Vec::new();
217 let mut image_urls = Vec::new();
218 let mut audio_urls = Vec::new();
219 for message in req_messages {
220 let content = match message.content.as_deref() {
221 Some(content) => content.clone(),
222 None => {
223 let calls = message
225 .tool_calls
226 .as_ref()
227 .context(
228 "No content was provided, expected tool calls to be provided.",
229 )?
230 .iter()
231 .map(|call| &call.function)
232 .collect::<Vec<_>>();
233
234 Either::Left(serde_json::to_string(&calls)?)
235 }
236 };
237
238 match &content {
239 Either::Left(content) => {
240 let mut message_map: IndexMap<
241 String,
242 Either<String, Vec<IndexMap<String, Value>>>,
243 > = IndexMap::new();
244 message_map.insert("role".to_string(), Either::Left(message.role));
245 message_map.insert("content".to_string(), Either::Left(content.clone()));
246 messages.push(message_map);
247 }
248 Either::Right(image_messages) => {
249 if image_messages.len() == 1 {
253 if !image_messages[0].contains_key("text") {
254 anyhow::bail!("Expected `text` key in input message.");
255 }
256 let content = match image_messages[0]["text"].deref() {
257 Either::Left(left) => left.to_string(),
258 Either::Right(right) => format!("{right:?}"),
259 };
260 let mut message_map: IndexMap<
261 String,
262 Either<String, Vec<IndexMap<String, Value>>>,
263 > = IndexMap::new();
264 message_map.insert("role".to_string(), Either::Left(message.role));
265 message_map.insert("content".to_string(), Either::Left(content));
266 messages.push(message_map);
267 continue;
268 }
269 if message.role != "user" {
270 anyhow::bail!(
271 "Role for an image message must be `user`, but it is {}",
272 message.role
273 );
274 }
275
276 enum ContentPart {
277 Text { text: String },
278 Image { image_url: String },
279 Audio { audio_url: String },
280 }
281
282 let mut items = Vec::new();
283 for image_message in image_messages {
284 match image_message.get("type") {
285 Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
286 items.push(ContentPart::Text {
287 text: image_message
288 .get("text").as_ref()
289 .context("Text sub-content must have `text` key.")?.as_ref()
290 .left().context("Text sub-content `text` key must be a string.")?.clone(),
291 });
292 }
293 Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
294 items.push(ContentPart::Image {
295 image_url: image_message
296 .get("image_url")
297 .as_ref()
298 .context("Image sub-content must have `image_url` key.")?
299 .as_ref()
300 .right()
301 .context("Image sub-content `image_url` key must be an object.")?
302 .get("url")
303 .context("Image sub-content `image_url` object must have a `url` key.")?
304 .clone(),
305 });
306 }
307 Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
308 items.push(ContentPart::Audio {
309 audio_url: image_message
310 .get("audio_url")
311 .as_ref()
312 .context("Audio sub-content must have `audio_url` key.")?
313 .as_ref()
314 .right()
315 .context("Audio sub-content `audio_url` key must be an object.")?
316 .get("url")
317 .context("Audio sub-content `audio_url` object must have a `url` key.")?
318 .clone(),
319 });
320 }
321 _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
322 }
323 }
324
325 let text_content = items
326 .iter()
327 .filter_map(|item| match item {
328 ContentPart::Text { text } => Some(text),
329 _ => None,
330 })
331 .join(" ");
332 let image_urls_iter = items
333 .iter()
334 .filter_map(|item| match item {
335 ContentPart::Image { image_url } => Some(image_url.clone()),
336 _ => None,
337 })
338 .collect::<Vec<_>>();
339
340 let audio_urls_iter = items
341 .iter()
342 .filter_map(|item| match item {
343 ContentPart::Audio { audio_url } => Some(audio_url.clone()),
344 _ => None,
345 })
346 .collect::<Vec<_>>();
347
348 let mut message_map: IndexMap<
349 String,
350 Either<String, Vec<IndexMap<String, Value>>>,
351 > = IndexMap::new();
352 message_map.insert("role".to_string(), Either::Left(message.role));
353
354 let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
355 for _ in &image_urls_iter {
356 let mut content_image_map = IndexMap::new();
357 content_image_map
358 .insert("type".to_string(), Value::String("image".to_string()));
359 content_map.push(content_image_map);
360 }
361 for _ in &audio_urls_iter {
362 let mut content_audio_map = IndexMap::new();
363 content_audio_map
364 .insert("type".to_string(), Value::String("audio".to_string()));
365 content_map.push(content_audio_map);
366 }
367 {
368 let mut content_text_map = IndexMap::new();
369 content_text_map
370 .insert("type".to_string(), Value::String("text".to_string()));
371 content_text_map
372 .insert("text".to_string(), Value::String(text_content));
373 content_map.push(content_text_map);
374 }
375
376 message_map.insert("content".to_string(), Either::Right(content_map));
377 messages.push(message_map);
378 image_urls.extend(image_urls_iter);
379 audio_urls.extend(audio_urls_iter);
380 }
381 }
382 }
383 if !image_urls.is_empty() || !audio_urls.is_empty() {
384 let mut images = Vec::new();
386 for url_unparsed in image_urls {
387 let image = parse_image_url(&url_unparsed)
388 .await
389 .context(format!("Failed to parse image resource: {url_unparsed}"))?;
390 images.push(image);
391 }
392
393 let mut audios = Vec::new();
395 for url_unparsed in audio_urls {
396 let audio = parse_audio_url(&url_unparsed)
397 .await
398 .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
399 audios.push(audio);
400 }
401
402 RequestMessage::VisionChat {
403 messages,
404 images,
405 audios,
406 enable_thinking: oairequest.enable_thinking,
407 }
408 } else {
409 RequestMessage::Chat {
410 messages,
411 enable_thinking: oairequest.enable_thinking,
412 }
413 }
414 }
415 Either::Right(prompt) => {
416 let mut messages = Vec::new();
417 let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
418 IndexMap::new();
419 message_map.insert("role".to_string(), Either::Left("user".to_string()));
420 message_map.insert("content".to_string(), Either::Left(prompt));
421 messages.push(message_map);
422 RequestMessage::Chat {
423 messages,
424 enable_thinking: oairequest.enable_thinking,
425 }
426 }
427 };
428
429 let dry_params = get_dry_sampling_params(
430 oairequest.dry_multiplier,
431 oairequest.dry_sequence_breakers,
432 oairequest.dry_base,
433 oairequest.dry_allowed_length,
434 )?;
435
436 let is_streaming = oairequest.stream.unwrap_or(false);
437
438 if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
439 anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
440 }
441
442 let constraint = match oairequest.grammar {
443 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
444 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
445 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
446 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
447 None => match oairequest.response_format {
448 Some(ResponseFormat::JsonSchema {
449 json_schema: JsonSchemaResponseFormat { name: _, schema },
450 }) => Constraint::JsonSchema(schema),
451 Some(ResponseFormat::Text) => Constraint::None,
452 None => Constraint::None,
453 },
454 };
455
456 Ok((
457 Request::Normal(Box::new(NormalRequest {
458 id: state.next_request_id(),
459 messages,
460 sampling_params: SamplingParams {
461 temperature: oairequest.temperature,
462 top_k: oairequest.top_k,
463 top_p: oairequest.top_p,
464 min_p: oairequest.min_p,
465 top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
466 frequency_penalty: oairequest.frequency_penalty,
467 presence_penalty: oairequest.presence_penalty,
468 max_len: oairequest.max_tokens,
469 stop_toks,
470 logits_bias: oairequest.logit_bias,
471 n_choices: oairequest.n_choices,
472 dry_params,
473 },
474 response: tx,
475 return_logprobs: oairequest.logprobs,
476 is_streaming,
477 suffix: None,
478 constraint,
479 tool_choice: oairequest.tool_choice,
480 tools: oairequest.tools,
481 logits_processors: None,
482 return_raw_logits: false,
483 web_search_options: oairequest.web_search_options,
484 model_id: if oairequest.model == "default" {
485 None
486 } else {
487 Some(oairequest.model.clone())
488 },
489 })),
490 is_streaming,
491 ))
492}
493
494#[utoipa::path(
496 post,
497 tag = "Mistral.rs",
498 path = "/v1/chat/completions",
499 request_body = ChatCompletionRequest,
500 responses((status = 200, description = "Chat completions"))
501)]
502pub async fn chatcompletions(
503 State(state): ExtractedMistralRsState,
504 Json(oairequest): Json<ChatCompletionRequest>,
505) -> ChatCompletionResponder {
506 let (tx, mut rx) = create_response_channel(None);
507
508 let model_id = if oairequest.model == "default" {
510 None
511 } else {
512 Some(oairequest.model.clone())
513 };
514
515 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
516 Ok(x) => x,
517 Err(e) => return handle_error(state, e.into()),
518 };
519
520 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
521 return handle_error(state, e.into());
522 }
523
524 if is_streaming {
525 ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
526 } else {
527 process_non_streaming_response(&mut rx, state).await
528 }
529}
530
531pub fn handle_error(
533 state: SharedMistralRsState,
534 e: Box<dyn std::error::Error + Send + Sync + 'static>,
535) -> ChatCompletionResponder {
536 handle_completion_error(state, e)
537}
538
539pub fn create_streamer(
541 rx: Receiver<Response>,
542 state: SharedMistralRsState,
543 on_chunk: Option<ChatCompletionOnChunkCallback>,
544 on_done: Option<ChatCompletionOnDoneCallback>,
545) -> Sse<ChatCompletionStreamer> {
546 let streamer = base_create_streamer(rx, state, on_chunk, on_done);
547 let keep_alive_interval = get_keep_alive_interval();
548
549 Sse::new(streamer)
550 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
551}
552
553pub async fn process_non_streaming_response(
555 rx: &mut Receiver<Response>,
556 state: SharedMistralRsState,
557) -> ChatCompletionResponder {
558 base_process_non_streaming_response(rx, state, match_responses, handle_error).await
559}
560
561pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
563 match response {
564 Response::InternalError(e) => {
565 MistralRs::maybe_log_error(state, &*e);
566 ChatCompletionResponder::InternalError(e)
567 }
568 Response::ModelError(msg, response) => {
569 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
570 MistralRs::maybe_log_response(state, &response);
571 ChatCompletionResponder::ModelError(msg, response)
572 }
573 Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
574 Response::Done(response) => {
575 MistralRs::maybe_log_response(state, &response);
576 ChatCompletionResponder::Json(response)
577 }
578 Response::Chunk(_) => unreachable!(),
579 Response::CompletionDone(_) => unreachable!(),
580 Response::CompletionModelError(_, _) => unreachable!(),
581 Response::CompletionChunk(_) => unreachable!(),
582 Response::ImageGeneration(_) => unreachable!(),
583 Response::Speech { .. } => unreachable!(),
584 Response::Raw { .. } => unreachable!(),
585 }
586}