mistralrs_server_core/
responses.rs

1//! ## Responses API functionality and route handlers.
2
3use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7    extract::{Json, Path, State},
8    http::{self, StatusCode},
9    response::{
10        sse::{Event, KeepAlive},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21    cached_responses::get_response_cache,
22    chat_completion::parse_request as parse_chat_request,
23    completion_core::{handle_completion_error, BaseCompletionResponder},
24    handler_core::{
25        create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26        JsonError, ModelErrorMessage,
27    },
28    openai::{
29        ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30        ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31        ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32    },
33    streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35    util::sanitize_error_message,
36};
37
38/// Response streamer for the Responses API
39pub type ResponsesStreamer =
40    BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43    type Item = Result<Event, axum::Error>;
44
45    fn poll_next(
46        mut self: Pin<&mut Self>,
47        cx: &mut std::task::Context<'_>,
48    ) -> Poll<Option<Self::Item>> {
49        match self.done_state {
50            DoneState::SendingDone => {
51                self.done_state = DoneState::Done;
52                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53            }
54            DoneState::Done => {
55                if let Some(on_done) = &self.on_done {
56                    on_done(&self.chunks);
57                }
58                return Poll::Ready(None);
59            }
60            DoneState::Running => (),
61        }
62
63        match self.rx.poll_recv(cx) {
64            Poll::Ready(Some(resp)) => match resp {
65                Response::ModelError(msg, _) => {
66                    MistralRs::maybe_log_error(
67                        self.state.clone(),
68                        &ModelErrorMessage(msg.to_string()),
69                    );
70                    self.done_state = DoneState::SendingDone;
71                    Poll::Ready(Some(Ok(Event::default().data(msg))))
72                }
73                Response::ValidationError(e) => Poll::Ready(Some(Ok(
74                    Event::default().data(sanitize_error_message(e.as_ref()))
75                ))),
76                Response::InternalError(e) => {
77                    MistralRs::maybe_log_error(self.state.clone(), &*e);
78                    Poll::Ready(Some(Ok(
79                        Event::default().data(sanitize_error_message(e.as_ref()))
80                    )))
81                }
82                Response::Chunk(chat_chunk) => {
83                    // Convert ChatCompletionChunkResponse to ResponsesChunk
84                    let mut delta_outputs = vec![];
85
86                    // Check if all choices are finished
87                    let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89                    for choice in &chat_chunk.choices {
90                        let mut delta_content_items = Vec::new();
91
92                        // Handle text content in delta
93                        if let Some(content) = &choice.delta.content {
94                            delta_content_items.push(ResponsesDeltaContent {
95                                content_type: "output_text".to_string(),
96                                text: Some(content.clone()),
97                            });
98                        }
99
100                        // Handle tool calls in delta
101                        if let Some(tool_calls) = &choice.delta.tool_calls {
102                            for tool_call in tool_calls {
103                                let tool_text = format!(
104                                    "Tool: {} args: {}",
105                                    tool_call.function.name, tool_call.function.arguments
106                                );
107                                delta_content_items.push(ResponsesDeltaContent {
108                                    content_type: "tool_use".to_string(),
109                                    text: Some(tool_text),
110                                });
111                            }
112                        }
113
114                        if !delta_content_items.is_empty() {
115                            delta_outputs.push(ResponsesDeltaOutput {
116                                id: format!("msg_{}", Uuid::new_v4()),
117                                output_type: "message".to_string(),
118                                content: Some(delta_content_items),
119                            });
120                        }
121                    }
122
123                    let mut response_chunk = ResponsesChunk {
124                        id: chat_chunk.id.clone(),
125                        object: "response.chunk",
126                        created_at: chat_chunk.created as f64,
127                        model: chat_chunk.model.clone(),
128                        chunk_type: "delta".to_string(),
129                        delta: Some(ResponsesDelta {
130                            output: if delta_outputs.is_empty() {
131                                None
132                            } else {
133                                Some(delta_outputs)
134                            },
135                            status: if all_finished {
136                                Some("completed".to_string())
137                            } else {
138                                None
139                            },
140                        }),
141                        usage: None,
142                        metadata: None,
143                    };
144
145                    if all_finished {
146                        self.done_state = DoneState::SendingDone;
147                    }
148
149                    MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151                    if let Some(on_chunk) = &self.on_chunk {
152                        response_chunk = on_chunk(response_chunk);
153                    }
154
155                    if self.store_chunks {
156                        self.chunks.push(response_chunk.clone());
157                    }
158
159                    Poll::Ready(Some(Event::default().json_data(response_chunk)))
160                }
161                _ => unreachable!(),
162            },
163            Poll::Pending | Poll::Ready(None) => Poll::Pending,
164        }
165    }
166}
167
168/// Response responder types
169pub type ResponsesResponder = BaseCompletionResponder<ResponsesObject, ResponsesStreamer>;
170
171type JsonModelError = BaseJsonModelError<ResponsesObject>;
172impl ErrorToResponse for JsonModelError {}
173
174impl IntoResponse for ResponsesResponder {
175    fn into_response(self) -> axum::response::Response {
176        match self {
177            ResponsesResponder::Sse(s) => s.into_response(),
178            ResponsesResponder::Json(s) => Json(s).into_response(),
179            ResponsesResponder::InternalError(e) => {
180                JsonError::new(sanitize_error_message(e.as_ref()))
181                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
182            }
183            ResponsesResponder::ValidationError(e) => {
184                JsonError::new(sanitize_error_message(e.as_ref()))
185                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
186            }
187            ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
188                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
189        }
190    }
191}
192
193/// Convert chat completion response to responses object
194fn chat_response_to_responses_object(
195    chat_resp: &ChatCompletionResponse,
196    request_id: String,
197    metadata: Option<Value>,
198) -> ResponsesObject {
199    let mut outputs = Vec::new();
200    let mut output_text_parts = Vec::new();
201
202    for choice in &chat_resp.choices {
203        let mut content_items = Vec::new();
204        let mut has_content = false;
205
206        // Handle text content
207        if let Some(text) = &choice.message.content {
208            output_text_parts.push(text.clone());
209            content_items.push(ResponsesContent {
210                content_type: "output_text".to_string(),
211                text: Some(text.clone()),
212                annotations: None,
213            });
214            has_content = true;
215        }
216
217        // Handle tool calls
218        if let Some(tool_calls) = &choice.message.tool_calls {
219            for tool_call in tool_calls {
220                let tool_text = format!(
221                    "Tool call: {} with args: {}",
222                    tool_call.function.name, tool_call.function.arguments
223                );
224                content_items.push(ResponsesContent {
225                    content_type: "tool_use".to_string(),
226                    text: Some(tool_text),
227                    annotations: None,
228                });
229                has_content = true;
230            }
231        }
232
233        // Only add output if we have content
234        if has_content {
235            outputs.push(ResponsesOutput {
236                id: format!("msg_{}", Uuid::new_v4()),
237                output_type: "message".to_string(),
238                role: choice.message.role.clone(),
239                status: None,
240                content: content_items,
241            });
242        }
243    }
244
245    ResponsesObject {
246        id: request_id,
247        object: "response",
248        created_at: chat_resp.created as f64,
249        model: chat_resp.model.clone(),
250        status: "completed".to_string(),
251        output: outputs,
252        output_text: if output_text_parts.is_empty() {
253            None
254        } else {
255            Some(output_text_parts.join(" "))
256        },
257        usage: Some(ResponsesUsage {
258            input_tokens: chat_resp.usage.prompt_tokens,
259            output_tokens: chat_resp.usage.completion_tokens,
260            total_tokens: chat_resp.usage.total_tokens,
261            input_tokens_details: None,
262            output_tokens_details: None,
263        }),
264        error: None,
265        metadata,
266        instructions: None,
267        incomplete_details: None,
268    }
269}
270
271/// Parse responses request into internal format
272async fn parse_responses_request(
273    oairequest: ResponsesCreateRequest,
274    state: SharedMistralRsState,
275    tx: Sender<Response>,
276) -> Result<(Request, bool, Option<Vec<Message>>)> {
277    if oairequest.instructions.is_some() {
278        return Err(anyhow::anyhow!(
279            "The 'instructions' field is not supported in the Responses API"
280        ));
281    }
282    // If previous_response_id is provided, get the full conversation history from cache
283    let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
284        let cache = get_response_cache();
285        cache.get_conversation_history(prev_id)?
286    } else {
287        None
288    };
289
290    // Get messages from either messages or input field
291    let messages = oairequest.input.into_either();
292
293    // Convert to ChatCompletionRequest for reuse
294    let mut chat_request = ChatCompletionRequest {
295        messages: messages.clone(),
296        model: oairequest.model,
297        logit_bias: oairequest.logit_bias,
298        logprobs: oairequest.logprobs,
299        top_logprobs: oairequest.top_logprobs,
300        max_tokens: oairequest.max_tokens,
301        n_choices: oairequest.n_choices,
302        presence_penalty: oairequest.presence_penalty,
303        frequency_penalty: oairequest.frequency_penalty,
304        repetition_penalty: oairequest.repetition_penalty,
305        stop_seqs: oairequest.stop_seqs,
306        temperature: oairequest.temperature,
307        top_p: oairequest.top_p,
308        stream: oairequest.stream,
309        tools: oairequest.tools,
310        tool_choice: oairequest.tool_choice,
311        response_format: oairequest.response_format,
312        web_search_options: oairequest.web_search_options,
313        top_k: oairequest.top_k,
314        grammar: oairequest.grammar,
315        min_p: oairequest.min_p,
316        dry_multiplier: oairequest.dry_multiplier,
317        dry_base: oairequest.dry_base,
318        dry_allowed_length: oairequest.dry_allowed_length,
319        dry_sequence_breakers: oairequest.dry_sequence_breakers,
320        enable_thinking: oairequest.enable_thinking,
321    };
322
323    // Prepend previous messages if available
324    if let Some(prev_msgs) = previous_messages {
325        match &mut chat_request.messages {
326            Either::Left(msgs) => {
327                let mut combined = prev_msgs;
328                combined.extend(msgs.clone());
329                chat_request.messages = Either::Left(combined);
330            }
331            Either::Right(_) => {
332                // If it's a prompt string, convert to messages and prepend
333                let prompt = chat_request.messages.as_ref().right().unwrap().clone();
334                let mut combined = prev_msgs;
335                combined.push(Message {
336                    content: Some(MessageContent::from_text(prompt)),
337                    role: "user".to_string(),
338                    name: None,
339                    tool_calls: None,
340                });
341                chat_request.messages = Either::Left(combined);
342            }
343        }
344    }
345
346    // Get all messages for prompt_details
347    let all_messages = match &chat_request.messages {
348        Either::Left(msgs) => msgs.clone(),
349        Either::Right(prompt) => vec![Message {
350            content: Some(MessageContent::from_text(prompt.clone())),
351            role: "user".to_string(),
352            name: None,
353            tool_calls: None,
354        }],
355    };
356
357    let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
358    Ok((request, is_streaming, Some(all_messages)))
359}
360
361/// Create response endpoint
362#[utoipa::path(
363    post,
364    tag = "Mistral.rs",
365    path = "/v1/responses",
366    request_body = ResponsesCreateRequest,
367    responses((status = 200, description = "Response created"))
368)]
369pub async fn create_response(
370    State(state): ExtractedMistralRsState,
371    Json(oairequest): Json<ResponsesCreateRequest>,
372) -> ResponsesResponder {
373    let (tx, mut rx) = create_response_channel(None);
374    let request_id = format!("resp_{}", Uuid::new_v4());
375    let metadata = oairequest.metadata.clone();
376    let store = oairequest.store.unwrap_or(true);
377
378    // Extract model_id for routing
379    let model_id = if oairequest.model == "default" {
380        None
381    } else {
382        Some(oairequest.model.clone())
383    };
384
385    let (request, is_streaming, conversation_history) =
386        match parse_responses_request(oairequest, state.clone(), tx).await {
387            Ok(x) => x,
388            Err(e) => return handle_error(state, e.into()),
389        };
390
391    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
392        return handle_error(state, e.into());
393    }
394
395    if is_streaming {
396        let streamer = ResponsesStreamer {
397            rx,
398            done_state: DoneState::Running,
399            state: state.clone(),
400            on_chunk: None,
401            on_done: None,
402            chunks: Vec::new(),
403            store_chunks: store,
404        };
405
406        // Store chunks for later retrieval if requested
407        if store {
408            let cache = get_response_cache();
409            let id = request_id.clone();
410            let chunks_cache = cache.clone();
411
412            // Create a wrapper that stores chunks and conversation history
413            let history_for_streaming = conversation_history.clone();
414            let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
415                let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
416
417                // Reconstruct the assistant's message from chunks and store conversation history
418                if let Some(history) = history_for_streaming.clone() {
419                    let mut history = history;
420                    let mut assistant_message = String::new();
421
422                    // Collect all text from chunks
423                    for chunk in chunks {
424                        if let Some(delta) = &chunk.delta {
425                            if let Some(outputs) = &delta.output {
426                                for output in outputs {
427                                    if let Some(contents) = &output.content {
428                                        for content in contents {
429                                            if let Some(text) = &content.text {
430                                                assistant_message.push_str(text);
431                                            }
432                                        }
433                                    }
434                                }
435                            }
436                        }
437                    }
438
439                    // Add the complete assistant message to history
440                    if !assistant_message.is_empty() {
441                        history.push(Message {
442                            content: Some(MessageContent::from_text(assistant_message)),
443                            role: "assistant".to_string(),
444                            name: None,
445                            tool_calls: None,
446                        });
447                    }
448
449                    let _ = chunks_cache.store_conversation_history(id.clone(), history);
450                }
451            });
452
453            ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
454        } else {
455            ResponsesResponder::Sse(create_streamer(streamer, None))
456        }
457    } else {
458        // Non-streaming response
459        match rx.recv().await {
460            Some(Response::Done(chat_resp)) => {
461                let response_obj =
462                    chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
463
464                // Store if requested
465                if store {
466                    let cache = get_response_cache();
467                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
468
469                    // Create complete conversation history including the assistant's response
470                    if let Some(mut history) = conversation_history.clone() {
471                        // Add the assistant's response to the conversation history
472                        for choice in &chat_resp.choices {
473                            if let Some(content) = &choice.message.content {
474                                history.push(Message {
475                                    content: Some(MessageContent::from_text(content.clone())),
476                                    role: choice.message.role.clone(),
477                                    name: None,
478                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
479                                });
480                            }
481                        }
482                        let _ = cache.store_conversation_history(request_id, history);
483                    }
484                }
485
486                ResponsesResponder::Json(response_obj)
487            }
488            Some(Response::ModelError(msg, partial_resp)) => {
489                let mut response_obj =
490                    chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
491                response_obj.error = Some(ResponsesError {
492                    error_type: "model_error".to_string(),
493                    message: msg.to_string(),
494                });
495                response_obj.status = "failed".to_string();
496
497                if store {
498                    let cache = get_response_cache();
499                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
500
501                    // Even on error, store conversation history with partial response
502                    if let Some(mut history) = conversation_history.clone() {
503                        // Add any partial response to the conversation history
504                        for choice in &partial_resp.choices {
505                            if let Some(content) = &choice.message.content {
506                                history.push(Message {
507                                    content: Some(MessageContent::from_text(content.clone())),
508                                    role: choice.message.role.clone(),
509                                    name: None,
510                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
511                                });
512                            }
513                        }
514                        let _ = cache.store_conversation_history(request_id, history);
515                    }
516                }
517                ResponsesResponder::ModelError(msg.to_string(), response_obj)
518            }
519            Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
520            Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
521            _ => ResponsesResponder::InternalError(
522                anyhow::anyhow!("Unexpected response type").into(),
523            ),
524        }
525    }
526}
527
528/// Get response by ID endpoint
529#[utoipa::path(
530    get,
531    tag = "Mistral.rs",
532    path = "/v1/responses/{response_id}",
533    params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
534    responses((status = 200, description = "Response object"))
535)]
536pub async fn get_response(
537    State(_state): ExtractedMistralRsState,
538    Path(response_id): Path<String>,
539) -> impl IntoResponse {
540    let cache = get_response_cache();
541
542    match cache.get_response(&response_id) {
543        Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
544        Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
545            .to_response(StatusCode::NOT_FOUND),
546        Err(e) => JsonError::new(format!(
547            "Error retrieving response: {}",
548            sanitize_error_message(&*e)
549        ))
550        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
551    }
552}
553
554/// Delete response by ID endpoint
555#[utoipa::path(
556    delete,
557    tag = "Mistral.rs",
558    path = "/v1/responses/{response_id}",
559    params(("response_id" = String, Path, description = "The ID of the response to delete")),
560    responses((status = 200, description = "Response deleted"))
561)]
562pub async fn delete_response(
563    State(_state): ExtractedMistralRsState,
564    Path(response_id): Path<String>,
565) -> impl IntoResponse {
566    let cache = get_response_cache();
567
568    match cache.delete_response(&response_id) {
569        Ok(true) => (
570            StatusCode::OK,
571            Json(serde_json::json!({
572                "deleted": true,
573                "id": response_id,
574                "object": "response.deleted"
575            })),
576        )
577            .into_response(),
578        Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
579            .to_response(StatusCode::NOT_FOUND),
580        Err(e) => JsonError::new(format!(
581            "Error deleting response: {}",
582            sanitize_error_message(&*e)
583        ))
584        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
585    }
586}
587
588/// Handle errors
589fn handle_error(
590    state: SharedMistralRsState,
591    e: Box<dyn std::error::Error + Send + Sync + 'static>,
592) -> ResponsesResponder {
593    handle_completion_error(state, e)
594}
595
596/// Create SSE streamer
597fn create_streamer(
598    streamer: ResponsesStreamer,
599    on_done: Option<OnDoneCallback<ResponsesChunk>>,
600) -> Sse<ResponsesStreamer> {
601    let keep_alive_interval = get_keep_alive_interval();
602
603    let streamer_with_callback = ResponsesStreamer {
604        on_done,
605        ..streamer
606    };
607
608    Sse::new(streamer_with_callback)
609        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
610}