1use std::{ops::Deref, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::{Context, Result};
6use axum::{
7 extract::{Json, State},
8 http::{self},
9 response::{
10 sse::{Event, KeepAlive},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use indexmap::IndexMap;
16use itertools::Itertools;
17use mistralrs_core::{
18 ChatCompletionChunkResponse, ChatCompletionResponse, Constraint, MistralRs, NormalRequest,
19 Request, RequestMessage, Response, SamplingParams,
20};
21use serde_json::Value;
22use tokio::sync::mpsc::{Receiver, Sender};
23
24use crate::{
25 completion_core::{
26 convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
27 BaseCompletionResponder,
28 },
29 handler_core::{
30 base_process_non_streaming_response, create_response_channel, send_request_with_model,
31 BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
32 },
33 openai::{
34 ChatCompletionRequest, Grammar, JsonSchemaResponseFormat, MessageInnerContent,
35 ResponseFormat,
36 },
37 streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
38 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
39 util::{parse_audio_url, parse_image_url, validate_model_name},
40};
41
42pub type ChatCompletionOnChunkCallback = OnChunkCallback<ChatCompletionChunkResponse>;
60
61pub type ChatCompletionOnDoneCallback = OnDoneCallback<ChatCompletionChunkResponse>;
77
78pub type ChatCompletionStreamer = BaseStreamer<
83 ChatCompletionChunkResponse,
84 ChatCompletionOnChunkCallback,
85 ChatCompletionOnDoneCallback,
86>;
87
88impl futures::Stream for ChatCompletionStreamer {
89 type Item = Result<Event, axum::Error>;
90
91 fn poll_next(
99 mut self: Pin<&mut Self>,
100 cx: &mut std::task::Context<'_>,
101 ) -> Poll<Option<Self::Item>> {
102 match self.done_state {
103 DoneState::SendingDone => {
104 self.done_state = DoneState::Done;
107 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
108 }
109 DoneState::Done => {
110 if let Some(on_done) = &self.on_done {
111 on_done(&self.chunks);
112 }
113 return Poll::Ready(None);
114 }
115 DoneState::Running => (),
116 }
117
118 match self.rx.poll_recv(cx) {
119 Poll::Ready(Some(resp)) => match resp {
120 Response::ModelError(msg, _) => {
121 MistralRs::maybe_log_error(
122 self.state.clone(),
123 &ModelErrorMessage(msg.to_string()),
124 );
125 self.done_state = DoneState::SendingDone;
127 Poll::Ready(Some(Ok(Event::default().data(msg))))
128 }
129 Response::ValidationError(e) => {
130 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
131 }
132 Response::InternalError(e) => {
133 MistralRs::maybe_log_error(self.state.clone(), &*e);
134 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
135 }
136 Response::Chunk(mut response) => {
137 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
138 self.done_state = DoneState::SendingDone;
139 }
140 MistralRs::maybe_log_response(self.state.clone(), &response);
142
143 if let Some(on_chunk) = &self.on_chunk {
144 response = on_chunk(response);
145 }
146
147 if self.store_chunks {
148 self.chunks.push(response.clone());
149 }
150
151 Poll::Ready(Some(Event::default().json_data(response)))
152 }
153 Response::Done(_) => unreachable!(),
154 Response::CompletionDone(_) => unreachable!(),
155 Response::CompletionModelError(_, _) => unreachable!(),
156 Response::CompletionChunk(_) => unreachable!(),
157 Response::ImageGeneration(_) => unreachable!(),
158 Response::Speech { .. } => unreachable!(),
159 Response::Raw { .. } => unreachable!(),
160 },
161 Poll::Pending | Poll::Ready(None) => Poll::Pending,
162 }
163 }
164}
165
166pub type ChatCompletionResponder =
168 BaseCompletionResponder<ChatCompletionResponse, ChatCompletionStreamer>;
169
170type JsonModelError = BaseJsonModelError<ChatCompletionResponse>;
171impl ErrorToResponse for JsonModelError {}
172
173impl IntoResponse for ChatCompletionResponder {
174 fn into_response(self) -> axum::response::Response {
176 match self {
177 ChatCompletionResponder::Sse(s) => s.into_response(),
178 ChatCompletionResponder::Json(s) => Json(s).into_response(),
179 ChatCompletionResponder::InternalError(e) => {
180 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
181 }
182 ChatCompletionResponder::ValidationError(e) => {
183 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
184 }
185 ChatCompletionResponder::ModelError(msg, response) => {
186 JsonModelError::new(msg, response)
187 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
188 }
189 }
190 }
191}
192
193pub async fn parse_request(
198 oairequest: ChatCompletionRequest,
199 state: SharedMistralRsState,
200 tx: Sender<Response>,
201) -> Result<(Request, bool)> {
202 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
203 MistralRs::maybe_log_request(state.clone(), repr);
204
205 validate_model_name(&oairequest.model, state.clone())?;
207
208 let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
209
210 let messages = match oairequest.messages {
211 Either::Left(req_messages) => {
212 let mut messages = Vec::new();
213 let mut image_urls = Vec::new();
214 let mut audio_urls = Vec::new();
215 for message in req_messages {
216 let content = match message.content.as_deref() {
217 Some(content) => content.clone(),
218 None => {
219 let calls = message
221 .tool_calls
222 .as_ref()
223 .context(
224 "No content was provided, expected tool calls to be provided.",
225 )?
226 .iter()
227 .map(|call| &call.function)
228 .collect::<Vec<_>>();
229
230 Either::Left(serde_json::to_string(&calls)?)
231 }
232 };
233
234 match &content {
235 Either::Left(content) => {
236 let mut message_map: IndexMap<
237 String,
238 Either<String, Vec<IndexMap<String, Value>>>,
239 > = IndexMap::new();
240 message_map.insert("role".to_string(), Either::Left(message.role));
241 message_map.insert("content".to_string(), Either::Left(content.clone()));
242 messages.push(message_map);
243 }
244 Either::Right(image_messages) => {
245 if image_messages.len() == 1 {
249 if !image_messages[0].contains_key("text") {
250 anyhow::bail!("Expected `text` key in input message.");
251 }
252 let content = match image_messages[0]["text"].deref() {
253 Either::Left(left) => left.to_string(),
254 Either::Right(right) => format!("{right:?}"),
255 };
256 let mut message_map: IndexMap<
257 String,
258 Either<String, Vec<IndexMap<String, Value>>>,
259 > = IndexMap::new();
260 message_map.insert("role".to_string(), Either::Left(message.role));
261 message_map.insert("content".to_string(), Either::Left(content));
262 messages.push(message_map);
263 continue;
264 }
265 if message.role != "user" {
266 anyhow::bail!(
267 "Role for an image message must be `user`, but it is {}",
268 message.role
269 );
270 }
271
272 enum ContentPart {
273 Text { text: String },
274 Image { image_url: String },
275 Audio { audio_url: String },
276 }
277
278 let mut items = Vec::new();
279 for image_message in image_messages {
280 match image_message.get("type") {
281 Some(MessageInnerContent(Either::Left(x))) if x == "text" => {
282 items.push(ContentPart::Text {
283 text: image_message
284 .get("text").as_ref()
285 .context("Text sub-content must have `text` key.")?.as_ref()
286 .left().context("Text sub-content `text` key must be a string.")?.clone(),
287 });
288 }
289 Some(MessageInnerContent(Either::Left(x))) if x == "image_url" => {
290 items.push(ContentPart::Image {
291 image_url: image_message
292 .get("image_url")
293 .as_ref()
294 .context("Image sub-content must have `image_url` key.")?
295 .as_ref()
296 .right()
297 .context("Image sub-content `image_url` key must be an object.")?
298 .get("url")
299 .context("Image sub-content `image_url` object must have a `url` key.")?
300 .clone(),
301 });
302 }
303 Some(MessageInnerContent(Either::Left(x))) if x == "audio_url" => {
304 items.push(ContentPart::Audio {
305 audio_url: image_message
306 .get("audio_url")
307 .as_ref()
308 .context("Audio sub-content must have `audio_url` key.")?
309 .as_ref()
310 .right()
311 .context("Audio sub-content `audio_url` key must be an object.")?
312 .get("url")
313 .context("Audio sub-content `audio_url` object must have a `url` key.")?
314 .clone(),
315 });
316 }
317 _ => anyhow::bail!("Expected array content sub-content to be of format {{`type`: `text`, `text`: ...}} and {{`type`: `url`, `image_url`: {{`url`: ...}}}}")
318 }
319 }
320
321 let text_content = items
322 .iter()
323 .filter_map(|item| match item {
324 ContentPart::Text { text } => Some(text),
325 _ => None,
326 })
327 .join(" ");
328 let image_urls_iter = items
329 .iter()
330 .filter_map(|item| match item {
331 ContentPart::Image { image_url } => Some(image_url.clone()),
332 _ => None,
333 })
334 .collect::<Vec<_>>();
335
336 let audio_urls_iter = items
337 .iter()
338 .filter_map(|item| match item {
339 ContentPart::Audio { audio_url } => Some(audio_url.clone()),
340 _ => None,
341 })
342 .collect::<Vec<_>>();
343
344 let mut message_map: IndexMap<
345 String,
346 Either<String, Vec<IndexMap<String, Value>>>,
347 > = IndexMap::new();
348 message_map.insert("role".to_string(), Either::Left(message.role));
349
350 let mut content_map: Vec<IndexMap<String, Value>> = Vec::new();
351 for _ in &image_urls_iter {
352 let mut content_image_map = IndexMap::new();
353 content_image_map
354 .insert("type".to_string(), Value::String("image".to_string()));
355 content_map.push(content_image_map);
356 }
357 for _ in &audio_urls_iter {
358 let mut content_audio_map = IndexMap::new();
359 content_audio_map
360 .insert("type".to_string(), Value::String("audio".to_string()));
361 content_map.push(content_audio_map);
362 }
363 {
364 let mut content_text_map = IndexMap::new();
365 content_text_map
366 .insert("type".to_string(), Value::String("text".to_string()));
367 content_text_map
368 .insert("text".to_string(), Value::String(text_content));
369 content_map.push(content_text_map);
370 }
371
372 message_map.insert("content".to_string(), Either::Right(content_map));
373 messages.push(message_map);
374 image_urls.extend(image_urls_iter);
375 audio_urls.extend(audio_urls_iter);
376 }
377 }
378 }
379 if !image_urls.is_empty() || !audio_urls.is_empty() {
380 let mut images = Vec::new();
382 for url_unparsed in image_urls {
383 let image = parse_image_url(&url_unparsed)
384 .await
385 .context(format!("Failed to parse image resource: {url_unparsed}"))?;
386 images.push(image);
387 }
388
389 let mut audios = Vec::new();
391 for url_unparsed in audio_urls {
392 let audio = parse_audio_url(&url_unparsed)
393 .await
394 .context(format!("Failed to parse audio resource: {url_unparsed}"))?;
395 audios.push(audio);
396 }
397
398 RequestMessage::VisionChat {
399 messages,
400 images,
401 audios,
402 enable_thinking: oairequest.enable_thinking,
403 }
404 } else {
405 RequestMessage::Chat {
406 messages,
407 enable_thinking: oairequest.enable_thinking,
408 }
409 }
410 }
411 Either::Right(prompt) => {
412 let mut messages = Vec::new();
413 let mut message_map: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
414 IndexMap::new();
415 message_map.insert("role".to_string(), Either::Left("user".to_string()));
416 message_map.insert("content".to_string(), Either::Left(prompt));
417 messages.push(message_map);
418 RequestMessage::Chat {
419 messages,
420 enable_thinking: oairequest.enable_thinking,
421 }
422 }
423 };
424
425 let dry_params = get_dry_sampling_params(
426 oairequest.dry_multiplier,
427 oairequest.dry_sequence_breakers,
428 oairequest.dry_base,
429 oairequest.dry_allowed_length,
430 )?;
431
432 let is_streaming = oairequest.stream.unwrap_or(false);
433
434 if oairequest.grammar.is_some() && oairequest.response_format.is_some() {
435 anyhow::bail!("Request `grammar` and `response_format` were both provided but are mutually exclusive.")
436 }
437
438 let constraint = match oairequest.grammar {
439 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
440 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
441 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
442 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
443 None => match oairequest.response_format {
444 Some(ResponseFormat::JsonSchema {
445 json_schema: JsonSchemaResponseFormat { name: _, schema },
446 }) => Constraint::JsonSchema(schema),
447 Some(ResponseFormat::Text) => Constraint::None,
448 None => Constraint::None,
449 },
450 };
451
452 Ok((
453 Request::Normal(Box::new(NormalRequest {
454 id: state.next_request_id(),
455 messages,
456 sampling_params: SamplingParams {
457 temperature: oairequest.temperature,
458 top_k: oairequest.top_k,
459 top_p: oairequest.top_p,
460 min_p: oairequest.min_p,
461 top_n_logprobs: oairequest.top_logprobs.unwrap_or(1),
462 frequency_penalty: oairequest.frequency_penalty,
463 presence_penalty: oairequest.presence_penalty,
464 max_len: oairequest.max_tokens,
465 stop_toks,
466 logits_bias: oairequest.logit_bias,
467 n_choices: oairequest.n_choices,
468 dry_params,
469 },
470 response: tx,
471 return_logprobs: oairequest.logprobs,
472 is_streaming,
473 suffix: None,
474 constraint,
475 tool_choice: oairequest.tool_choice,
476 tools: oairequest.tools,
477 logits_processors: None,
478 return_raw_logits: false,
479 web_search_options: oairequest.web_search_options,
480 model_id: if oairequest.model == "default" {
481 None
482 } else {
483 Some(oairequest.model.clone())
484 },
485 })),
486 is_streaming,
487 ))
488}
489
490#[utoipa::path(
492 post,
493 tag = "Mistral.rs",
494 path = "/v1/chat/completions",
495 request_body = ChatCompletionRequest,
496 responses((status = 200, description = "Chat completions"))
497)]
498pub async fn chatcompletions(
499 State(state): ExtractedMistralRsState,
500 Json(oairequest): Json<ChatCompletionRequest>,
501) -> ChatCompletionResponder {
502 let (tx, mut rx) = create_response_channel(None);
503
504 let model_id = if oairequest.model == "default" {
506 None
507 } else {
508 Some(oairequest.model.clone())
509 };
510
511 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx).await {
512 Ok(x) => x,
513 Err(e) => return handle_error(state, e.into()),
514 };
515
516 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
517 return handle_error(state, e.into());
518 }
519
520 if is_streaming {
521 ChatCompletionResponder::Sse(create_streamer(rx, state, None, None))
522 } else {
523 process_non_streaming_response(&mut rx, state).await
524 }
525}
526
527pub fn handle_error(
529 state: SharedMistralRsState,
530 e: Box<dyn std::error::Error + Send + Sync + 'static>,
531) -> ChatCompletionResponder {
532 handle_completion_error(state, e)
533}
534
535pub fn create_streamer(
537 rx: Receiver<Response>,
538 state: SharedMistralRsState,
539 on_chunk: Option<ChatCompletionOnChunkCallback>,
540 on_done: Option<ChatCompletionOnDoneCallback>,
541) -> Sse<ChatCompletionStreamer> {
542 let streamer = base_create_streamer(rx, state, on_chunk, on_done);
543 let keep_alive_interval = get_keep_alive_interval();
544
545 Sse::new(streamer)
546 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
547}
548
549pub async fn process_non_streaming_response(
551 rx: &mut Receiver<Response>,
552 state: SharedMistralRsState,
553) -> ChatCompletionResponder {
554 base_process_non_streaming_response(rx, state, match_responses, handle_error).await
555}
556
557pub fn match_responses(state: SharedMistralRsState, response: Response) -> ChatCompletionResponder {
559 match response {
560 Response::InternalError(e) => {
561 MistralRs::maybe_log_error(state, &*e);
562 ChatCompletionResponder::InternalError(e)
563 }
564 Response::ModelError(msg, response) => {
565 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
566 MistralRs::maybe_log_response(state, &response);
567 ChatCompletionResponder::ModelError(msg, response)
568 }
569 Response::ValidationError(e) => ChatCompletionResponder::ValidationError(e),
570 Response::Done(response) => {
571 MistralRs::maybe_log_response(state, &response);
572 ChatCompletionResponder::Json(response)
573 }
574 Response::Chunk(_) => unreachable!(),
575 Response::CompletionDone(_) => unreachable!(),
576 Response::CompletionModelError(_, _) => unreachable!(),
577 Response::CompletionChunk(_) => unreachable!(),
578 Response::ImageGeneration(_) => unreachable!(),
579 Response::Speech { .. } => unreachable!(),
580 Response::Raw { .. } => unreachable!(),
581 }
582}