1use std::{
7 collections::HashMap,
8 pin::Pin,
9 task::Poll,
10 time::{Duration, SystemTime, UNIX_EPOCH},
11};
12
13use anyhow::Result;
14use axum::{
15 extract::{Json, Path, State},
16 http::{self, StatusCode},
17 response::{
18 sse::{Event, KeepAlive, KeepAliveStream},
19 IntoResponse, Sse,
20 },
21};
22use either::Either;
23use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
24use serde::{Deserialize, Serialize};
25use serde_json::Value;
26use tokio::sync::mpsc::{Receiver, Sender};
27use utoipa::{
28 openapi::{schema::SchemaType, ArrayBuilder, ObjectBuilder, OneOfBuilder, RefOr, Schema, Type},
29 PartialSchema, ToSchema,
30};
31use uuid::Uuid;
32
33use crate::{
34 background_tasks::get_background_task_manager,
35 cached_responses::get_response_cache,
36 chat_completion::parse_request as parse_chat_request,
37 completion_core::{handle_completion_error, BaseCompletionResponder},
38 handler_core::{
39 create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
40 JsonError, ModelErrorMessage,
41 },
42 openai::{ChatCompletionRequest, Message, MessageContent, ToolCall},
43 responses_types::{
44 content::OutputContent,
45 enums::{ItemStatus, ResponseStatus},
46 events::StreamingState,
47 items::{InputItem, MessageContentParam, OutputItem},
48 resource::{ResponseError, ResponseResource, ResponseUsage},
49 },
50 streaming::{get_keep_alive_interval, DoneState},
51 types::{ExtractedMistralRsState, OnDoneCallback, SharedMistralRsState},
52 util::sanitize_error_message,
53};
54
55#[derive(Debug, Clone, Deserialize, Serialize)]
57#[serde(untagged)]
58pub enum OpenResponsesInput {
59 Text(String),
61 Items(Vec<InputItem>),
63}
64
65impl PartialSchema for OpenResponsesInput {
66 fn schema() -> RefOr<Schema> {
67 RefOr::T(Schema::OneOf(
68 OneOfBuilder::new()
69 .item(Schema::Object(
70 ObjectBuilder::new()
71 .schema_type(SchemaType::Type(Type::String))
72 .description(Some("Simple text input"))
73 .build(),
74 ))
75 .item(Schema::Array(
76 ArrayBuilder::new()
77 .items(InputItem::schema())
78 .description(Some("Array of input items (OpenResponses format)"))
79 .build(),
80 ))
81 .build(),
82 ))
83 }
84}
85
86impl ToSchema for OpenResponsesInput {
87 fn schemas(
88 schemas: &mut Vec<(
89 String,
90 utoipa::openapi::RefOr<utoipa::openapi::schema::Schema>,
91 )>,
92 ) {
93 schemas.push((
94 OpenResponsesInput::name().into(),
95 OpenResponsesInput::schema(),
96 ));
97 }
98}
99
100impl OpenResponsesInput {
101 pub fn into_either(self) -> Either<Vec<Message>, String> {
103 match self {
104 OpenResponsesInput::Text(s) => Either::Right(s),
105 OpenResponsesInput::Items(items) => {
106 let messages = convert_input_items_to_messages(items);
107 Either::Left(messages)
108 }
109 }
110 }
111}
112
113fn convert_input_items_to_messages(items: Vec<InputItem>) -> Vec<Message> {
117 use crate::responses_types::content::NormalizedInputContent;
118 use crate::responses_types::items::TaggedInputItem;
119
120 let mut messages = Vec::new();
121
122 for item in items {
123 match item.into_tagged() {
125 TaggedInputItem::Message(msg_param) => {
126 let content = match msg_param.content {
127 MessageContentParam::Text(text) => Some(MessageContent::from_text(text)),
128 MessageContentParam::Parts(parts) => {
129 let mut content_parts = Vec::new();
131 let mut has_non_text_content = false;
132
133 for part in parts {
134 match part.into_normalized() {
136 NormalizedInputContent::Text { text } => {
137 content_parts.push(MessageContent::text_part(text));
138 }
139 NormalizedInputContent::Image {
140 image_url,
141 image_data,
142 detail,
143 } => {
144 has_non_text_content = true;
145 let url = if let Some(url) = image_url {
147 url
148 } else if let Some(data) = image_data {
149 format!("data:image/png;base64,{}", data)
151 } else {
152 continue; };
154
155 let image_part = if let Some(detail_level) = detail {
156 let detail_str = match detail_level {
157 crate::responses_types::enums::ImageDetail::Auto => {
158 "auto"
159 }
160 crate::responses_types::enums::ImageDetail::Low => {
161 "low"
162 }
163 crate::responses_types::enums::ImageDetail::High => {
164 "high"
165 }
166 };
167 MessageContent::image_url_part_with_detail(
168 url,
169 detail_str.to_string(),
170 )
171 } else {
172 MessageContent::image_url_part(url)
173 };
174 content_parts.push(image_part);
175 }
176 NormalizedInputContent::Audio { data, format } => {
177 has_non_text_content = true;
178 let mime_type = match format.as_str() {
180 "wav" => "audio/wav",
181 "mp3" => "audio/mpeg",
182 "flac" => "audio/flac",
183 "ogg" => "audio/ogg",
184 _ => "audio/wav", };
186 let audio_url = format!("data:{};base64,{}", mime_type, data);
187 let mut audio_part = std::collections::HashMap::new();
190 audio_part.insert(
191 "type".to_string(),
192 crate::openai::MessageInnerContent(Either::Left(
193 "input_audio".to_string(),
194 )),
195 );
196 let mut audio_obj = std::collections::HashMap::new();
197 audio_obj.insert("data".to_string(), data);
198 audio_obj.insert("format".to_string(), format);
199 audio_part.insert(
200 "input_audio".to_string(),
201 crate::openai::MessageInnerContent(Either::Right(
202 audio_obj,
203 )),
204 );
205 content_parts.push(audio_part);
206 content_parts.push(MessageContent::text_part(format!(
208 "[Audio content: {}]",
209 audio_url
210 )));
211 }
212 NormalizedInputContent::File {
213 file_id,
214 file_data,
215 filename,
216 } => {
217 has_non_text_content = true;
218 let file_ref = if let Some(id) = file_id {
220 format!("[File reference: {}]", id)
221 } else if let Some(data) = file_data {
222 let name =
223 filename.unwrap_or_else(|| "unnamed_file".to_string());
224 format!(
225 "[File: {} (base64 data: {} bytes)]",
226 name,
227 data.len()
228 )
229 } else if let Some(name) = filename {
230 format!("[File: {}]", name)
231 } else {
232 "[File reference]".to_string()
233 };
234 content_parts.push(MessageContent::text_part(file_ref));
235 }
236 }
237 }
238
239 if content_parts.is_empty() {
240 None
241 } else if !has_non_text_content && content_parts.len() == 1 {
242 let first = &content_parts[0];
245 if let Some(text_value) = first.get("text") {
246 if let Either::Left(text) = &**text_value {
247 Some(MessageContent::from_text(text.clone()))
248 } else {
249 Some(MessageContent::from_parts(content_parts))
250 }
251 } else {
252 Some(MessageContent::from_parts(content_parts))
253 }
254 } else {
255 Some(MessageContent::from_parts(content_parts))
256 }
257 }
258 };
259
260 messages.push(Message {
261 content,
262 role: msg_param.role,
263 name: msg_param.name,
264 tool_calls: None,
265 tool_call_id: None,
266 });
267 }
268 TaggedInputItem::ItemReference { id: _ } => {
269 }
272 TaggedInputItem::FunctionCall {
273 call_id,
274 name,
275 arguments,
276 } => {
277 messages.push(Message {
279 content: None,
280 role: "assistant".to_string(),
281 name: None,
282 tool_calls: Some(vec![ToolCall {
283 id: Some(call_id),
284 tp: mistralrs_core::ToolType::Function,
285 function: crate::openai::FunctionCalled { name, arguments },
286 }]),
287 tool_call_id: None,
288 });
289 }
290 TaggedInputItem::FunctionCallOutput { call_id, output } => {
291 messages.push(Message {
293 content: Some(MessageContent::from_text(output)),
294 role: "tool".to_string(),
295 name: None,
296 tool_calls: None,
297 tool_call_id: Some(call_id),
298 });
299 }
300 }
301 }
302
303 messages
304}
305
306#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
308pub struct ReasoningConfig {
309 #[serde(skip_serializing_if = "Option::is_none")]
311 pub effort: Option<crate::responses_types::enums::ReasoningEffort>,
312 #[serde(skip_serializing_if = "Option::is_none")]
314 pub summary: Option<ReasoningSummary>,
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, ToSchema)]
319#[serde(rename_all = "lowercase")]
320pub enum ReasoningSummary {
321 Concise,
323 Detailed,
325 Auto,
327}
328
329#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
331pub struct TextConfig {
332 #[serde(skip_serializing_if = "Option::is_none")]
334 pub format: Option<TextFormat>,
335}
336
337#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
339#[serde(tag = "type")]
340pub enum TextFormat {
341 #[serde(rename = "text")]
343 Text,
344 #[serde(rename = "json_schema")]
346 JsonSchema {
347 name: String,
349 #[serde(skip_serializing_if = "Option::is_none")]
351 schema: Option<Value>,
352 #[serde(skip_serializing_if = "Option::is_none")]
354 strict: Option<bool>,
355 },
356 #[serde(rename = "json_object")]
358 JsonObject,
359}
360
361#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
363pub struct StreamOptions {
364 #[serde(skip_serializing_if = "Option::is_none")]
366 pub include_usage: Option<bool>,
367}
368
369#[derive(Debug, Clone, Default)]
374pub struct RequestContext {
375 pub tools: Option<Vec<mistralrs_core::Tool>>,
377 pub tool_choice: Option<mistralrs_core::ToolChoice>,
379 pub parallel_tool_calls: Option<bool>,
381 pub text: Option<TextConfig>,
383 pub temperature: Option<f64>,
385 pub top_p: Option<f64>,
387 pub presence_penalty: Option<f32>,
389 pub frequency_penalty: Option<f32>,
391 pub top_logprobs: Option<usize>,
393 pub max_output_tokens: Option<usize>,
395 pub max_tool_calls: Option<usize>,
397 pub store: Option<bool>,
399 pub background: Option<bool>,
401}
402
403#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, ToSchema)]
408#[serde(rename_all = "snake_case")]
409pub enum IncludeOption {
410 #[serde(rename = "file_search_call.results")]
412 FileSearchCallResults,
413 #[serde(rename = "message.input_image.image_url")]
415 MessageInputImageUrl,
416 #[serde(rename = "computer_call_output.output.image_url")]
418 ComputerCallOutputImageUrl,
419 #[serde(rename = "reasoning.encrypted_content")]
421 ReasoningEncryptedContent,
422}
423
424#[derive(Debug, Clone, Default)]
426pub struct IncludeConfig {
427 pub options: Vec<IncludeOption>,
429}
430
431impl IncludeConfig {
432 pub fn new(options: Option<Vec<IncludeOption>>) -> Self {
434 Self {
435 options: options.unwrap_or_default(),
436 }
437 }
438
439 pub fn has(&self, option: &IncludeOption) -> bool {
441 self.options.contains(option)
442 }
443
444 pub fn include_reasoning(&self) -> bool {
446 true
450 }
451}
452
453#[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]
455pub struct OpenResponsesCreateRequest {
456 #[serde(default = "default_model")]
459 pub model: String,
460
461 pub input: OpenResponsesInput,
463
464 #[serde(skip_serializing_if = "Option::is_none")]
466 pub instructions: Option<String>,
467
468 #[serde(skip_serializing_if = "Option::is_none")]
470 pub previous_response_id: Option<String>,
471
472 #[serde(skip_serializing_if = "Option::is_none")]
474 pub stream: Option<bool>,
475
476 #[serde(skip_serializing_if = "Option::is_none")]
478 pub stream_options: Option<StreamOptions>,
479
480 #[serde(skip_serializing_if = "Option::is_none")]
482 pub background: Option<bool>,
483
484 #[serde(skip_serializing_if = "Option::is_none")]
486 pub store: Option<bool>,
487
488 #[serde(skip_serializing_if = "Option::is_none")]
490 pub metadata: Option<Value>,
491
492 #[serde(skip_serializing_if = "Option::is_none")]
494 pub include: Option<Vec<IncludeOption>>,
495
496 #[serde(
499 alias = "max_tokens",
500 alias = "max_completion_tokens",
501 skip_serializing_if = "Option::is_none"
502 )]
503 pub max_output_tokens: Option<usize>,
504
505 #[serde(skip_serializing_if = "Option::is_none")]
507 pub temperature: Option<f64>,
508
509 #[serde(skip_serializing_if = "Option::is_none")]
511 pub top_p: Option<f64>,
512
513 #[serde(skip_serializing_if = "Option::is_none")]
515 pub presence_penalty: Option<f32>,
516
517 #[serde(skip_serializing_if = "Option::is_none")]
519 pub frequency_penalty: Option<f32>,
520
521 #[serde(skip_serializing_if = "Option::is_none")]
523 pub top_logprobs: Option<usize>,
524
525 #[serde(skip_serializing_if = "Option::is_none")]
528 pub tools: Option<Vec<mistralrs_core::Tool>>,
529
530 #[serde(skip_serializing_if = "Option::is_none")]
532 pub tool_choice: Option<mistralrs_core::ToolChoice>,
533
534 #[serde(skip_serializing_if = "Option::is_none")]
539 pub parallel_tool_calls: Option<bool>,
540
541 #[serde(skip_serializing_if = "Option::is_none")]
546 pub max_tool_calls: Option<usize>,
547
548 #[serde(skip_serializing_if = "Option::is_none")]
551 pub reasoning: Option<ReasoningConfig>,
552
553 #[serde(skip_serializing_if = "Option::is_none")]
556 pub text: Option<TextConfig>,
557
558 #[serde(skip_serializing_if = "Option::is_none")]
560 pub truncation: Option<crate::responses_types::enums::TruncationStrategy>,
561
562 #[serde(rename = "stop", skip_serializing_if = "Option::is_none")]
565 pub stop_seqs: Option<crate::openai::StopTokens>,
566
567 #[serde(skip_serializing_if = "Option::is_none")]
569 pub response_format: Option<crate::openai::ResponseFormat>,
570
571 #[serde(skip_serializing_if = "Option::is_none")]
573 pub logit_bias: Option<HashMap<u32, f32>>,
574
575 #[serde(default)]
577 pub logprobs: bool,
578
579 #[serde(rename = "n", default = "default_1usize")]
581 pub n_choices: usize,
582
583 #[serde(skip_serializing_if = "Option::is_none")]
585 pub repetition_penalty: Option<f32>,
586
587 #[serde(skip_serializing_if = "Option::is_none")]
589 pub top_k: Option<usize>,
590
591 #[serde(skip_serializing_if = "Option::is_none")]
593 pub grammar: Option<crate::openai::Grammar>,
594
595 #[serde(skip_serializing_if = "Option::is_none")]
597 pub min_p: Option<f64>,
598
599 #[serde(skip_serializing_if = "Option::is_none")]
601 pub dry_multiplier: Option<f32>,
602
603 #[serde(skip_serializing_if = "Option::is_none")]
605 pub dry_base: Option<f32>,
606
607 #[serde(skip_serializing_if = "Option::is_none")]
609 pub dry_allowed_length: Option<usize>,
610
611 #[serde(skip_serializing_if = "Option::is_none")]
613 pub dry_sequence_breakers: Option<Vec<String>>,
614
615 #[serde(skip_serializing_if = "Option::is_none")]
617 pub web_search_options: Option<mistralrs_core::WebSearchOptions>,
618}
619
620fn default_model() -> String {
621 "default".to_string()
622}
623
624fn default_1usize() -> usize {
625 1
626}
627
628#[derive(Debug, Clone, Serialize)]
630#[serde(tag = "type")]
631pub enum OpenResponsesStreamEvent {
632 #[serde(rename = "response.created")]
634 ResponseCreated {
635 sequence_number: u64,
636 response: ResponseResource,
637 },
638 #[serde(rename = "response.in_progress")]
640 ResponseInProgress {
641 sequence_number: u64,
642 response: ResponseResource,
643 },
644 #[serde(rename = "response.output_item.added")]
646 OutputItemAdded {
647 sequence_number: u64,
648 output_index: usize,
649 item: OutputItem,
650 },
651 #[serde(rename = "response.content_part.added")]
653 ContentPartAdded {
654 sequence_number: u64,
655 output_index: usize,
656 content_index: usize,
657 part: OutputContent,
658 },
659 #[serde(rename = "response.output_text.delta")]
661 OutputTextDelta {
662 sequence_number: u64,
663 output_index: usize,
664 content_index: usize,
665 delta: String,
666 },
667 #[serde(rename = "response.content_part.done")]
669 ContentPartDone {
670 sequence_number: u64,
671 output_index: usize,
672 content_index: usize,
673 part: OutputContent,
674 },
675 #[serde(rename = "response.output_item.done")]
677 OutputItemDone {
678 sequence_number: u64,
679 output_index: usize,
680 item: OutputItem,
681 },
682 #[serde(rename = "response.function_call_arguments.delta")]
684 FunctionCallArgumentsDelta {
685 sequence_number: u64,
686 output_index: usize,
687 call_id: String,
688 delta: String,
689 },
690 #[serde(rename = "response.function_call_arguments.done")]
692 FunctionCallArgumentsDone {
693 sequence_number: u64,
694 output_index: usize,
695 call_id: String,
696 arguments: String,
697 },
698 #[serde(rename = "response.completed")]
700 ResponseCompleted {
701 sequence_number: u64,
702 response: ResponseResource,
703 },
704 #[serde(rename = "response.failed")]
706 ResponseFailed {
707 sequence_number: u64,
708 response: ResponseResource,
709 },
710 #[serde(rename = "response.incomplete")]
712 ResponseIncomplete {
713 sequence_number: u64,
714 response: ResponseResource,
715 },
716 #[serde(rename = "error")]
718 Error {
719 sequence_number: u64,
720 error: ResponseError,
721 },
722}
723
724pub struct OpenResponsesStreamer {
726 rx: Receiver<Response>,
728 done_state: DoneState,
730 state: SharedMistralRsState,
732 streaming_state: StreamingState,
734 metadata: Option<Value>,
736 pending_events: Vec<OpenResponsesStreamEvent>,
738 accumulated_text: String,
740 accumulated_reasoning: String,
742 content_part_added: bool,
744 output_item_added: bool,
746 store: bool,
748 conversation_history: Option<Vec<Message>>,
750 on_done: Option<OnDoneCallback<OpenResponsesStreamEvent>>,
752 events: Vec<OpenResponsesStreamEvent>,
754 request_context: RequestContext,
756}
757
758impl OpenResponsesStreamer {
759 #[allow(clippy::too_many_arguments)]
761 pub fn new(
762 rx: Receiver<Response>,
763 state: SharedMistralRsState,
764 response_id: String,
765 model: String,
766 metadata: Option<Value>,
767 store: bool,
768 conversation_history: Option<Vec<Message>>,
769 request_context: RequestContext,
770 ) -> Self {
771 let created_at = SystemTime::now()
772 .duration_since(UNIX_EPOCH)
773 .unwrap()
774 .as_secs();
775
776 Self {
777 rx,
778 done_state: DoneState::Running,
779 state,
780 streaming_state: StreamingState::new(response_id, model, created_at),
781 metadata,
782 pending_events: Vec::new(),
783 accumulated_text: String::new(),
784 accumulated_reasoning: String::new(),
785 content_part_added: false,
786 output_item_added: false,
787 store,
788 conversation_history,
789 on_done: None,
790 events: Vec::new(),
791 request_context,
792 }
793 }
794
795 fn build_response_resource(&self, status: ResponseStatus) -> ResponseResource {
797 let mut resource = ResponseResource::new(
798 self.streaming_state.response_id.clone(),
799 self.streaming_state.model.clone(),
800 self.streaming_state.created_at,
801 );
802 resource.status = status;
803 resource.metadata = self.metadata.clone();
804
805 resource.tools = self.request_context.tools.clone();
807 resource.tool_choice = self.request_context.tool_choice.clone();
808 resource.parallel_tool_calls = self.request_context.parallel_tool_calls;
809 resource.text = self.request_context.text.clone();
810 resource.temperature = self.request_context.temperature;
811 resource.top_p = self.request_context.top_p;
812 resource.presence_penalty = self.request_context.presence_penalty;
813 resource.frequency_penalty = self.request_context.frequency_penalty;
814 resource.top_logprobs = self.request_context.top_logprobs;
815 resource.max_output_tokens = self.request_context.max_output_tokens;
816 resource.max_tool_calls = self.request_context.max_tool_calls;
817 resource.store = self.request_context.store;
818 resource.background = self.request_context.background;
819
820 resource
821 }
822
823 fn build_current_response(&self, status: ResponseStatus) -> ResponseResource {
825 let mut resource = self.build_response_resource(status);
826
827 if !self.accumulated_text.is_empty() {
829 let content = vec![OutputContent::text(self.accumulated_text.clone())];
830 let item = OutputItem::message(
831 format!("msg_{}", Uuid::new_v4()),
832 content,
833 if status == ResponseStatus::Completed {
834 ItemStatus::Completed
835 } else {
836 ItemStatus::InProgress
837 },
838 );
839 resource.output = vec![item];
840 resource.output_text = Some(self.accumulated_text.clone());
841 }
842
843 if !self.accumulated_reasoning.is_empty() {
845 resource.reasoning = Some(self.accumulated_reasoning.clone());
846 }
847
848 resource
849 }
850}
851
852impl futures::Stream for OpenResponsesStreamer {
853 type Item = Result<Event, axum::Error>;
854
855 fn poll_next(
856 mut self: Pin<&mut Self>,
857 cx: &mut std::task::Context<'_>,
858 ) -> Poll<Option<Self::Item>> {
859 if !self.pending_events.is_empty() {
861 let event = self.pending_events.remove(0);
862 self.events.push(event.clone());
863 return Poll::Ready(Some(
864 Event::default()
865 .event(get_event_type(&event))
866 .json_data(event),
867 ));
868 }
869
870 match self.done_state {
871 DoneState::SendingDone => {
872 self.done_state = DoneState::Done;
873 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
874 }
875 DoneState::Done => {
876 if self.store {
878 if let Some(history) = self.conversation_history.take() {
879 let cache = get_response_cache();
880 let mut history = history;
881
882 if !self.accumulated_text.is_empty() {
884 history.push(Message {
885 content: Some(MessageContent::from_text(
886 self.accumulated_text.clone(),
887 )),
888 role: "assistant".to_string(),
889 name: None,
890 tool_calls: None,
891 tool_call_id: None,
892 });
893 }
894
895 let _ = cache.store_conversation_history(
896 self.streaming_state.response_id.clone(),
897 history,
898 );
899 }
900 }
901
902 if let Some(on_done) = &self.on_done {
903 on_done(&self.events);
904 }
905 return Poll::Ready(None);
906 }
907 DoneState::Running => (),
908 }
909
910 if !self.streaming_state.created_sent {
912 self.streaming_state.created_sent = true;
913 let seq = self.streaming_state.next_sequence_number();
914 let response = self.build_response_resource(ResponseStatus::Queued);
915 let event = OpenResponsesStreamEvent::ResponseCreated {
916 sequence_number: seq,
917 response,
918 };
919 self.events.push(event.clone());
920 return Poll::Ready(Some(
921 Event::default().event("response.created").json_data(event),
922 ));
923 }
924
925 match self.rx.poll_recv(cx) {
926 Poll::Ready(Some(resp)) => match resp {
927 Response::ModelError(msg, _) => {
928 MistralRs::maybe_log_error(
929 self.state.clone(),
930 &ModelErrorMessage(msg.to_string()),
931 );
932
933 let seq = self.streaming_state.next_sequence_number();
934 let mut response = self.build_current_response(ResponseStatus::Failed);
935 response.error = Some(ResponseError::new("model_error", msg.to_string()));
936
937 let event = OpenResponsesStreamEvent::ResponseFailed {
938 sequence_number: seq,
939 response,
940 };
941
942 self.done_state = DoneState::SendingDone;
943 self.events.push(event.clone());
944 Poll::Ready(Some(
945 Event::default().event("response.failed").json_data(event),
946 ))
947 }
948 Response::ValidationError(e) => {
949 let seq = self.streaming_state.next_sequence_number();
950 let event = OpenResponsesStreamEvent::Error {
951 sequence_number: seq,
952 error: ResponseError::new(
953 "validation_error",
954 sanitize_error_message(e.as_ref()),
955 ),
956 };
957 self.done_state = DoneState::SendingDone;
958 self.events.push(event.clone());
959 Poll::Ready(Some(Event::default().event("error").json_data(event)))
960 }
961 Response::InternalError(e) => {
962 MistralRs::maybe_log_error(self.state.clone(), &*e);
963 let seq = self.streaming_state.next_sequence_number();
964 let event = OpenResponsesStreamEvent::Error {
965 sequence_number: seq,
966 error: ResponseError::new(
967 "internal_error",
968 sanitize_error_message(e.as_ref()),
969 ),
970 };
971 self.done_state = DoneState::SendingDone;
972 self.events.push(event.clone());
973 Poll::Ready(Some(Event::default().event("error").json_data(event)))
974 }
975 Response::Chunk(chat_chunk) => {
976 let mut events_to_emit = Vec::new();
977
978 if !self.streaming_state.in_progress_sent {
980 self.streaming_state.in_progress_sent = true;
981 let seq = self.streaming_state.next_sequence_number();
982 let response = self.build_response_resource(ResponseStatus::InProgress);
983 events_to_emit.push(OpenResponsesStreamEvent::ResponseInProgress {
984 sequence_number: seq,
985 response,
986 });
987 }
988
989 let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
991
992 for choice in &chat_chunk.choices {
993 if let Some(reasoning) = &choice.delta.reasoning_content {
995 self.accumulated_reasoning.push_str(reasoning);
996 }
997
998 if let Some(content) = &choice.delta.content {
1000 if !self.output_item_added {
1002 self.output_item_added = true;
1003 let seq = self.streaming_state.next_sequence_number();
1004 let item = OutputItem::message(
1005 format!("msg_{}", Uuid::new_v4()),
1006 vec![],
1007 ItemStatus::InProgress,
1008 );
1009 events_to_emit.push(OpenResponsesStreamEvent::OutputItemAdded {
1010 sequence_number: seq,
1011 output_index: 0,
1012 item,
1013 });
1014 }
1015
1016 if !self.content_part_added {
1018 self.content_part_added = true;
1019 let seq = self.streaming_state.next_sequence_number();
1020 let part = OutputContent::text(String::new());
1021 events_to_emit.push(OpenResponsesStreamEvent::ContentPartAdded {
1022 sequence_number: seq,
1023 output_index: 0,
1024 content_index: 0,
1025 part,
1026 });
1027 }
1028
1029 self.accumulated_text.push_str(content);
1031
1032 let seq = self.streaming_state.next_sequence_number();
1034 events_to_emit.push(OpenResponsesStreamEvent::OutputTextDelta {
1035 sequence_number: seq,
1036 output_index: 0,
1037 content_index: 0,
1038 delta: content.clone(),
1039 });
1040 }
1041
1042 if let Some(tool_calls) = &choice.delta.tool_calls {
1044 for tool_call in tool_calls {
1045 let seq = self.streaming_state.next_sequence_number();
1047 events_to_emit.push(
1048 OpenResponsesStreamEvent::FunctionCallArgumentsDelta {
1049 sequence_number: seq,
1050 output_index: 0,
1051 call_id: tool_call.id.clone(),
1052 delta: tool_call.function.arguments.clone(),
1053 },
1054 );
1055 }
1056 }
1057 }
1058
1059 if all_finished {
1061 if self.content_part_added {
1063 let seq = self.streaming_state.next_sequence_number();
1064 let part = OutputContent::text(self.accumulated_text.clone());
1065 events_to_emit.push(OpenResponsesStreamEvent::ContentPartDone {
1066 sequence_number: seq,
1067 output_index: 0,
1068 content_index: 0,
1069 part,
1070 });
1071 }
1072
1073 if self.output_item_added {
1075 let seq = self.streaming_state.next_sequence_number();
1076 let content = vec![OutputContent::text(self.accumulated_text.clone())];
1077 let item = OutputItem::message(
1078 format!("msg_{}", Uuid::new_v4()),
1079 content,
1080 ItemStatus::Completed,
1081 );
1082 events_to_emit.push(OpenResponsesStreamEvent::OutputItemDone {
1083 sequence_number: seq,
1084 output_index: 0,
1085 item,
1086 });
1087 }
1088
1089 let seq = self.streaming_state.next_sequence_number();
1091 let mut response = self.build_current_response(ResponseStatus::Completed);
1092 response.completed_at = Some(
1093 SystemTime::now()
1094 .duration_since(UNIX_EPOCH)
1095 .unwrap()
1096 .as_secs(),
1097 );
1098
1099 if let Some(usage) = &chat_chunk.usage {
1101 response.usage = Some(ResponseUsage::new(
1102 usage.prompt_tokens,
1103 usage.completion_tokens,
1104 ));
1105 }
1106
1107 events_to_emit.push(OpenResponsesStreamEvent::ResponseCompleted {
1108 sequence_number: seq,
1109 response,
1110 });
1111
1112 self.done_state = DoneState::SendingDone;
1113 }
1114
1115 MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
1116
1117 if !events_to_emit.is_empty() {
1119 let first_event = events_to_emit.remove(0);
1120 self.pending_events.extend(events_to_emit);
1121 self.events.push(first_event.clone());
1122 Poll::Ready(Some(
1123 Event::default()
1124 .event(get_event_type(&first_event))
1125 .json_data(first_event),
1126 ))
1127 } else {
1128 Poll::Pending
1129 }
1130 }
1131 Response::Done(chat_resp) => {
1132 let seq = self.streaming_state.next_sequence_number();
1135 let response = chat_response_to_response_resource(
1136 &chat_resp,
1137 self.streaming_state.response_id.clone(),
1138 self.metadata.clone(),
1139 &self.request_context,
1140 );
1141 let event = OpenResponsesStreamEvent::ResponseCompleted {
1142 sequence_number: seq,
1143 response,
1144 };
1145 self.done_state = DoneState::SendingDone;
1146 self.events.push(event.clone());
1147 Poll::Ready(Some(
1148 Event::default()
1149 .event("response.completed")
1150 .json_data(event),
1151 ))
1152 }
1153 _ => Poll::Pending,
1154 },
1155 Poll::Pending | Poll::Ready(None) => Poll::Pending,
1156 }
1157 }
1158}
1159
1160fn get_event_type(event: &OpenResponsesStreamEvent) -> &'static str {
1162 match event {
1163 OpenResponsesStreamEvent::ResponseCreated { .. } => "response.created",
1164 OpenResponsesStreamEvent::ResponseInProgress { .. } => "response.in_progress",
1165 OpenResponsesStreamEvent::OutputItemAdded { .. } => "response.output_item.added",
1166 OpenResponsesStreamEvent::ContentPartAdded { .. } => "response.content_part.added",
1167 OpenResponsesStreamEvent::OutputTextDelta { .. } => "response.output_text.delta",
1168 OpenResponsesStreamEvent::ContentPartDone { .. } => "response.content_part.done",
1169 OpenResponsesStreamEvent::OutputItemDone { .. } => "response.output_item.done",
1170 OpenResponsesStreamEvent::FunctionCallArgumentsDelta { .. } => {
1171 "response.function_call_arguments.delta"
1172 }
1173 OpenResponsesStreamEvent::FunctionCallArgumentsDone { .. } => {
1174 "response.function_call_arguments.done"
1175 }
1176 OpenResponsesStreamEvent::ResponseCompleted { .. } => "response.completed",
1177 OpenResponsesStreamEvent::ResponseFailed { .. } => "response.failed",
1178 OpenResponsesStreamEvent::ResponseIncomplete { .. } => "response.incomplete",
1179 OpenResponsesStreamEvent::Error { .. } => "error",
1180 }
1181}
1182
1183pub type OpenResponsesResponder =
1185 BaseCompletionResponder<ResponseResource, KeepAliveStream<OpenResponsesStreamer>>;
1186
1187type JsonModelError = BaseJsonModelError<ResponseResource>;
1188impl ErrorToResponse for JsonModelError {}
1189
1190impl IntoResponse for OpenResponsesResponder {
1191 fn into_response(self) -> axum::response::Response {
1192 match self {
1193 OpenResponsesResponder::Sse(s) => s.into_response(),
1194 OpenResponsesResponder::Json(s) => Json(s).into_response(),
1195 OpenResponsesResponder::InternalError(e) => {
1196 JsonError::new(sanitize_error_message(e.as_ref()))
1197 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
1198 }
1199 OpenResponsesResponder::ValidationError(e) => {
1200 JsonError::new(sanitize_error_message(e.as_ref()))
1201 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
1202 }
1203 OpenResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
1204 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
1205 }
1206 }
1207}
1208
1209fn chat_response_to_response_resource(
1211 chat_resp: &ChatCompletionResponse,
1212 request_id: String,
1213 metadata: Option<Value>,
1214 request_ctx: &RequestContext,
1215) -> ResponseResource {
1216 let created_at = chat_resp.created;
1217 let mut resource = ResponseResource::new(request_id, chat_resp.model.clone(), created_at);
1218
1219 let mut output_items = Vec::new();
1220 let mut output_text_parts = Vec::new();
1221 let mut reasoning_parts = Vec::new();
1222
1223 for choice in &chat_resp.choices {
1224 let mut content_items = Vec::new();
1225
1226 if let Some(text) = &choice.message.content {
1228 output_text_parts.push(text.clone());
1229 content_items.push(OutputContent::text(text.clone()));
1230 }
1231
1232 if let Some(reasoning) = &choice.message.reasoning_content {
1234 reasoning_parts.push(reasoning.clone());
1235 }
1236
1237 if let Some(tool_calls) = &choice.message.tool_calls {
1239 for tool_call in tool_calls {
1240 let item = OutputItem::function_call(
1241 format!("fc_{}", Uuid::new_v4()),
1242 tool_call.id.clone(),
1243 tool_call.function.name.clone(),
1244 tool_call.function.arguments.clone(),
1245 ItemStatus::Completed,
1246 );
1247 output_items.push(item);
1248 }
1249 }
1250
1251 if !content_items.is_empty() {
1253 let item = OutputItem::message(
1254 format!("msg_{}", Uuid::new_v4()),
1255 content_items,
1256 ItemStatus::Completed,
1257 );
1258 output_items.push(item);
1259 }
1260 }
1261
1262 resource.status = ResponseStatus::Completed;
1263 resource.output = output_items;
1264 resource.output_text = if output_text_parts.is_empty() {
1265 None
1266 } else {
1267 Some(output_text_parts.join(""))
1268 };
1269 resource.reasoning = if reasoning_parts.is_empty() {
1270 None
1271 } else {
1272 Some(reasoning_parts.join(""))
1273 };
1274 resource.usage = Some(ResponseUsage::new(
1275 chat_resp.usage.prompt_tokens,
1276 chat_resp.usage.completion_tokens,
1277 ));
1278 resource.metadata = metadata;
1279 resource.completed_at = Some(
1280 SystemTime::now()
1281 .duration_since(UNIX_EPOCH)
1282 .unwrap()
1283 .as_secs(),
1284 );
1285
1286 resource.tools = request_ctx.tools.clone();
1288 resource.tool_choice = request_ctx.tool_choice.clone();
1289 resource.parallel_tool_calls = request_ctx.parallel_tool_calls;
1290 resource.text = request_ctx.text.clone();
1291 resource.temperature = request_ctx.temperature;
1292 resource.top_p = request_ctx.top_p;
1293 resource.presence_penalty = request_ctx.presence_penalty;
1294 resource.frequency_penalty = request_ctx.frequency_penalty;
1295 resource.top_logprobs = request_ctx.top_logprobs;
1296 resource.max_output_tokens = request_ctx.max_output_tokens;
1297 resource.max_tool_calls = request_ctx.max_tool_calls;
1298 resource.store = request_ctx.store;
1299 resource.background = request_ctx.background;
1300
1301 resource
1302}
1303
1304async fn parse_openresponses_request(
1306 oairequest: OpenResponsesCreateRequest,
1307 state: SharedMistralRsState,
1308 tx: Sender<Response>,
1309) -> Result<(
1310 Request,
1311 bool,
1312 Option<Vec<Message>>,
1313 IncludeConfig,
1314 RequestContext,
1315)> {
1316 if let Some(false) = oairequest.parallel_tool_calls {
1319 anyhow::bail!(
1320 "parallel_tool_calls=false is not supported. \
1321 mistral.rs does not currently support disabling parallel tool calls."
1322 );
1323 }
1324
1325 if oairequest.max_tool_calls.is_some() {
1327 anyhow::bail!(
1328 "max_tool_calls is not supported. \
1329 mistral.rs does not currently support limiting the number of tool calls."
1330 );
1331 }
1332
1333 let request_context = RequestContext {
1336 tools: oairequest.tools.clone(),
1337 tool_choice: oairequest.tool_choice.clone(),
1338 parallel_tool_calls: oairequest.parallel_tool_calls,
1339 text: oairequest.text.clone(),
1340 temperature: oairequest.temperature,
1341 top_p: oairequest.top_p,
1342 presence_penalty: oairequest.presence_penalty,
1343 frequency_penalty: oairequest.frequency_penalty,
1344 top_logprobs: oairequest.top_logprobs,
1345 max_output_tokens: oairequest.max_output_tokens,
1346 max_tool_calls: oairequest.max_tool_calls,
1347 store: oairequest.store,
1348 background: oairequest.background,
1349 };
1350
1351 let include_config = IncludeConfig::new(oairequest.include.clone());
1353
1354 let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
1356 let cache = get_response_cache();
1357 cache.get_conversation_history(prev_id)?
1358 } else {
1359 None
1360 };
1361
1362 let messages = oairequest.input.into_either();
1364
1365 let mut final_messages = Vec::new();
1367 if let Some(instructions) = &oairequest.instructions {
1368 final_messages.push(Message {
1369 content: Some(MessageContent::from_text(instructions.clone())),
1370 role: "system".to_string(),
1371 name: None,
1372 tool_calls: None,
1373 tool_call_id: None,
1374 });
1375 }
1376
1377 if let Some(prev_msgs) = previous_messages {
1379 final_messages.extend(prev_msgs);
1380 }
1381
1382 match messages {
1384 Either::Left(msgs) => final_messages.extend(msgs),
1385 Either::Right(prompt) => {
1386 final_messages.push(Message {
1387 content: Some(MessageContent::from_text(prompt)),
1388 role: "user".to_string(),
1389 name: None,
1390 tool_calls: None,
1391 tool_call_id: None,
1392 });
1393 }
1394 }
1395
1396 let (enable_thinking, reasoning_effort) = if let Some(ref reasoning) = oairequest.reasoning {
1398 let effort = reasoning.effort.map(|e| match e {
1399 crate::responses_types::enums::ReasoningEffort::None => "none".to_string(),
1400 crate::responses_types::enums::ReasoningEffort::Low => "low".to_string(),
1401 crate::responses_types::enums::ReasoningEffort::Medium => "medium".to_string(),
1402 crate::responses_types::enums::ReasoningEffort::High => "high".to_string(),
1403 });
1404 let thinking = reasoning
1406 .effort
1407 .map(|e| !matches!(e, crate::responses_types::enums::ReasoningEffort::None));
1408 (thinking, effort)
1409 } else {
1410 (None, None)
1411 };
1412
1413 let truncate_sequence = oairequest
1415 .truncation
1416 .map(|t| matches!(t, crate::responses_types::enums::TruncationStrategy::Auto));
1417
1418 let response_format = if let Some(text_config) = oairequest.text {
1420 text_config.format.map(|fmt| match fmt {
1421 TextFormat::Text => crate::openai::ResponseFormat::Text,
1422 TextFormat::JsonSchema {
1423 name,
1424 schema,
1425 strict: _,
1426 } => crate::openai::ResponseFormat::JsonSchema {
1427 json_schema: crate::openai::JsonSchemaResponseFormat {
1428 name,
1429 schema: schema.unwrap_or(serde_json::Value::Object(Default::default())),
1430 },
1431 },
1432 TextFormat::JsonObject => {
1433 crate::openai::ResponseFormat::JsonSchema {
1435 json_schema: crate::openai::JsonSchemaResponseFormat {
1436 name: "json_object".to_string(),
1437 schema: serde_json::json!({"type": "object"}),
1438 },
1439 }
1440 }
1441 })
1442 } else {
1443 oairequest.response_format
1444 };
1445
1446 let chat_request = ChatCompletionRequest {
1448 messages: Either::Left(final_messages.clone()),
1449 model: oairequest.model,
1450 logit_bias: oairequest.logit_bias,
1451 logprobs: oairequest.logprobs,
1452 top_logprobs: oairequest.top_logprobs,
1453 max_tokens: oairequest.max_output_tokens,
1454 n_choices: oairequest.n_choices,
1455 presence_penalty: oairequest.presence_penalty,
1456 frequency_penalty: oairequest.frequency_penalty,
1457 repetition_penalty: oairequest.repetition_penalty,
1458 stop_seqs: oairequest.stop_seqs,
1459 temperature: oairequest.temperature,
1460 top_p: oairequest.top_p,
1461 stream: oairequest.stream,
1462 tools: oairequest.tools,
1463 tool_choice: oairequest.tool_choice,
1464 response_format,
1465 web_search_options: oairequest.web_search_options,
1466 top_k: oairequest.top_k,
1467 grammar: oairequest.grammar,
1468 min_p: oairequest.min_p,
1469 dry_multiplier: oairequest.dry_multiplier,
1470 dry_base: oairequest.dry_base,
1471 dry_allowed_length: oairequest.dry_allowed_length,
1472 dry_sequence_breakers: oairequest.dry_sequence_breakers,
1473 enable_thinking,
1474 truncate_sequence,
1475 reasoning_effort,
1476 };
1477
1478 let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
1479 Ok((
1480 request,
1481 is_streaming,
1482 Some(final_messages),
1483 include_config,
1484 request_context,
1485 ))
1486}
1487
1488#[utoipa::path(
1490 post,
1491 tag = "Mistral.rs",
1492 path = "/v1/responses",
1493 request_body = OpenResponsesCreateRequest,
1494 responses((status = 200, description = "Response created"))
1495)]
1496pub async fn create_response(
1497 State(state): ExtractedMistralRsState,
1498 Json(oairequest): Json<OpenResponsesCreateRequest>,
1499) -> OpenResponsesResponder {
1500 let (tx, rx) = create_response_channel(None);
1501 let request_id = format!("resp_{}", Uuid::new_v4());
1502 let metadata = oairequest.metadata.clone();
1503 let store = oairequest.store.unwrap_or(true);
1504 let background = oairequest.background.unwrap_or(false);
1505
1506 let model_id = if oairequest.model == "default" {
1508 None
1509 } else {
1510 Some(oairequest.model.clone())
1511 };
1512
1513 let model_name = oairequest.model.clone();
1514
1515 if background {
1517 let task_manager = get_background_task_manager();
1518 let task_id = task_manager.create_task(model_name.clone());
1519
1520 let response = ResponseResource::new(
1522 task_id.clone(),
1523 model_name,
1524 SystemTime::now()
1525 .duration_since(UNIX_EPOCH)
1526 .unwrap()
1527 .as_secs(),
1528 )
1529 .with_status(ResponseStatus::Queued)
1530 .with_metadata(metadata.clone().unwrap_or(Value::Null));
1531
1532 let state_clone = state.clone();
1534 let metadata_clone = metadata.clone();
1535 tokio::spawn(async move {
1536 let (bg_tx, mut bg_rx) = create_response_channel(None);
1537
1538 let (request, _, conversation_history, _include_config, request_context) =
1539 match parse_openresponses_request(oairequest, state_clone.clone(), bg_tx).await {
1540 Ok(x) => x,
1541 Err(e) => {
1542 task_manager.mark_failed(
1543 &task_id,
1544 ResponseError::new("parse_error", e.to_string()),
1545 );
1546 return;
1547 }
1548 };
1549
1550 task_manager.mark_in_progress(&task_id);
1551
1552 if let Err(e) =
1553 send_request_with_model(&state_clone, request, model_id.as_deref()).await
1554 {
1555 task_manager.mark_failed(&task_id, ResponseError::new("send_error", e.to_string()));
1556 return;
1557 }
1558
1559 match bg_rx.recv().await {
1561 Some(Response::Done(chat_resp)) => {
1562 let response = chat_response_to_response_resource(
1563 &chat_resp,
1564 task_id.clone(),
1565 metadata_clone,
1566 &request_context,
1567 );
1568
1569 if store {
1571 let cache = get_response_cache();
1572 let _ = cache.store_response(task_id.clone(), response.clone());
1573
1574 if let Some(mut history) = conversation_history {
1575 for choice in &chat_resp.choices {
1576 if let Some(content) = &choice.message.content {
1577 history.push(Message {
1578 content: Some(MessageContent::from_text(content.clone())),
1579 role: choice.message.role.clone(),
1580 name: None,
1581 tool_calls: None,
1582 tool_call_id: None,
1583 });
1584 }
1585 }
1586 let _ = cache.store_conversation_history(task_id.clone(), history);
1587 }
1588 }
1589
1590 task_manager.mark_completed(&task_id, response);
1591 }
1592 Some(Response::ModelError(msg, _partial_resp)) => {
1593 task_manager
1594 .mark_failed(&task_id, ResponseError::new("model_error", msg.to_string()));
1595 }
1596 Some(Response::ValidationError(e)) => {
1597 task_manager.mark_failed(
1598 &task_id,
1599 ResponseError::new("validation_error", e.to_string()),
1600 );
1601 }
1602 Some(Response::InternalError(e)) => {
1603 task_manager.mark_failed(
1604 &task_id,
1605 ResponseError::new("internal_error", e.to_string()),
1606 );
1607 }
1608 _ => {
1609 task_manager.mark_failed(
1610 &task_id,
1611 ResponseError::new("unknown_error", "Unexpected response type"),
1612 );
1613 }
1614 }
1615 });
1616
1617 return OpenResponsesResponder::Json(response);
1618 }
1619
1620 let (request, is_streaming, conversation_history, _include_config, request_context) =
1621 match parse_openresponses_request(oairequest, state.clone(), tx).await {
1622 Ok(x) => x,
1623 Err(e) => return handle_error(state, e.into()),
1624 };
1625
1626 if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
1627 return handle_error(state, e.into());
1628 }
1629
1630 if is_streaming {
1631 let streamer = OpenResponsesStreamer::new(
1632 rx,
1633 state.clone(),
1634 request_id.clone(),
1635 model_name,
1636 metadata,
1637 store,
1638 conversation_history,
1639 request_context,
1640 );
1641
1642 let keep_alive_interval = get_keep_alive_interval();
1643 let sse = Sse::new(streamer)
1644 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)));
1645
1646 OpenResponsesResponder::Sse(sse)
1647 } else {
1648 let mut rx = rx;
1650 match rx.recv().await {
1651 Some(Response::Done(chat_resp)) => {
1652 let response = chat_response_to_response_resource(
1653 &chat_resp,
1654 request_id.clone(),
1655 metadata,
1656 &request_context,
1657 );
1658
1659 if store {
1661 let cache = get_response_cache();
1662 let _ = cache.store_response(request_id.clone(), response.clone());
1663
1664 if let Some(mut history) = conversation_history {
1665 for choice in &chat_resp.choices {
1666 if let Some(content) = &choice.message.content {
1667 history.push(Message {
1668 content: Some(MessageContent::from_text(content.clone())),
1669 role: choice.message.role.clone(),
1670 name: None,
1671 tool_calls: None,
1672 tool_call_id: None,
1673 });
1674 }
1675 }
1676 let _ = cache.store_conversation_history(request_id, history);
1677 }
1678 }
1679
1680 OpenResponsesResponder::Json(response)
1681 }
1682 Some(Response::ModelError(msg, partial_resp)) => {
1683 let mut response = chat_response_to_response_resource(
1684 &partial_resp,
1685 request_id.clone(),
1686 metadata,
1687 &request_context,
1688 );
1689 response.error = Some(ResponseError::new("model_error", msg.to_string()));
1690 response.status = ResponseStatus::Failed;
1691
1692 if store {
1693 let cache = get_response_cache();
1694 let _ = cache.store_response(request_id.clone(), response.clone());
1695 }
1696
1697 OpenResponsesResponder::ModelError(msg.to_string(), response)
1698 }
1699 Some(Response::ValidationError(e)) => OpenResponsesResponder::ValidationError(e),
1700 Some(Response::InternalError(e)) => OpenResponsesResponder::InternalError(e),
1701 _ => OpenResponsesResponder::InternalError(
1702 anyhow::anyhow!("Unexpected response type").into(),
1703 ),
1704 }
1705 }
1706}
1707
1708#[utoipa::path(
1710 get,
1711 tag = "Mistral.rs",
1712 path = "/v1/responses/{response_id}",
1713 params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
1714 responses((status = 200, description = "Response object"))
1715)]
1716pub async fn get_response(
1717 State(_state): ExtractedMistralRsState,
1718 Path(response_id): Path<String>,
1719) -> impl IntoResponse {
1720 let task_manager = get_background_task_manager();
1722 if let Some(response) = task_manager.get_response(&response_id) {
1723 return (StatusCode::OK, Json(response)).into_response();
1724 }
1725
1726 let cache = get_response_cache();
1728 match cache.get_response(&response_id) {
1729 Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
1730 Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
1731 .to_response(StatusCode::NOT_FOUND),
1732 Err(e) => JsonError::new(format!(
1733 "Error retrieving response: {}",
1734 sanitize_error_message(&*e)
1735 ))
1736 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
1737 }
1738}
1739
1740#[utoipa::path(
1742 delete,
1743 tag = "Mistral.rs",
1744 path = "/v1/responses/{response_id}",
1745 params(("response_id" = String, Path, description = "The ID of the response to delete")),
1746 responses((status = 200, description = "Response deleted"))
1747)]
1748pub async fn delete_response(
1749 State(_state): ExtractedMistralRsState,
1750 Path(response_id): Path<String>,
1751) -> impl IntoResponse {
1752 let task_manager = get_background_task_manager();
1754 let task_deleted = task_manager.delete_task(&response_id);
1755
1756 let cache = get_response_cache();
1758 match cache.delete_response(&response_id) {
1759 Ok(cache_deleted) => {
1760 if task_deleted || cache_deleted {
1761 (
1762 StatusCode::OK,
1763 Json(serde_json::json!({
1764 "deleted": true,
1765 "id": response_id,
1766 "object": "response.deleted"
1767 })),
1768 )
1769 .into_response()
1770 } else {
1771 JsonError::new(format!("Response with ID '{response_id}' not found"))
1772 .to_response(StatusCode::NOT_FOUND)
1773 }
1774 }
1775 Err(e) => JsonError::new(format!(
1776 "Error deleting response: {}",
1777 sanitize_error_message(&*e)
1778 ))
1779 .to_response(StatusCode::INTERNAL_SERVER_ERROR),
1780 }
1781}
1782
1783#[utoipa::path(
1785 post,
1786 tag = "Mistral.rs",
1787 path = "/v1/responses/{response_id}/cancel",
1788 params(("response_id" = String, Path, description = "The ID of the response to cancel")),
1789 responses((status = 200, description = "Response cancelled"))
1790)]
1791pub async fn cancel_response(
1792 State(_state): ExtractedMistralRsState,
1793 Path(response_id): Path<String>,
1794) -> impl IntoResponse {
1795 let task_manager = get_background_task_manager();
1796
1797 if task_manager.request_cancel(&response_id) {
1798 task_manager.mark_cancelled(&response_id);
1799
1800 if let Some(response) = task_manager.get_response(&response_id) {
1801 return (StatusCode::OK, Json(response)).into_response();
1802 }
1803 }
1804
1805 JsonError::new(format!(
1806 "Response with ID '{response_id}' not found or cannot be cancelled"
1807 ))
1808 .to_response(StatusCode::NOT_FOUND)
1809}
1810
1811fn handle_error(
1813 state: SharedMistralRsState,
1814 e: Box<dyn std::error::Error + Send + Sync + 'static>,
1815) -> OpenResponsesResponder {
1816 handle_completion_error(state, e)
1817}