1use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7 extract::{Json, Path, State},
8 http::{self, StatusCode},
9 response::{
10 sse::{Event, KeepAlive},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21 cached_responses::get_response_cache,
22 chat_completion::parse_request as parse_chat_request,
23 completion_core::{handle_completion_error, BaseCompletionResponder},
24 handler_core::{
25 create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26 JsonError, ModelErrorMessage,
27 },
28 openai::{
29 ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30 ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31 ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32 },
33 streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35 util::sanitize_error_message,
36};
37
38pub type ResponsesStreamer =
40 BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43 type Item = Result<Event, axum::Error>;
44
45 fn poll_next(
46 mut self: Pin<&mut Self>,
47 cx: &mut std::task::Context<'_>,
48 ) -> Poll<Option<Self::Item>> {
49 match self.done_state {
50 DoneState::SendingDone => {
51 self.done_state = DoneState::Done;
52 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53 }
54 DoneState::Done => {
55 if let Some(on_done) = &self.on_done {
56 on_done(&self.chunks);
57 }
58 return Poll::Ready(None);
59 }
60 DoneState::Running => (),
61 }
62
63 match self.rx.poll_recv(cx) {
64 Poll::Ready(Some(resp)) => match resp {
65 Response::ModelError(msg, _) => {
66 MistralRs::maybe_log_error(
67 self.state.clone(),
68 &ModelErrorMessage(msg.to_string()),
69 );
70 self.done_state = DoneState::SendingDone;
71 Poll::Ready(Some(Ok(Event::default().data(msg))))
72 }
73 Response::ValidationError(e) => Poll::Ready(Some(Ok(
74 Event::default().data(sanitize_error_message(e.as_ref()))
75 ))),
76 Response::InternalError(e) => {
77 MistralRs::maybe_log_error(self.state.clone(), &*e);
78 Poll::Ready(Some(Ok(
79 Event::default().data(sanitize_error_message(e.as_ref()))
80 )))
81 }
82 Response::Chunk(chat_chunk) => {
83 let mut delta_outputs = vec![];
85
86 let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89 for choice in &chat_chunk.choices {
90 let mut delta_content_items = Vec::new();
91
92 if let Some(content) = &choice.delta.content {
94 delta_content_items.push(ResponsesDeltaContent {
95 content_type: "output_text".to_string(),
96 text: Some(content.clone()),
97 });
98 }
99
100 if let Some(tool_calls) = &choice.delta.tool_calls {
102 for tool_call in tool_calls {
103 let tool_text = format!(
104 "Tool: {} args: {}",
105 tool_call.function.name, tool_call.function.arguments
106 );
107 delta_content_items.push(ResponsesDeltaContent {
108 content_type: "tool_use".to_string(),
109 text: Some(tool_text),
110 });
111 }
112 }
113
114 if !delta_content_items.is_empty() {
115 delta_outputs.push(ResponsesDeltaOutput {
116 id: format!("msg_{}", Uuid::new_v4()),
117 output_type: "message".to_string(),
118 content: Some(delta_content_items),
119 });
120 }
121 }
122
123 let mut response_chunk = ResponsesChunk {
124 id: chat_chunk.id.clone(),
125 object: "response.chunk",
126 created_at: chat_chunk.created as f64,
127 model: chat_chunk.model.clone(),
128 chunk_type: "delta".to_string(),
129 delta: Some(ResponsesDelta {
130 output: if delta_outputs.is_empty() {
131 None
132 } else {
133 Some(delta_outputs)
134 },
135 status: if all_finished {
136 Some("completed".to_string())
137 } else {
138 None
139 },
140 }),
141 usage: None,
142 metadata: None,
143 };
144
145 if all_finished {
146 self.done_state = DoneState::SendingDone;
147 }
148
149 MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151 if let Some(on_chunk) = &self.on_chunk {
152 response_chunk = on_chunk(response_chunk);
153 }
154
155 if self.store_chunks {
156 self.chunks.push(response_chunk.clone());
157 }
158
159 Poll::Ready(Some(Event::default().json_data(response_chunk)))
160 }
161 _ => unreachable!(),
162 },
163 Poll::Pending | Poll::Ready(None) => Poll::Pending,
164 }
165 }
166}
167
168pub type ResponsesResponder = BaseCompletionResponder<ResponsesObject, ResponsesStreamer>;
170
171type JsonModelError = BaseJsonModelError<ResponsesObject>;
172impl ErrorToResponse for JsonModelError {}
173
174impl IntoResponse for ResponsesResponder {
175 fn into_response(self) -> axum::response::Response {
176 match self {
177 ResponsesResponder::Sse(s) => s.into_response(),
178 ResponsesResponder::Json(s) => Json(s).into_response(),
179 ResponsesResponder::InternalError(e) => {
180 JsonError::new(sanitize_error_message(e.as_ref()))
181 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
182 }
183 ResponsesResponder::ValidationError(e) => {
184 JsonError::new(sanitize_error_message(e.as_ref()))
185 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
186 }
187 ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
188 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
189 }
190 }
191}
192
193fn chat_response_to_responses_object(
195 chat_resp: &ChatCompletionResponse,
196 request_id: String,
197 metadata: Option<Value>,
198) -> ResponsesObject {
199 let mut outputs = Vec::new();
200 let mut output_text_parts = Vec::new();
201
202 for choice in &chat_resp.choices {
203 let mut content_items = Vec::new();
204 let mut has_content = false;
205
206 if let Some(text) = &choice.message.content {
208 output_text_parts.push(text.clone());
209 content_items.push(ResponsesContent {
210 content_type: "output_text".to_string(),
211 text: Some(text.clone()),
212 annotations: None,
213 });
214 has_content = true;
215 }
216
217 if let Some(tool_calls) = &choice.message.tool_calls {
219 for tool_call in tool_calls {
220 let tool_text = format!(
221 "Tool call: {} with args: {}",
222 tool_call.function.name, tool_call.function.arguments
223 );
224 content_items.push(ResponsesContent {
225 content_type: "tool_use".to_string(),
226 text: Some(tool_text),
227 annotations: None,
228 });
229 has_content = true;
230 }
231 }
232
233 if has_content {
235 outputs.push(ResponsesOutput {
236 id: format!("msg_{}", Uuid::new_v4()),
237 output_type: "message".to_string(),
238 role: choice.message.role.clone(),
239 status: None,
240 content: content_items,
241 });
242 }
243 }
244
245 ResponsesObject {
246 id: request_id,
247 object: "response",
248 created_at: chat_resp.created as f64,
249 model: chat_resp.model.clone(),
250 status: "completed".to_string(),
251 output: outputs,
252 output_text: if output_text_parts.is_empty() {
253 None
254 } else {
255 Some(output_text_parts.join(" "))
256 },
257 usage: Some(ResponsesUsage {
258 input_tokens: chat_resp.usage.prompt_tokens,
259 output_tokens: chat_resp.usage.completion_tokens,
260 total_tokens: chat_resp.usage.total_tokens,
261 input_tokens_details: None,
262 output_tokens_details: None,
263 }),
264 error: None,
265 metadata,
266 instructions: None,
267 incomplete_details: None,
268 }
269}
270
271async fn parse_responses_request(
273 oairequest: ResponsesCreateRequest,
274 state: SharedMistralRsState,
275 tx: Sender<Response>,
276) -> Result<(Request, bool, Option<Vec<Message>>)> {
277 if oairequest.instructions.is_some() {
278 return Err(anyhow::anyhow!(
279 "The 'instructions' field is not supported in the Responses API"
280 ));
281 }
282 let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
284 let cache = get_response_cache();
285 cache.get_conversation_history(prev_id)?
286 } else {
287 None
288 };
289
290 let messages = oairequest.input.into_either();
292
293 let mut chat_request = ChatCompletionRequest {
295 messages: messages.clone(),
296 model: oairequest.model,
297 logit_bias: oairequest.logit_bias,
298 logprobs: oairequest.logprobs,
299 top_logprobs: oairequest.top_logprobs,
300 max_tokens: oairequest.max_tokens,
301 n_choices: oairequest.n_choices,
302 presence_penalty: oairequest.presence_penalty,
303 frequency_penalty: oairequest.frequency_penalty,
304 repetition_penalty: oairequest.repetition_penalty,
305 stop_seqs: oairequest.stop_seqs,
306 temperature: oairequest.temperature,
307 top_p: oairequest.top_p,
308 stream: oairequest.stream,
309 tools: oairequest.tools,
310 tool_choice: oairequest.tool_choice,
311 response_format: oairequest.response_format,
312 web_search_options: oairequest.web_search_options,
313 top_k: oairequest.top_k,
314 grammar: oairequest.grammar,
315 min_p: oairequest.min_p,
316 dry_multiplier: oairequest.dry_multiplier,
317 dry_base: oairequest.dry_base,
318 dry_allowed_length: oairequest.dry_allowed_length,
319 dry_sequence_breakers: oairequest.dry_sequence_breakers,
320 enable_thinking: oairequest.enable_thinking,
321 };
322
323 if let Some(prev_msgs) = previous_messages {
325 match &mut chat_request.messages {
326 Either::Left(msgs) => {
327 let mut combined = prev_msgs;
328 combined.extend(msgs.clone());
329 chat_request.messages = Either::Left(combined);
330 }
331 Either::Right(_) => {
332 let prompt = chat_request.messages.as_ref().right().unwrap().clone();
334 let mut combined = prev_msgs;
335 combined.push(Message {
336 content: Some(MessageContent::from_text(prompt)),
337 role: "user".to_string(),
338 name: None,
339 tool_calls: None,
340 });
341 chat_request.messages = Either::Left(combined);
342 }
343 }
344 }
345
346 let all_messages = match &chat_request.messages {
348 Either::Left(msgs) => msgs.clone(),
349 Either::Right(prompt) => vec![Message {
350 content: Some(MessageContent::from_text(prompt.clone())),
351 role: "user".to_string(),
352 name: None,
353 tool_calls: None,
354 }],
355 };
356
357 let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
358 Ok((request, is_streaming, Some(all_messages)))
359}
360
361#[utoipa::path(
363 post,
364 tag = "Mistral.rs",
365 path = "/v1/responses",
366 request_body = ResponsesCreateRequest,
367 responses((status = 200, description = "Response created"))
368)]
369pub async fn create_response(
370 State(state): ExtractedMistralRsState,
371 Json(oairequest): Json<ResponsesCreateRequest>,
372) -> ResponsesResponder {
373 let (tx, mut rx) = create_response_channel(None);
374 let request_id = format!("resp_{}", Uuid::new_v4());
375 let metadata = oairequest.metadata.clone();
376 let store = oairequest.store.unwrap_or(true);
377
378 let model_id = if oairequest.model == "default" {
380 None
381 } else {
382 Some(oairequest.model.clone())
383 };
384
385 let (request, is_streaming, conversation_history) =
386 match parse_responses_request(oairequest, state.clone(), tx).await {
387 Ok(x) => x,
388 Err(e) => return handle_error(state, e.into()),
389 };
390
391 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
392 return handle_error(state, e.into());
393 }
394
395 if is_streaming {
396 let streamer = ResponsesStreamer {
397 rx,
398 done_state: DoneState::Running,
399 state: state.clone(),
400 on_chunk: None,
401 on_done: None,
402 chunks: Vec::new(),
403 store_chunks: store,
404 };
405
406 if store {
408 let cache = get_response_cache();
409 let id = request_id.clone();
410 let chunks_cache = cache.clone();
411
412 let history_for_streaming = conversation_history.clone();
414 let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
415 let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
416
417 if let Some(history) = history_for_streaming.clone() {
419 let mut history = history;
420 let mut assistant_message = String::new();
421
422 for chunk in chunks {
424 if let Some(delta) = &chunk.delta {
425 if let Some(outputs) = &delta.output {
426 for output in outputs {
427 if let Some(contents) = &output.content {
428 for content in contents {
429 if let Some(text) = &content.text {
430 assistant_message.push_str(text);
431 }
432 }
433 }
434 }
435 }
436 }
437 }
438
439 if !assistant_message.is_empty() {
441 history.push(Message {
442 content: Some(MessageContent::from_text(assistant_message)),
443 role: "assistant".to_string(),
444 name: None,
445 tool_calls: None,
446 });
447 }
448
449 let _ = chunks_cache.store_conversation_history(id.clone(), history);
450 }
451 });
452
453 ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
454 } else {
455 ResponsesResponder::Sse(create_streamer(streamer, None))
456 }
457 } else {
458 match rx.recv().await {
460 Some(Response::Done(chat_resp)) => {
461 let response_obj =
462 chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
463
464 if store {
466 let cache = get_response_cache();
467 let _ = cache.store_response(request_id.clone(), response_obj.clone());
468
469 if let Some(mut history) = conversation_history.clone() {
471 for choice in &chat_resp.choices {
473 if let Some(content) = &choice.message.content {
474 history.push(Message {
475 content: Some(MessageContent::from_text(content.clone())),
476 role: choice.message.role.clone(),
477 name: None,
478 tool_calls: None, });
480 }
481 }
482 let _ = cache.store_conversation_history(request_id, history);
483 }
484 }
485
486 ResponsesResponder::Json(response_obj)
487 }
488 Some(Response::ModelError(msg, partial_resp)) => {
489 let mut response_obj =
490 chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
491 response_obj.error = Some(ResponsesError {
492 error_type: "model_error".to_string(),
493 message: msg.to_string(),
494 });
495 response_obj.status = "failed".to_string();
496
497 if store {
498 let cache = get_response_cache();
499 let _ = cache.store_response(request_id.clone(), response_obj.clone());
500
501 if let Some(mut history) = conversation_history.clone() {
503 for choice in &partial_resp.choices {
505 if let Some(content) = &choice.message.content {
506 history.push(Message {
507 content: Some(MessageContent::from_text(content.clone())),
508 role: choice.message.role.clone(),
509 name: None,
510 tool_calls: None, });
512 }
513 }
514 let _ = cache.store_conversation_history(request_id, history);
515 }
516 }
517 ResponsesResponder::ModelError(msg.to_string(), response_obj)
518 }
519 Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
520 Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
521 _ => ResponsesResponder::InternalError(
522 anyhow::anyhow!("Unexpected response type").into(),
523 ),
524 }
525 }
526}
527
528#[utoipa::path(
530 get,
531 tag = "Mistral.rs",
532 path = "/v1/responses/{response_id}",
533 params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
534 responses((status = 200, description = "Response object"))
535)]
536pub async fn get_response(
537 State(_state): ExtractedMistralRsState,
538 Path(response_id): Path<String>,
539) -> impl IntoResponse {
540 let cache = get_response_cache();
541
542 match cache.get_response(&response_id) {
543 Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
544 Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
545 .to_response(StatusCode::NOT_FOUND),
546 Err(e) => JsonError::new(format!(
547 "Error retrieving response: {}",
548 sanitize_error_message(&*e)
549 ))
550 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
551 }
552}
553
554#[utoipa::path(
556 delete,
557 tag = "Mistral.rs",
558 path = "/v1/responses/{response_id}",
559 params(("response_id" = String, Path, description = "The ID of the response to delete")),
560 responses((status = 200, description = "Response deleted"))
561)]
562pub async fn delete_response(
563 State(_state): ExtractedMistralRsState,
564 Path(response_id): Path<String>,
565) -> impl IntoResponse {
566 let cache = get_response_cache();
567
568 match cache.delete_response(&response_id) {
569 Ok(true) => (
570 StatusCode::OK,
571 Json(serde_json::json!({
572 "deleted": true,
573 "id": response_id,
574 "object": "response.deleted"
575 })),
576 )
577 .into_response(),
578 Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
579 .to_response(StatusCode::NOT_FOUND),
580 Err(e) => JsonError::new(format!(
581 "Error deleting response: {}",
582 sanitize_error_message(&*e)
583 ))
584 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
585 }
586}
587
588fn handle_error(
590 state: SharedMistralRsState,
591 e: Box<dyn std::error::Error + Send + Sync + 'static>,
592) -> ResponsesResponder {
593 handle_completion_error(state, e)
594}
595
596fn create_streamer(
598 streamer: ResponsesStreamer,
599 on_done: Option<OnDoneCallback<ResponsesChunk>>,
600) -> Sse<ResponsesStreamer> {
601 let keep_alive_interval = get_keep_alive_interval();
602
603 let streamer_with_callback = ResponsesStreamer {
604 on_done,
605 ..streamer
606 };
607
608 Sse::new(streamer_with_callback)
609 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
610}