1use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7 extract::{Json, Path, State},
8 http::{self, StatusCode},
9 response::{
10 sse::{Event, KeepAlive, KeepAliveStream},
11 IntoResponse, Sse,
12 },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21 cached_responses::get_response_cache,
22 chat_completion::parse_request as parse_chat_request,
23 completion_core::{handle_completion_error, BaseCompletionResponder},
24 handler_core::{
25 create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26 JsonError, ModelErrorMessage,
27 },
28 openai::{
29 ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30 ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31 ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32 },
33 streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35 util::sanitize_error_message,
36};
37
38pub type ResponsesStreamer =
40 BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43 type Item = Result<Event, axum::Error>;
44
45 fn poll_next(
46 mut self: Pin<&mut Self>,
47 cx: &mut std::task::Context<'_>,
48 ) -> Poll<Option<Self::Item>> {
49 match self.done_state {
50 DoneState::SendingDone => {
51 self.done_state = DoneState::Done;
52 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53 }
54 DoneState::Done => {
55 if let Some(on_done) = &self.on_done {
56 on_done(&self.chunks);
57 }
58 return Poll::Ready(None);
59 }
60 DoneState::Running => (),
61 }
62
63 match self.rx.poll_recv(cx) {
64 Poll::Ready(Some(resp)) => match resp {
65 Response::ModelError(msg, _) => {
66 MistralRs::maybe_log_error(
67 self.state.clone(),
68 &ModelErrorMessage(msg.to_string()),
69 );
70 self.done_state = DoneState::SendingDone;
71 Poll::Ready(Some(Ok(Event::default().data(msg))))
72 }
73 Response::ValidationError(e) => Poll::Ready(Some(Ok(
74 Event::default().data(sanitize_error_message(e.as_ref()))
75 ))),
76 Response::InternalError(e) => {
77 MistralRs::maybe_log_error(self.state.clone(), &*e);
78 Poll::Ready(Some(Ok(
79 Event::default().data(sanitize_error_message(e.as_ref()))
80 )))
81 }
82 Response::Chunk(chat_chunk) => {
83 let mut delta_outputs = vec![];
85
86 let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89 for choice in &chat_chunk.choices {
90 let mut delta_content_items = Vec::new();
91
92 if let Some(content) = &choice.delta.content {
94 delta_content_items.push(ResponsesDeltaContent {
95 content_type: "output_text".to_string(),
96 text: Some(content.clone()),
97 });
98 }
99
100 if let Some(tool_calls) = &choice.delta.tool_calls {
102 for tool_call in tool_calls {
103 let tool_text = format!(
104 "Tool: {} args: {}",
105 tool_call.function.name, tool_call.function.arguments
106 );
107 delta_content_items.push(ResponsesDeltaContent {
108 content_type: "tool_use".to_string(),
109 text: Some(tool_text),
110 });
111 }
112 }
113
114 if !delta_content_items.is_empty() {
115 delta_outputs.push(ResponsesDeltaOutput {
116 id: format!("msg_{}", Uuid::new_v4()),
117 output_type: "message".to_string(),
118 content: Some(delta_content_items),
119 });
120 }
121 }
122
123 let mut response_chunk = ResponsesChunk {
124 id: chat_chunk.id.clone(),
125 object: "response.chunk",
126 created_at: chat_chunk.created as f64,
127 model: chat_chunk.model.clone(),
128 chunk_type: "delta".to_string(),
129 delta: Some(ResponsesDelta {
130 output: if delta_outputs.is_empty() {
131 None
132 } else {
133 Some(delta_outputs)
134 },
135 status: if all_finished {
136 Some("completed".to_string())
137 } else {
138 None
139 },
140 }),
141 usage: None,
142 metadata: None,
143 };
144
145 if all_finished {
146 self.done_state = DoneState::SendingDone;
147 }
148
149 MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151 if let Some(on_chunk) = &self.on_chunk {
152 response_chunk = on_chunk(response_chunk);
153 }
154
155 if self.store_chunks {
156 self.chunks.push(response_chunk.clone());
157 }
158
159 Poll::Ready(Some(Event::default().json_data(response_chunk)))
160 }
161 _ => unreachable!(),
162 },
163 Poll::Pending | Poll::Ready(None) => Poll::Pending,
164 }
165 }
166}
167
168pub type ResponsesResponder =
170 BaseCompletionResponder<ResponsesObject, KeepAliveStream<ResponsesStreamer>>;
171
172type JsonModelError = BaseJsonModelError<ResponsesObject>;
173impl ErrorToResponse for JsonModelError {}
174
175impl IntoResponse for ResponsesResponder {
176 fn into_response(self) -> axum::response::Response {
177 match self {
178 ResponsesResponder::Sse(s) => s.into_response(),
179 ResponsesResponder::Json(s) => Json(s).into_response(),
180 ResponsesResponder::InternalError(e) => {
181 JsonError::new(sanitize_error_message(e.as_ref()))
182 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
183 }
184 ResponsesResponder::ValidationError(e) => {
185 JsonError::new(sanitize_error_message(e.as_ref()))
186 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
187 }
188 ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
189 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
190 }
191 }
192}
193
194fn chat_response_to_responses_object(
196 chat_resp: &ChatCompletionResponse,
197 request_id: String,
198 metadata: Option<Value>,
199) -> ResponsesObject {
200 let mut outputs = Vec::new();
201 let mut output_text_parts = Vec::new();
202
203 for choice in &chat_resp.choices {
204 let mut content_items = Vec::new();
205 let mut has_content = false;
206
207 if let Some(text) = &choice.message.content {
209 output_text_parts.push(text.clone());
210 content_items.push(ResponsesContent {
211 content_type: "output_text".to_string(),
212 text: Some(text.clone()),
213 annotations: None,
214 });
215 has_content = true;
216 }
217
218 if let Some(tool_calls) = &choice.message.tool_calls {
220 for tool_call in tool_calls {
221 let tool_text = format!(
222 "Tool call: {} with args: {}",
223 tool_call.function.name, tool_call.function.arguments
224 );
225 content_items.push(ResponsesContent {
226 content_type: "tool_use".to_string(),
227 text: Some(tool_text),
228 annotations: None,
229 });
230 has_content = true;
231 }
232 }
233
234 if has_content {
236 outputs.push(ResponsesOutput {
237 id: format!("msg_{}", Uuid::new_v4()),
238 output_type: "message".to_string(),
239 role: choice.message.role.clone(),
240 status: None,
241 content: content_items,
242 });
243 }
244 }
245
246 ResponsesObject {
247 id: request_id,
248 object: "response",
249 created_at: chat_resp.created as f64,
250 model: chat_resp.model.clone(),
251 status: "completed".to_string(),
252 output: outputs,
253 output_text: if output_text_parts.is_empty() {
254 None
255 } else {
256 Some(output_text_parts.join(" "))
257 },
258 usage: Some(ResponsesUsage {
259 input_tokens: chat_resp.usage.prompt_tokens,
260 output_tokens: chat_resp.usage.completion_tokens,
261 total_tokens: chat_resp.usage.total_tokens,
262 input_tokens_details: None,
263 output_tokens_details: None,
264 }),
265 error: None,
266 metadata,
267 instructions: None,
268 incomplete_details: None,
269 }
270}
271
272async fn parse_responses_request(
274 oairequest: ResponsesCreateRequest,
275 state: SharedMistralRsState,
276 tx: Sender<Response>,
277) -> Result<(Request, bool, Option<Vec<Message>>)> {
278 if oairequest.instructions.is_some() {
279 return Err(anyhow::anyhow!(
280 "The 'instructions' field is not supported in the Responses API"
281 ));
282 }
283 let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
285 let cache = get_response_cache();
286 cache.get_conversation_history(prev_id)?
287 } else {
288 None
289 };
290
291 let messages = oairequest.input.into_either();
293
294 let mut chat_request = ChatCompletionRequest {
296 messages: messages.clone(),
297 model: oairequest.model,
298 logit_bias: oairequest.logit_bias,
299 logprobs: oairequest.logprobs,
300 top_logprobs: oairequest.top_logprobs,
301 max_tokens: oairequest.max_tokens,
302 n_choices: oairequest.n_choices,
303 presence_penalty: oairequest.presence_penalty,
304 frequency_penalty: oairequest.frequency_penalty,
305 repetition_penalty: oairequest.repetition_penalty,
306 stop_seqs: oairequest.stop_seqs,
307 temperature: oairequest.temperature,
308 top_p: oairequest.top_p,
309 stream: oairequest.stream,
310 tools: oairequest.tools,
311 tool_choice: oairequest.tool_choice,
312 response_format: oairequest.response_format,
313 web_search_options: oairequest.web_search_options,
314 top_k: oairequest.top_k,
315 grammar: oairequest.grammar,
316 min_p: oairequest.min_p,
317 dry_multiplier: oairequest.dry_multiplier,
318 dry_base: oairequest.dry_base,
319 dry_allowed_length: oairequest.dry_allowed_length,
320 dry_sequence_breakers: oairequest.dry_sequence_breakers,
321 enable_thinking: oairequest.enable_thinking,
322 truncate_sequence: oairequest.truncate_sequence,
323 };
324
325 if let Some(prev_msgs) = previous_messages {
327 match &mut chat_request.messages {
328 Either::Left(msgs) => {
329 let mut combined = prev_msgs;
330 combined.extend(msgs.clone());
331 chat_request.messages = Either::Left(combined);
332 }
333 Either::Right(_) => {
334 let prompt = chat_request.messages.as_ref().right().unwrap().clone();
336 let mut combined = prev_msgs;
337 combined.push(Message {
338 content: Some(MessageContent::from_text(prompt)),
339 role: "user".to_string(),
340 name: None,
341 tool_calls: None,
342 });
343 chat_request.messages = Either::Left(combined);
344 }
345 }
346 }
347
348 let all_messages = match &chat_request.messages {
350 Either::Left(msgs) => msgs.clone(),
351 Either::Right(prompt) => vec![Message {
352 content: Some(MessageContent::from_text(prompt.clone())),
353 role: "user".to_string(),
354 name: None,
355 tool_calls: None,
356 }],
357 };
358
359 let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
360 Ok((request, is_streaming, Some(all_messages)))
361}
362
363#[utoipa::path(
365 post,
366 tag = "Mistral.rs",
367 path = "/v1/responses",
368 request_body = ResponsesCreateRequest,
369 responses((status = 200, description = "Response created"))
370)]
371pub async fn create_response(
372 State(state): ExtractedMistralRsState,
373 Json(oairequest): Json<ResponsesCreateRequest>,
374) -> ResponsesResponder {
375 let (tx, mut rx) = create_response_channel(None);
376 let request_id = format!("resp_{}", Uuid::new_v4());
377 let metadata = oairequest.metadata.clone();
378 let store = oairequest.store.unwrap_or(true);
379
380 let model_id = if oairequest.model == "default" {
382 None
383 } else {
384 Some(oairequest.model.clone())
385 };
386
387 let (request, is_streaming, conversation_history) =
388 match parse_responses_request(oairequest, state.clone(), tx).await {
389 Ok(x) => x,
390 Err(e) => return handle_error(state, e.into()),
391 };
392
393 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
394 return handle_error(state, e.into());
395 }
396
397 if is_streaming {
398 let streamer = ResponsesStreamer {
399 rx,
400 done_state: DoneState::Running,
401 state: state.clone(),
402 on_chunk: None,
403 on_done: None,
404 chunks: Vec::new(),
405 store_chunks: store,
406 };
407
408 if store {
410 let cache = get_response_cache();
411 let id = request_id.clone();
412 let chunks_cache = cache.clone();
413
414 let history_for_streaming = conversation_history.clone();
416 let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
417 let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
418
419 if let Some(history) = history_for_streaming.clone() {
421 let mut history = history;
422 let mut assistant_message = String::new();
423
424 for chunk in chunks {
426 if let Some(delta) = &chunk.delta {
427 if let Some(outputs) = &delta.output {
428 for output in outputs {
429 if let Some(contents) = &output.content {
430 for content in contents {
431 if let Some(text) = &content.text {
432 assistant_message.push_str(text);
433 }
434 }
435 }
436 }
437 }
438 }
439 }
440
441 if !assistant_message.is_empty() {
443 history.push(Message {
444 content: Some(MessageContent::from_text(assistant_message)),
445 role: "assistant".to_string(),
446 name: None,
447 tool_calls: None,
448 });
449 }
450
451 let _ = chunks_cache.store_conversation_history(id.clone(), history);
452 }
453 });
454
455 ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
456 } else {
457 ResponsesResponder::Sse(create_streamer(streamer, None))
458 }
459 } else {
460 match rx.recv().await {
462 Some(Response::Done(chat_resp)) => {
463 let response_obj =
464 chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
465
466 if store {
468 let cache = get_response_cache();
469 let _ = cache.store_response(request_id.clone(), response_obj.clone());
470
471 if let Some(mut history) = conversation_history.clone() {
473 for choice in &chat_resp.choices {
475 if let Some(content) = &choice.message.content {
476 history.push(Message {
477 content: Some(MessageContent::from_text(content.clone())),
478 role: choice.message.role.clone(),
479 name: None,
480 tool_calls: None, });
482 }
483 }
484 let _ = cache.store_conversation_history(request_id, history);
485 }
486 }
487
488 ResponsesResponder::Json(response_obj)
489 }
490 Some(Response::ModelError(msg, partial_resp)) => {
491 let mut response_obj =
492 chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
493 response_obj.error = Some(ResponsesError {
494 error_type: "model_error".to_string(),
495 message: msg.to_string(),
496 });
497 response_obj.status = "failed".to_string();
498
499 if store {
500 let cache = get_response_cache();
501 let _ = cache.store_response(request_id.clone(), response_obj.clone());
502
503 if let Some(mut history) = conversation_history.clone() {
505 for choice in &partial_resp.choices {
507 if let Some(content) = &choice.message.content {
508 history.push(Message {
509 content: Some(MessageContent::from_text(content.clone())),
510 role: choice.message.role.clone(),
511 name: None,
512 tool_calls: None, });
514 }
515 }
516 let _ = cache.store_conversation_history(request_id, history);
517 }
518 }
519 ResponsesResponder::ModelError(msg.to_string(), response_obj)
520 }
521 Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
522 Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
523 _ => ResponsesResponder::InternalError(
524 anyhow::anyhow!("Unexpected response type").into(),
525 ),
526 }
527 }
528}
529
530#[utoipa::path(
532 get,
533 tag = "Mistral.rs",
534 path = "/v1/responses/{response_id}",
535 params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
536 responses((status = 200, description = "Response object"))
537)]
538pub async fn get_response(
539 State(_state): ExtractedMistralRsState,
540 Path(response_id): Path<String>,
541) -> impl IntoResponse {
542 let cache = get_response_cache();
543
544 match cache.get_response(&response_id) {
545 Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
546 Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
547 .to_response(StatusCode::NOT_FOUND),
548 Err(e) => JsonError::new(format!(
549 "Error retrieving response: {}",
550 sanitize_error_message(&*e)
551 ))
552 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
553 }
554}
555
556#[utoipa::path(
558 delete,
559 tag = "Mistral.rs",
560 path = "/v1/responses/{response_id}",
561 params(("response_id" = String, Path, description = "The ID of the response to delete")),
562 responses((status = 200, description = "Response deleted"))
563)]
564pub async fn delete_response(
565 State(_state): ExtractedMistralRsState,
566 Path(response_id): Path<String>,
567) -> impl IntoResponse {
568 let cache = get_response_cache();
569
570 match cache.delete_response(&response_id) {
571 Ok(true) => (
572 StatusCode::OK,
573 Json(serde_json::json!({
574 "deleted": true,
575 "id": response_id,
576 "object": "response.deleted"
577 })),
578 )
579 .into_response(),
580 Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
581 .to_response(StatusCode::NOT_FOUND),
582 Err(e) => JsonError::new(format!(
583 "Error deleting response: {}",
584 sanitize_error_message(&*e)
585 ))
586 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
587 }
588}
589
590fn handle_error(
592 state: SharedMistralRsState,
593 e: Box<dyn std::error::Error + Send + Sync + 'static>,
594) -> ResponsesResponder {
595 handle_completion_error(state, e)
596}
597
598fn create_streamer(
600 streamer: ResponsesStreamer,
601 on_done: Option<OnDoneCallback<ResponsesChunk>>,
602) -> Sse<KeepAliveStream<ResponsesStreamer>> {
603 let keep_alive_interval = get_keep_alive_interval();
604
605 let streamer_with_callback = ResponsesStreamer {
606 on_done,
607 ..streamer
608 };
609
610 Sse::new(streamer_with_callback)
611 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
612}