1use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7 extract::{Json, Path, State},
8 http::{self, StatusCode},
9 response::{
10 sse::{Event, KeepAlive},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21 cached_responses::get_response_cache,
22 chat_completion::parse_request as parse_chat_request,
23 completion_core::{handle_completion_error, BaseCompletionResponder},
24 handler_core::{
25 create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26 JsonError, ModelErrorMessage,
27 },
28 openai::{
29 ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30 ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31 ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32 },
33 streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35 util::sanitize_error_message,
36};
37
38pub type ResponsesStreamer =
40 BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43 type Item = Result<Event, axum::Error>;
44
45 fn poll_next(
46 mut self: Pin<&mut Self>,
47 cx: &mut std::task::Context<'_>,
48 ) -> Poll<Option<Self::Item>> {
49 match self.done_state {
50 DoneState::SendingDone => {
51 self.done_state = DoneState::Done;
52 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53 }
54 DoneState::Done => {
55 if let Some(on_done) = &self.on_done {
56 on_done(&self.chunks);
57 }
58 return Poll::Ready(None);
59 }
60 DoneState::Running => (),
61 }
62
63 match self.rx.poll_recv(cx) {
64 Poll::Ready(Some(resp)) => match resp {
65 Response::ModelError(msg, _) => {
66 MistralRs::maybe_log_error(
67 self.state.clone(),
68 &ModelErrorMessage(msg.to_string()),
69 );
70 self.done_state = DoneState::SendingDone;
71 Poll::Ready(Some(Ok(Event::default().data(msg))))
72 }
73 Response::ValidationError(e) => Poll::Ready(Some(Ok(
74 Event::default().data(sanitize_error_message(e.as_ref()))
75 ))),
76 Response::InternalError(e) => {
77 MistralRs::maybe_log_error(self.state.clone(), &*e);
78 Poll::Ready(Some(Ok(
79 Event::default().data(sanitize_error_message(e.as_ref()))
80 )))
81 }
82 Response::Chunk(chat_chunk) => {
83 let mut delta_outputs = vec![];
85
86 let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89 for choice in &chat_chunk.choices {
90 let mut delta_content_items = Vec::new();
91
92 if let Some(content) = &choice.delta.content {
94 delta_content_items.push(ResponsesDeltaContent {
95 content_type: "output_text".to_string(),
96 text: Some(content.clone()),
97 });
98 }
99
100 if let Some(tool_calls) = &choice.delta.tool_calls {
102 for tool_call in tool_calls {
103 let tool_text = format!(
104 "Tool: {} args: {}",
105 tool_call.function.name, tool_call.function.arguments
106 );
107 delta_content_items.push(ResponsesDeltaContent {
108 content_type: "tool_use".to_string(),
109 text: Some(tool_text),
110 });
111 }
112 }
113
114 if !delta_content_items.is_empty() {
115 delta_outputs.push(ResponsesDeltaOutput {
116 id: format!("msg_{}", Uuid::new_v4()),
117 output_type: "message".to_string(),
118 content: Some(delta_content_items),
119 });
120 }
121 }
122
123 let mut response_chunk = ResponsesChunk {
124 id: chat_chunk.id.clone(),
125 object: "response.chunk",
126 created_at: chat_chunk.created as f64,
127 model: chat_chunk.model.clone(),
128 chunk_type: "delta".to_string(),
129 delta: Some(ResponsesDelta {
130 output: if delta_outputs.is_empty() {
131 None
132 } else {
133 Some(delta_outputs)
134 },
135 status: if all_finished {
136 Some("completed".to_string())
137 } else {
138 None
139 },
140 }),
141 usage: None,
142 metadata: None,
143 };
144
145 if all_finished {
146 self.done_state = DoneState::SendingDone;
147 }
148
149 MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151 if let Some(on_chunk) = &self.on_chunk {
152 response_chunk = on_chunk(response_chunk);
153 }
154
155 if self.store_chunks {
156 self.chunks.push(response_chunk.clone());
157 }
158
159 Poll::Ready(Some(Event::default().json_data(response_chunk)))
160 }
161 _ => unreachable!(),
162 },
163 Poll::Pending | Poll::Ready(None) => Poll::Pending,
164 }
165 }
166}
167
168pub type ResponsesResponder = BaseCompletionResponder<ResponsesObject, ResponsesStreamer>;
170
171type JsonModelError = BaseJsonModelError<ResponsesObject>;
172impl ErrorToResponse for JsonModelError {}
173
174impl IntoResponse for ResponsesResponder {
175 fn into_response(self) -> axum::response::Response {
176 match self {
177 ResponsesResponder::Sse(s) => s.into_response(),
178 ResponsesResponder::Json(s) => Json(s).into_response(),
179 ResponsesResponder::InternalError(e) => {
180 JsonError::new(sanitize_error_message(e.as_ref()))
181 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
182 }
183 ResponsesResponder::ValidationError(e) => {
184 JsonError::new(sanitize_error_message(e.as_ref()))
185 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
186 }
187 ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
188 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
189 }
190 }
191}
192
193fn chat_response_to_responses_object(
195 chat_resp: &ChatCompletionResponse,
196 request_id: String,
197 metadata: Option<Value>,
198) -> ResponsesObject {
199 let mut outputs = Vec::new();
200 let mut output_text_parts = Vec::new();
201
202 for choice in &chat_resp.choices {
203 let mut content_items = Vec::new();
204 let mut has_content = false;
205
206 if let Some(text) = &choice.message.content {
208 output_text_parts.push(text.clone());
209 content_items.push(ResponsesContent {
210 content_type: "output_text".to_string(),
211 text: Some(text.clone()),
212 annotations: None,
213 });
214 has_content = true;
215 }
216
217 if let Some(tool_calls) = &choice.message.tool_calls {
219 for tool_call in tool_calls {
220 let tool_text = format!(
221 "Tool call: {} with args: {}",
222 tool_call.function.name, tool_call.function.arguments
223 );
224 content_items.push(ResponsesContent {
225 content_type: "tool_use".to_string(),
226 text: Some(tool_text),
227 annotations: None,
228 });
229 has_content = true;
230 }
231 }
232
233 if has_content {
235 outputs.push(ResponsesOutput {
236 id: format!("msg_{}", Uuid::new_v4()),
237 output_type: "message".to_string(),
238 role: choice.message.role.clone(),
239 status: None,
240 content: content_items,
241 });
242 }
243 }
244
245 ResponsesObject {
246 id: request_id,
247 object: "response",
248 created_at: chat_resp.created as f64,
249 model: chat_resp.model.clone(),
250 status: "completed".to_string(),
251 output: outputs,
252 output_text: if output_text_parts.is_empty() {
253 None
254 } else {
255 Some(output_text_parts.join(" "))
256 },
257 usage: Some(ResponsesUsage {
258 input_tokens: chat_resp.usage.prompt_tokens,
259 output_tokens: chat_resp.usage.completion_tokens,
260 total_tokens: chat_resp.usage.total_tokens,
261 input_tokens_details: None,
262 output_tokens_details: None,
263 }),
264 error: None,
265 metadata,
266 instructions: None,
267 incomplete_details: None,
268 }
269}
270
271async fn parse_responses_request(
273 oairequest: ResponsesCreateRequest,
274 state: SharedMistralRsState,
275 tx: Sender<Response>,
276) -> Result<(Request, bool, Option<Vec<Message>>)> {
277 if oairequest.instructions.is_some() {
278 return Err(anyhow::anyhow!(
279 "The 'instructions' field is not supported in the Responses API"
280 ));
281 }
282 let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
284 let cache = get_response_cache();
285 cache.get_conversation_history(prev_id)?
286 } else {
287 None
288 };
289
290 let messages = oairequest.input.into_either();
292
293 let mut chat_request = ChatCompletionRequest {
295 messages: messages.clone(),
296 model: oairequest.model,
297 logit_bias: oairequest.logit_bias,
298 logprobs: oairequest.logprobs,
299 top_logprobs: oairequest.top_logprobs,
300 max_tokens: oairequest.max_tokens,
301 n_choices: oairequest.n_choices,
302 presence_penalty: oairequest.presence_penalty,
303 frequency_penalty: oairequest.frequency_penalty,
304 stop_seqs: oairequest.stop_seqs,
305 temperature: oairequest.temperature,
306 top_p: oairequest.top_p,
307 stream: oairequest.stream,
308 tools: oairequest.tools,
309 tool_choice: oairequest.tool_choice,
310 response_format: oairequest.response_format,
311 web_search_options: oairequest.web_search_options,
312 top_k: oairequest.top_k,
313 grammar: oairequest.grammar,
314 min_p: oairequest.min_p,
315 dry_multiplier: oairequest.dry_multiplier,
316 dry_base: oairequest.dry_base,
317 dry_allowed_length: oairequest.dry_allowed_length,
318 dry_sequence_breakers: oairequest.dry_sequence_breakers,
319 enable_thinking: oairequest.enable_thinking,
320 };
321
322 if let Some(prev_msgs) = previous_messages {
324 match &mut chat_request.messages {
325 Either::Left(msgs) => {
326 let mut combined = prev_msgs;
327 combined.extend(msgs.clone());
328 chat_request.messages = Either::Left(combined);
329 }
330 Either::Right(_) => {
331 let prompt = chat_request.messages.as_ref().right().unwrap().clone();
333 let mut combined = prev_msgs;
334 combined.push(Message {
335 content: Some(MessageContent::from_text(prompt)),
336 role: "user".to_string(),
337 name: None,
338 tool_calls: None,
339 });
340 chat_request.messages = Either::Left(combined);
341 }
342 }
343 }
344
345 let all_messages = match &chat_request.messages {
347 Either::Left(msgs) => msgs.clone(),
348 Either::Right(prompt) => vec![Message {
349 content: Some(MessageContent::from_text(prompt.clone())),
350 role: "user".to_string(),
351 name: None,
352 tool_calls: None,
353 }],
354 };
355
356 let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
357 Ok((request, is_streaming, Some(all_messages)))
358}
359
360#[utoipa::path(
362 post,
363 tag = "Mistral.rs",
364 path = "/v1/responses",
365 request_body = ResponsesCreateRequest,
366 responses((status = 200, description = "Response created"))
367)]
368pub async fn create_response(
369 State(state): ExtractedMistralRsState,
370 Json(oairequest): Json<ResponsesCreateRequest>,
371) -> ResponsesResponder {
372 let (tx, mut rx) = create_response_channel(None);
373 let request_id = format!("resp_{}", Uuid::new_v4());
374 let metadata = oairequest.metadata.clone();
375 let store = oairequest.store.unwrap_or(true);
376
377 let model_id = if oairequest.model == "default" {
379 None
380 } else {
381 Some(oairequest.model.clone())
382 };
383
384 let (request, is_streaming, conversation_history) =
385 match parse_responses_request(oairequest, state.clone(), tx).await {
386 Ok(x) => x,
387 Err(e) => return handle_error(state, e.into()),
388 };
389
390 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
391 return handle_error(state, e.into());
392 }
393
394 if is_streaming {
395 let streamer = ResponsesStreamer {
396 rx,
397 done_state: DoneState::Running,
398 state: state.clone(),
399 on_chunk: None,
400 on_done: None,
401 chunks: Vec::new(),
402 store_chunks: store,
403 };
404
405 if store {
407 let cache = get_response_cache();
408 let id = request_id.clone();
409 let chunks_cache = cache.clone();
410
411 let history_for_streaming = conversation_history.clone();
413 let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
414 let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
415
416 if let Some(history) = history_for_streaming.clone() {
418 let mut history = history;
419 let mut assistant_message = String::new();
420
421 for chunk in chunks {
423 if let Some(delta) = &chunk.delta {
424 if let Some(outputs) = &delta.output {
425 for output in outputs {
426 if let Some(contents) = &output.content {
427 for content in contents {
428 if let Some(text) = &content.text {
429 assistant_message.push_str(text);
430 }
431 }
432 }
433 }
434 }
435 }
436 }
437
438 if !assistant_message.is_empty() {
440 history.push(Message {
441 content: Some(MessageContent::from_text(assistant_message)),
442 role: "assistant".to_string(),
443 name: None,
444 tool_calls: None,
445 });
446 }
447
448 let _ = chunks_cache.store_conversation_history(id.clone(), history);
449 }
450 });
451
452 ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
453 } else {
454 ResponsesResponder::Sse(create_streamer(streamer, None))
455 }
456 } else {
457 match rx.recv().await {
459 Some(Response::Done(chat_resp)) => {
460 let response_obj =
461 chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
462
463 if store {
465 let cache = get_response_cache();
466 let _ = cache.store_response(request_id.clone(), response_obj.clone());
467
468 if let Some(mut history) = conversation_history.clone() {
470 for choice in &chat_resp.choices {
472 if let Some(content) = &choice.message.content {
473 history.push(Message {
474 content: Some(MessageContent::from_text(content.clone())),
475 role: choice.message.role.clone(),
476 name: None,
477 tool_calls: None, });
479 }
480 }
481 let _ = cache.store_conversation_history(request_id, history);
482 }
483 }
484
485 ResponsesResponder::Json(response_obj)
486 }
487 Some(Response::ModelError(msg, partial_resp)) => {
488 let mut response_obj =
489 chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
490 response_obj.error = Some(ResponsesError {
491 error_type: "model_error".to_string(),
492 message: msg.to_string(),
493 });
494 response_obj.status = "failed".to_string();
495
496 if store {
497 let cache = get_response_cache();
498 let _ = cache.store_response(request_id.clone(), response_obj.clone());
499
500 if let Some(mut history) = conversation_history.clone() {
502 for choice in &partial_resp.choices {
504 if let Some(content) = &choice.message.content {
505 history.push(Message {
506 content: Some(MessageContent::from_text(content.clone())),
507 role: choice.message.role.clone(),
508 name: None,
509 tool_calls: None, });
511 }
512 }
513 let _ = cache.store_conversation_history(request_id, history);
514 }
515 }
516 ResponsesResponder::ModelError(msg.to_string(), response_obj)
517 }
518 Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
519 Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
520 _ => ResponsesResponder::InternalError(
521 anyhow::anyhow!("Unexpected response type").into(),
522 ),
523 }
524 }
525}
526
527#[utoipa::path(
529 get,
530 tag = "Mistral.rs",
531 path = "/v1/responses/{response_id}",
532 params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
533 responses((status = 200, description = "Response object"))
534)]
535pub async fn get_response(
536 State(_state): ExtractedMistralRsState,
537 Path(response_id): Path<String>,
538) -> impl IntoResponse {
539 let cache = get_response_cache();
540
541 match cache.get_response(&response_id) {
542 Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
543 Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
544 .to_response(StatusCode::NOT_FOUND),
545 Err(e) => JsonError::new(format!(
546 "Error retrieving response: {}",
547 sanitize_error_message(&*e)
548 ))
549 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
550 }
551}
552
553#[utoipa::path(
555 delete,
556 tag = "Mistral.rs",
557 path = "/v1/responses/{response_id}",
558 params(("response_id" = String, Path, description = "The ID of the response to delete")),
559 responses((status = 200, description = "Response deleted"))
560)]
561pub async fn delete_response(
562 State(_state): ExtractedMistralRsState,
563 Path(response_id): Path<String>,
564) -> impl IntoResponse {
565 let cache = get_response_cache();
566
567 match cache.delete_response(&response_id) {
568 Ok(true) => (
569 StatusCode::OK,
570 Json(serde_json::json!({
571 "deleted": true,
572 "id": response_id,
573 "object": "response.deleted"
574 })),
575 )
576 .into_response(),
577 Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
578 .to_response(StatusCode::NOT_FOUND),
579 Err(e) => JsonError::new(format!(
580 "Error deleting response: {}",
581 sanitize_error_message(&*e)
582 ))
583 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
584 }
585}
586
587fn handle_error(
589 state: SharedMistralRsState,
590 e: Box<dyn std::error::Error + Send + Sync + 'static>,
591) -> ResponsesResponder {
592 handle_completion_error(state, e)
593}
594
595fn create_streamer(
597 streamer: ResponsesStreamer,
598 on_done: Option<OnDoneCallback<ResponsesChunk>>,
599) -> Sse<ResponsesStreamer> {
600 let keep_alive_interval = get_keep_alive_interval();
601
602 let streamer_with_callback = ResponsesStreamer {
603 on_done,
604 ..streamer
605 };
606
607 Sse::new(streamer_with_callback)
608 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
609}