1use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7 extract::{Json, Path, State},
8 http::{self, StatusCode},
9 response::{
10 sse::{Event, KeepAlive, KeepAliveStream},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21 cached_responses::get_response_cache,
22 chat_completion::parse_request as parse_chat_request,
23 completion_core::{handle_completion_error, BaseCompletionResponder},
24 handler_core::{
25 create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26 JsonError, ModelErrorMessage,
27 },
28 openai::{
29 ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30 ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31 ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32 },
33 streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35 util::sanitize_error_message,
36};
37
38pub type ResponsesStreamer =
40 BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43 type Item = Result<Event, axum::Error>;
44
45 fn poll_next(
46 mut self: Pin<&mut Self>,
47 cx: &mut std::task::Context<'_>,
48 ) -> Poll<Option<Self::Item>> {
49 match self.done_state {
50 DoneState::SendingDone => {
51 self.done_state = DoneState::Done;
52 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53 }
54 DoneState::Done => {
55 if let Some(on_done) = &self.on_done {
56 on_done(&self.chunks);
57 }
58 return Poll::Ready(None);
59 }
60 DoneState::Running => (),
61 }
62
63 match self.rx.poll_recv(cx) {
64 Poll::Ready(Some(resp)) => match resp {
65 Response::ModelError(msg, _) => {
66 MistralRs::maybe_log_error(
67 self.state.clone(),
68 &ModelErrorMessage(msg.to_string()),
69 );
70 self.done_state = DoneState::SendingDone;
71 Poll::Ready(Some(Ok(Event::default().data(msg))))
72 }
73 Response::ValidationError(e) => Poll::Ready(Some(Ok(
74 Event::default().data(sanitize_error_message(e.as_ref()))
75 ))),
76 Response::InternalError(e) => {
77 MistralRs::maybe_log_error(self.state.clone(), &*e);
78 Poll::Ready(Some(Ok(
79 Event::default().data(sanitize_error_message(e.as_ref()))
80 )))
81 }
82 Response::Chunk(chat_chunk) => {
83 let mut delta_outputs = vec![];
85
86 let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89 for choice in &chat_chunk.choices {
90 let mut delta_content_items = Vec::new();
91
92 if let Some(content) = &choice.delta.content {
94 delta_content_items.push(ResponsesDeltaContent {
95 content_type: "output_text".to_string(),
96 text: Some(content.clone()),
97 });
98 }
99
100 if let Some(tool_calls) = &choice.delta.tool_calls {
102 for tool_call in tool_calls {
103 let tool_text = format!(
104 "Tool: {} args: {}",
105 tool_call.function.name, tool_call.function.arguments
106 );
107 delta_content_items.push(ResponsesDeltaContent {
108 content_type: "tool_use".to_string(),
109 text: Some(tool_text),
110 });
111 }
112 }
113
114 if !delta_content_items.is_empty() {
115 delta_outputs.push(ResponsesDeltaOutput {
116 id: format!("msg_{}", Uuid::new_v4()),
117 output_type: "message".to_string(),
118 content: Some(delta_content_items),
119 });
120 }
121 }
122
123 let mut response_chunk = ResponsesChunk {
124 id: chat_chunk.id.clone(),
125 object: "response.chunk",
126 created_at: chat_chunk.created as f64,
127 model: chat_chunk.model.clone(),
128 chunk_type: "delta".to_string(),
129 delta: Some(ResponsesDelta {
130 output: if delta_outputs.is_empty() {
131 None
132 } else {
133 Some(delta_outputs)
134 },
135 status: if all_finished {
136 Some("completed".to_string())
137 } else {
138 None
139 },
140 }),
141 usage: None,
142 metadata: None,
143 };
144
145 if all_finished {
146 self.done_state = DoneState::SendingDone;
147 }
148
149 MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151 if let Some(on_chunk) = &self.on_chunk {
152 response_chunk = on_chunk(response_chunk);
153 }
154
155 if self.store_chunks {
156 self.chunks.push(response_chunk.clone());
157 }
158
159 Poll::Ready(Some(Event::default().json_data(response_chunk)))
160 }
161 _ => unreachable!(),
162 },
163 Poll::Pending | Poll::Ready(None) => Poll::Pending,
164 }
165 }
166}
167
168pub type ResponsesResponder =
170 BaseCompletionResponder<ResponsesObject, KeepAliveStream<ResponsesStreamer>>;
171
172type JsonModelError = BaseJsonModelError<ResponsesObject>;
173impl ErrorToResponse for JsonModelError {}
174
175impl IntoResponse for ResponsesResponder {
176 fn into_response(self) -> axum::response::Response {
177 match self {
178 ResponsesResponder::Sse(s) => s.into_response(),
179 ResponsesResponder::Json(s) => Json(s).into_response(),
180 ResponsesResponder::InternalError(e) => {
181 JsonError::new(sanitize_error_message(e.as_ref()))
182 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
183 }
184 ResponsesResponder::ValidationError(e) => {
185 JsonError::new(sanitize_error_message(e.as_ref()))
186 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
187 }
188 ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
189 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
190 }
191 }
192}
193
194fn chat_response_to_responses_object(
196 chat_resp: &ChatCompletionResponse,
197 request_id: String,
198 metadata: Option<Value>,
199) -> ResponsesObject {
200 let mut outputs = Vec::new();
201 let mut output_text_parts = Vec::new();
202
203 for choice in &chat_resp.choices {
204 let mut content_items = Vec::new();
205 let mut has_content = false;
206
207 if let Some(text) = &choice.message.content {
209 output_text_parts.push(text.clone());
210 content_items.push(ResponsesContent {
211 content_type: "output_text".to_string(),
212 text: Some(text.clone()),
213 annotations: None,
214 });
215 has_content = true;
216 }
217
218 if let Some(tool_calls) = &choice.message.tool_calls {
220 for tool_call in tool_calls {
221 let tool_text = format!(
222 "Tool call: {} with args: {}",
223 tool_call.function.name, tool_call.function.arguments
224 );
225 content_items.push(ResponsesContent {
226 content_type: "tool_use".to_string(),
227 text: Some(tool_text),
228 annotations: None,
229 });
230 has_content = true;
231 }
232 }
233
234 if has_content {
236 outputs.push(ResponsesOutput {
237 id: format!("msg_{}", Uuid::new_v4()),
238 output_type: "message".to_string(),
239 role: choice.message.role.clone(),
240 status: None,
241 content: content_items,
242 });
243 }
244 }
245
246 ResponsesObject {
247 id: request_id,
248 object: "response",
249 created_at: chat_resp.created as f64,
250 model: chat_resp.model.clone(),
251 status: "completed".to_string(),
252 output: outputs,
253 output_text: if output_text_parts.is_empty() {
254 None
255 } else {
256 Some(output_text_parts.join(" "))
257 },
258 usage: Some(ResponsesUsage {
259 input_tokens: chat_resp.usage.prompt_tokens,
260 output_tokens: chat_resp.usage.completion_tokens,
261 total_tokens: chat_resp.usage.total_tokens,
262 input_tokens_details: None,
263 output_tokens_details: None,
264 }),
265 error: None,
266 metadata,
267 instructions: None,
268 incomplete_details: None,
269 }
270}
271
272async fn parse_responses_request(
274 oairequest: ResponsesCreateRequest,
275 state: SharedMistralRsState,
276 tx: Sender<Response>,
277) -> Result<(Request, bool, Option<Vec<Message>>)> {
278 if oairequest.instructions.is_some() {
279 return Err(anyhow::anyhow!(
280 "The 'instructions' field is not supported in the Responses API"
281 ));
282 }
283 let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
285 let cache = get_response_cache();
286 cache.get_conversation_history(prev_id)?
287 } else {
288 None
289 };
290
291 let messages = oairequest.input.into_either();
293
294 let mut chat_request = ChatCompletionRequest {
296 messages: messages.clone(),
297 model: oairequest.model,
298 logit_bias: oairequest.logit_bias,
299 logprobs: oairequest.logprobs,
300 top_logprobs: oairequest.top_logprobs,
301 max_tokens: oairequest.max_tokens,
302 n_choices: oairequest.n_choices,
303 presence_penalty: oairequest.presence_penalty,
304 frequency_penalty: oairequest.frequency_penalty,
305 repetition_penalty: oairequest.repetition_penalty,
306 stop_seqs: oairequest.stop_seqs,
307 temperature: oairequest.temperature,
308 top_p: oairequest.top_p,
309 stream: oairequest.stream,
310 tools: oairequest.tools,
311 tool_choice: oairequest.tool_choice,
312 response_format: oairequest.response_format,
313 web_search_options: oairequest.web_search_options,
314 top_k: oairequest.top_k,
315 grammar: oairequest.grammar,
316 min_p: oairequest.min_p,
317 dry_multiplier: oairequest.dry_multiplier,
318 dry_base: oairequest.dry_base,
319 dry_allowed_length: oairequest.dry_allowed_length,
320 dry_sequence_breakers: oairequest.dry_sequence_breakers,
321 enable_thinking: oairequest.enable_thinking,
322 truncate_sequence: oairequest.truncate_sequence,
323 reasoning_effort: oairequest.reasoning_effort,
324 };
325
326 if let Some(prev_msgs) = previous_messages {
328 match &mut chat_request.messages {
329 Either::Left(msgs) => {
330 let mut combined = prev_msgs;
331 combined.extend(msgs.clone());
332 chat_request.messages = Either::Left(combined);
333 }
334 Either::Right(_) => {
335 let prompt = chat_request.messages.as_ref().right().unwrap().clone();
337 let mut combined = prev_msgs;
338 combined.push(Message {
339 content: Some(MessageContent::from_text(prompt)),
340 role: "user".to_string(),
341 name: None,
342 tool_calls: None,
343 tool_call_id: None,
344 });
345 chat_request.messages = Either::Left(combined);
346 }
347 }
348 }
349
350 let all_messages = match &chat_request.messages {
352 Either::Left(msgs) => msgs.clone(),
353 Either::Right(prompt) => vec![Message {
354 content: Some(MessageContent::from_text(prompt.clone())),
355 role: "user".to_string(),
356 name: None,
357 tool_calls: None,
358 tool_call_id: None,
359 }],
360 };
361
362 let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
363 Ok((request, is_streaming, Some(all_messages)))
364}
365
366#[utoipa::path(
368 post,
369 tag = "Mistral.rs",
370 path = "/v1/responses",
371 request_body = ResponsesCreateRequest,
372 responses((status = 200, description = "Response created"))
373)]
374pub async fn create_response(
375 State(state): ExtractedMistralRsState,
376 Json(oairequest): Json<ResponsesCreateRequest>,
377) -> ResponsesResponder {
378 let (tx, mut rx) = create_response_channel(None);
379 let request_id = format!("resp_{}", Uuid::new_v4());
380 let metadata = oairequest.metadata.clone();
381 let store = oairequest.store.unwrap_or(true);
382
383 let model_id = if oairequest.model == "default" {
385 None
386 } else {
387 Some(oairequest.model.clone())
388 };
389
390 let (request, is_streaming, conversation_history) =
391 match parse_responses_request(oairequest, state.clone(), tx).await {
392 Ok(x) => x,
393 Err(e) => return handle_error(state, e.into()),
394 };
395
396 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
397 return handle_error(state, e.into());
398 }
399
400 if is_streaming {
401 let streamer = ResponsesStreamer {
402 rx,
403 done_state: DoneState::Running,
404 state: state.clone(),
405 on_chunk: None,
406 on_done: None,
407 chunks: Vec::new(),
408 store_chunks: store,
409 };
410
411 if store {
413 let cache = get_response_cache();
414 let id = request_id.clone();
415 let chunks_cache = cache.clone();
416
417 let history_for_streaming = conversation_history.clone();
419 let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
420 let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
421
422 if let Some(history) = history_for_streaming.clone() {
424 let mut history = history;
425 let mut assistant_message = String::new();
426
427 for chunk in chunks {
429 if let Some(delta) = &chunk.delta {
430 if let Some(outputs) = &delta.output {
431 for output in outputs {
432 if let Some(contents) = &output.content {
433 for content in contents {
434 if let Some(text) = &content.text {
435 assistant_message.push_str(text);
436 }
437 }
438 }
439 }
440 }
441 }
442 }
443
444 if !assistant_message.is_empty() {
446 history.push(Message {
447 content: Some(MessageContent::from_text(assistant_message)),
448 role: "assistant".to_string(),
449 name: None,
450 tool_calls: None,
451 tool_call_id: None,
452 });
453 }
454
455 let _ = chunks_cache.store_conversation_history(id.clone(), history);
456 }
457 });
458
459 ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
460 } else {
461 ResponsesResponder::Sse(create_streamer(streamer, None))
462 }
463 } else {
464 match rx.recv().await {
466 Some(Response::Done(chat_resp)) => {
467 let response_obj =
468 chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
469
470 if store {
472 let cache = get_response_cache();
473 let _ = cache.store_response(request_id.clone(), response_obj.clone());
474
475 if let Some(mut history) = conversation_history.clone() {
477 for choice in &chat_resp.choices {
479 if let Some(content) = &choice.message.content {
480 history.push(Message {
481 content: Some(MessageContent::from_text(content.clone())),
482 role: choice.message.role.clone(),
483 name: None,
484 tool_calls: None, tool_call_id: None,
486 });
487 }
488 }
489 let _ = cache.store_conversation_history(request_id, history);
490 }
491 }
492
493 ResponsesResponder::Json(response_obj)
494 }
495 Some(Response::ModelError(msg, partial_resp)) => {
496 let mut response_obj =
497 chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
498 response_obj.error = Some(ResponsesError {
499 error_type: "model_error".to_string(),
500 message: msg.to_string(),
501 });
502 response_obj.status = "failed".to_string();
503
504 if store {
505 let cache = get_response_cache();
506 let _ = cache.store_response(request_id.clone(), response_obj.clone());
507
508 if let Some(mut history) = conversation_history.clone() {
510 for choice in &partial_resp.choices {
512 if let Some(content) = &choice.message.content {
513 history.push(Message {
514 content: Some(MessageContent::from_text(content.clone())),
515 role: choice.message.role.clone(),
516 name: None,
517 tool_calls: None, tool_call_id: None,
519 });
520 }
521 }
522 let _ = cache.store_conversation_history(request_id, history);
523 }
524 }
525 ResponsesResponder::ModelError(msg.to_string(), response_obj)
526 }
527 Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
528 Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
529 _ => ResponsesResponder::InternalError(
530 anyhow::anyhow!("Unexpected response type").into(),
531 ),
532 }
533 }
534}
535
536#[utoipa::path(
538 get,
539 tag = "Mistral.rs",
540 path = "/v1/responses/{response_id}",
541 params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
542 responses((status = 200, description = "Response object"))
543)]
544pub async fn get_response(
545 State(_state): ExtractedMistralRsState,
546 Path(response_id): Path<String>,
547) -> impl IntoResponse {
548 let cache = get_response_cache();
549
550 match cache.get_response(&response_id) {
551 Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
552 Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
553 .to_response(StatusCode::NOT_FOUND),
554 Err(e) => JsonError::new(format!(
555 "Error retrieving response: {}",
556 sanitize_error_message(&*e)
557 ))
558 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
559 }
560}
561
562#[utoipa::path(
564 delete,
565 tag = "Mistral.rs",
566 path = "/v1/responses/{response_id}",
567 params(("response_id" = String, Path, description = "The ID of the response to delete")),
568 responses((status = 200, description = "Response deleted"))
569)]
570pub async fn delete_response(
571 State(_state): ExtractedMistralRsState,
572 Path(response_id): Path<String>,
573) -> impl IntoResponse {
574 let cache = get_response_cache();
575
576 match cache.delete_response(&response_id) {
577 Ok(true) => (
578 StatusCode::OK,
579 Json(serde_json::json!({
580 "deleted": true,
581 "id": response_id,
582 "object": "response.deleted"
583 })),
584 )
585 .into_response(),
586 Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
587 .to_response(StatusCode::NOT_FOUND),
588 Err(e) => JsonError::new(format!(
589 "Error deleting response: {}",
590 sanitize_error_message(&*e)
591 ))
592 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
593 }
594}
595
596fn handle_error(
598 state: SharedMistralRsState,
599 e: Box<dyn std::error::Error + Send + Sync + 'static>,
600) -> ResponsesResponder {
601 handle_completion_error(state, e)
602}
603
604fn create_streamer(
606 streamer: ResponsesStreamer,
607 on_done: Option<OnDoneCallback<ResponsesChunk>>,
608) -> Sse<KeepAliveStream<ResponsesStreamer>> {
609 let keep_alive_interval = get_keep_alive_interval();
610
611 let streamer_with_callback = ResponsesStreamer {
612 on_done,
613 ..streamer
614 };
615
616 Sse::new(streamer_with_callback)
617 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
618}