mistralrs_server_core/
responses.rs

1//! ## Responses API functionality and route handlers.
2
3use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7    extract::{Json, Path, State},
8    http::{self, StatusCode},
9    response::{
10        sse::{Event, KeepAlive, KeepAliveStream},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21    cached_responses::get_response_cache,
22    chat_completion::parse_request as parse_chat_request,
23    completion_core::{handle_completion_error, BaseCompletionResponder},
24    handler_core::{
25        create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26        JsonError, ModelErrorMessage,
27    },
28    openai::{
29        ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30        ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31        ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32    },
33    streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35    util::sanitize_error_message,
36};
37
38/// Response streamer for the Responses API
39pub type ResponsesStreamer =
40    BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43    type Item = Result<Event, axum::Error>;
44
45    fn poll_next(
46        mut self: Pin<&mut Self>,
47        cx: &mut std::task::Context<'_>,
48    ) -> Poll<Option<Self::Item>> {
49        match self.done_state {
50            DoneState::SendingDone => {
51                self.done_state = DoneState::Done;
52                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53            }
54            DoneState::Done => {
55                if let Some(on_done) = &self.on_done {
56                    on_done(&self.chunks);
57                }
58                return Poll::Ready(None);
59            }
60            DoneState::Running => (),
61        }
62
63        match self.rx.poll_recv(cx) {
64            Poll::Ready(Some(resp)) => match resp {
65                Response::ModelError(msg, _) => {
66                    MistralRs::maybe_log_error(
67                        self.state.clone(),
68                        &ModelErrorMessage(msg.to_string()),
69                    );
70                    self.done_state = DoneState::SendingDone;
71                    Poll::Ready(Some(Ok(Event::default().data(msg))))
72                }
73                Response::ValidationError(e) => Poll::Ready(Some(Ok(
74                    Event::default().data(sanitize_error_message(e.as_ref()))
75                ))),
76                Response::InternalError(e) => {
77                    MistralRs::maybe_log_error(self.state.clone(), &*e);
78                    Poll::Ready(Some(Ok(
79                        Event::default().data(sanitize_error_message(e.as_ref()))
80                    )))
81                }
82                Response::Chunk(chat_chunk) => {
83                    // Convert ChatCompletionChunkResponse to ResponsesChunk
84                    let mut delta_outputs = vec![];
85
86                    // Check if all choices are finished
87                    let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89                    for choice in &chat_chunk.choices {
90                        let mut delta_content_items = Vec::new();
91
92                        // Handle text content in delta
93                        if let Some(content) = &choice.delta.content {
94                            delta_content_items.push(ResponsesDeltaContent {
95                                content_type: "output_text".to_string(),
96                                text: Some(content.clone()),
97                            });
98                        }
99
100                        // Handle tool calls in delta
101                        if let Some(tool_calls) = &choice.delta.tool_calls {
102                            for tool_call in tool_calls {
103                                let tool_text = format!(
104                                    "Tool: {} args: {}",
105                                    tool_call.function.name, tool_call.function.arguments
106                                );
107                                delta_content_items.push(ResponsesDeltaContent {
108                                    content_type: "tool_use".to_string(),
109                                    text: Some(tool_text),
110                                });
111                            }
112                        }
113
114                        if !delta_content_items.is_empty() {
115                            delta_outputs.push(ResponsesDeltaOutput {
116                                id: format!("msg_{}", Uuid::new_v4()),
117                                output_type: "message".to_string(),
118                                content: Some(delta_content_items),
119                            });
120                        }
121                    }
122
123                    let mut response_chunk = ResponsesChunk {
124                        id: chat_chunk.id.clone(),
125                        object: "response.chunk",
126                        created_at: chat_chunk.created as f64,
127                        model: chat_chunk.model.clone(),
128                        chunk_type: "delta".to_string(),
129                        delta: Some(ResponsesDelta {
130                            output: if delta_outputs.is_empty() {
131                                None
132                            } else {
133                                Some(delta_outputs)
134                            },
135                            status: if all_finished {
136                                Some("completed".to_string())
137                            } else {
138                                None
139                            },
140                        }),
141                        usage: None,
142                        metadata: None,
143                    };
144
145                    if all_finished {
146                        self.done_state = DoneState::SendingDone;
147                    }
148
149                    MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151                    if let Some(on_chunk) = &self.on_chunk {
152                        response_chunk = on_chunk(response_chunk);
153                    }
154
155                    if self.store_chunks {
156                        self.chunks.push(response_chunk.clone());
157                    }
158
159                    Poll::Ready(Some(Event::default().json_data(response_chunk)))
160                }
161                _ => unreachable!(),
162            },
163            Poll::Pending | Poll::Ready(None) => Poll::Pending,
164        }
165    }
166}
167
168/// Response responder types
169pub type ResponsesResponder =
170    BaseCompletionResponder<ResponsesObject, KeepAliveStream<ResponsesStreamer>>;
171
172type JsonModelError = BaseJsonModelError<ResponsesObject>;
173impl ErrorToResponse for JsonModelError {}
174
175impl IntoResponse for ResponsesResponder {
176    fn into_response(self) -> axum::response::Response {
177        match self {
178            ResponsesResponder::Sse(s) => s.into_response(),
179            ResponsesResponder::Json(s) => Json(s).into_response(),
180            ResponsesResponder::InternalError(e) => {
181                JsonError::new(sanitize_error_message(e.as_ref()))
182                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
183            }
184            ResponsesResponder::ValidationError(e) => {
185                JsonError::new(sanitize_error_message(e.as_ref()))
186                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
187            }
188            ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
189                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
190        }
191    }
192}
193
194/// Convert chat completion response to responses object
195fn chat_response_to_responses_object(
196    chat_resp: &ChatCompletionResponse,
197    request_id: String,
198    metadata: Option<Value>,
199) -> ResponsesObject {
200    let mut outputs = Vec::new();
201    let mut output_text_parts = Vec::new();
202
203    for choice in &chat_resp.choices {
204        let mut content_items = Vec::new();
205        let mut has_content = false;
206
207        // Handle text content
208        if let Some(text) = &choice.message.content {
209            output_text_parts.push(text.clone());
210            content_items.push(ResponsesContent {
211                content_type: "output_text".to_string(),
212                text: Some(text.clone()),
213                annotations: None,
214            });
215            has_content = true;
216        }
217
218        // Handle tool calls
219        if let Some(tool_calls) = &choice.message.tool_calls {
220            for tool_call in tool_calls {
221                let tool_text = format!(
222                    "Tool call: {} with args: {}",
223                    tool_call.function.name, tool_call.function.arguments
224                );
225                content_items.push(ResponsesContent {
226                    content_type: "tool_use".to_string(),
227                    text: Some(tool_text),
228                    annotations: None,
229                });
230                has_content = true;
231            }
232        }
233
234        // Only add output if we have content
235        if has_content {
236            outputs.push(ResponsesOutput {
237                id: format!("msg_{}", Uuid::new_v4()),
238                output_type: "message".to_string(),
239                role: choice.message.role.clone(),
240                status: None,
241                content: content_items,
242            });
243        }
244    }
245
246    ResponsesObject {
247        id: request_id,
248        object: "response",
249        created_at: chat_resp.created as f64,
250        model: chat_resp.model.clone(),
251        status: "completed".to_string(),
252        output: outputs,
253        output_text: if output_text_parts.is_empty() {
254            None
255        } else {
256            Some(output_text_parts.join(" "))
257        },
258        usage: Some(ResponsesUsage {
259            input_tokens: chat_resp.usage.prompt_tokens,
260            output_tokens: chat_resp.usage.completion_tokens,
261            total_tokens: chat_resp.usage.total_tokens,
262            input_tokens_details: None,
263            output_tokens_details: None,
264        }),
265        error: None,
266        metadata,
267        instructions: None,
268        incomplete_details: None,
269    }
270}
271
272/// Parse responses request into internal format
273async fn parse_responses_request(
274    oairequest: ResponsesCreateRequest,
275    state: SharedMistralRsState,
276    tx: Sender<Response>,
277) -> Result<(Request, bool, Option<Vec<Message>>)> {
278    if oairequest.instructions.is_some() {
279        return Err(anyhow::anyhow!(
280            "The 'instructions' field is not supported in the Responses API"
281        ));
282    }
283    // If previous_response_id is provided, get the full conversation history from cache
284    let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
285        let cache = get_response_cache();
286        cache.get_conversation_history(prev_id)?
287    } else {
288        None
289    };
290
291    // Get messages from either messages or input field
292    let messages = oairequest.input.into_either();
293
294    // Convert to ChatCompletionRequest for reuse
295    let mut chat_request = ChatCompletionRequest {
296        messages: messages.clone(),
297        model: oairequest.model,
298        logit_bias: oairequest.logit_bias,
299        logprobs: oairequest.logprobs,
300        top_logprobs: oairequest.top_logprobs,
301        max_tokens: oairequest.max_tokens,
302        n_choices: oairequest.n_choices,
303        presence_penalty: oairequest.presence_penalty,
304        frequency_penalty: oairequest.frequency_penalty,
305        repetition_penalty: oairequest.repetition_penalty,
306        stop_seqs: oairequest.stop_seqs,
307        temperature: oairequest.temperature,
308        top_p: oairequest.top_p,
309        stream: oairequest.stream,
310        tools: oairequest.tools,
311        tool_choice: oairequest.tool_choice,
312        response_format: oairequest.response_format,
313        web_search_options: oairequest.web_search_options,
314        top_k: oairequest.top_k,
315        grammar: oairequest.grammar,
316        min_p: oairequest.min_p,
317        dry_multiplier: oairequest.dry_multiplier,
318        dry_base: oairequest.dry_base,
319        dry_allowed_length: oairequest.dry_allowed_length,
320        dry_sequence_breakers: oairequest.dry_sequence_breakers,
321        enable_thinking: oairequest.enable_thinking,
322        truncate_sequence: oairequest.truncate_sequence,
323        reasoning_effort: oairequest.reasoning_effort,
324    };
325
326    // Prepend previous messages if available
327    if let Some(prev_msgs) = previous_messages {
328        match &mut chat_request.messages {
329            Either::Left(msgs) => {
330                let mut combined = prev_msgs;
331                combined.extend(msgs.clone());
332                chat_request.messages = Either::Left(combined);
333            }
334            Either::Right(_) => {
335                // If it's a prompt string, convert to messages and prepend
336                let prompt = chat_request.messages.as_ref().right().unwrap().clone();
337                let mut combined = prev_msgs;
338                combined.push(Message {
339                    content: Some(MessageContent::from_text(prompt)),
340                    role: "user".to_string(),
341                    name: None,
342                    tool_calls: None,
343                    tool_call_id: None,
344                });
345                chat_request.messages = Either::Left(combined);
346            }
347        }
348    }
349
350    // Get all messages for prompt_details
351    let all_messages = match &chat_request.messages {
352        Either::Left(msgs) => msgs.clone(),
353        Either::Right(prompt) => vec![Message {
354            content: Some(MessageContent::from_text(prompt.clone())),
355            role: "user".to_string(),
356            name: None,
357            tool_calls: None,
358            tool_call_id: None,
359        }],
360    };
361
362    let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
363    Ok((request, is_streaming, Some(all_messages)))
364}
365
366/// Create response endpoint
367#[utoipa::path(
368    post,
369    tag = "Mistral.rs",
370    path = "/v1/responses",
371    request_body = ResponsesCreateRequest,
372    responses((status = 200, description = "Response created"))
373)]
374pub async fn create_response(
375    State(state): ExtractedMistralRsState,
376    Json(oairequest): Json<ResponsesCreateRequest>,
377) -> ResponsesResponder {
378    let (tx, mut rx) = create_response_channel(None);
379    let request_id = format!("resp_{}", Uuid::new_v4());
380    let metadata = oairequest.metadata.clone();
381    let store = oairequest.store.unwrap_or(true);
382
383    // Extract model_id for routing
384    let model_id = if oairequest.model == "default" {
385        None
386    } else {
387        Some(oairequest.model.clone())
388    };
389
390    let (request, is_streaming, conversation_history) =
391        match parse_responses_request(oairequest, state.clone(), tx).await {
392            Ok(x) => x,
393            Err(e) => return handle_error(state, e.into()),
394        };
395
396    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
397        return handle_error(state, e.into());
398    }
399
400    if is_streaming {
401        let streamer = ResponsesStreamer {
402            rx,
403            done_state: DoneState::Running,
404            state: state.clone(),
405            on_chunk: None,
406            on_done: None,
407            chunks: Vec::new(),
408            store_chunks: store,
409        };
410
411        // Store chunks for later retrieval if requested
412        if store {
413            let cache = get_response_cache();
414            let id = request_id.clone();
415            let chunks_cache = cache.clone();
416
417            // Create a wrapper that stores chunks and conversation history
418            let history_for_streaming = conversation_history.clone();
419            let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
420                let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
421
422                // Reconstruct the assistant's message from chunks and store conversation history
423                if let Some(history) = history_for_streaming.clone() {
424                    let mut history = history;
425                    let mut assistant_message = String::new();
426
427                    // Collect all text from chunks
428                    for chunk in chunks {
429                        if let Some(delta) = &chunk.delta {
430                            if let Some(outputs) = &delta.output {
431                                for output in outputs {
432                                    if let Some(contents) = &output.content {
433                                        for content in contents {
434                                            if let Some(text) = &content.text {
435                                                assistant_message.push_str(text);
436                                            }
437                                        }
438                                    }
439                                }
440                            }
441                        }
442                    }
443
444                    // Add the complete assistant message to history
445                    if !assistant_message.is_empty() {
446                        history.push(Message {
447                            content: Some(MessageContent::from_text(assistant_message)),
448                            role: "assistant".to_string(),
449                            name: None,
450                            tool_calls: None,
451                            tool_call_id: None,
452                        });
453                    }
454
455                    let _ = chunks_cache.store_conversation_history(id.clone(), history);
456                }
457            });
458
459            ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
460        } else {
461            ResponsesResponder::Sse(create_streamer(streamer, None))
462        }
463    } else {
464        // Non-streaming response
465        match rx.recv().await {
466            Some(Response::Done(chat_resp)) => {
467                let response_obj =
468                    chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
469
470                // Store if requested
471                if store {
472                    let cache = get_response_cache();
473                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
474
475                    // Create complete conversation history including the assistant's response
476                    if let Some(mut history) = conversation_history.clone() {
477                        // Add the assistant's response to the conversation history
478                        for choice in &chat_resp.choices {
479                            if let Some(content) = &choice.message.content {
480                                history.push(Message {
481                                    content: Some(MessageContent::from_text(content.clone())),
482                                    role: choice.message.role.clone(),
483                                    name: None,
484                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
485                                    tool_call_id: None,
486                                });
487                            }
488                        }
489                        let _ = cache.store_conversation_history(request_id, history);
490                    }
491                }
492
493                ResponsesResponder::Json(response_obj)
494            }
495            Some(Response::ModelError(msg, partial_resp)) => {
496                let mut response_obj =
497                    chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
498                response_obj.error = Some(ResponsesError {
499                    error_type: "model_error".to_string(),
500                    message: msg.to_string(),
501                });
502                response_obj.status = "failed".to_string();
503
504                if store {
505                    let cache = get_response_cache();
506                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
507
508                    // Even on error, store conversation history with partial response
509                    if let Some(mut history) = conversation_history.clone() {
510                        // Add any partial response to the conversation history
511                        for choice in &partial_resp.choices {
512                            if let Some(content) = &choice.message.content {
513                                history.push(Message {
514                                    content: Some(MessageContent::from_text(content.clone())),
515                                    role: choice.message.role.clone(),
516                                    name: None,
517                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
518                                    tool_call_id: None,
519                                });
520                            }
521                        }
522                        let _ = cache.store_conversation_history(request_id, history);
523                    }
524                }
525                ResponsesResponder::ModelError(msg.to_string(), response_obj)
526            }
527            Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
528            Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
529            _ => ResponsesResponder::InternalError(
530                anyhow::anyhow!("Unexpected response type").into(),
531            ),
532        }
533    }
534}
535
536/// Get response by ID endpoint
537#[utoipa::path(
538    get,
539    tag = "Mistral.rs",
540    path = "/v1/responses/{response_id}",
541    params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
542    responses((status = 200, description = "Response object"))
543)]
544pub async fn get_response(
545    State(_state): ExtractedMistralRsState,
546    Path(response_id): Path<String>,
547) -> impl IntoResponse {
548    let cache = get_response_cache();
549
550    match cache.get_response(&response_id) {
551        Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
552        Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
553            .to_response(StatusCode::NOT_FOUND),
554        Err(e) => JsonError::new(format!(
555            "Error retrieving response: {}",
556            sanitize_error_message(&*e)
557        ))
558        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
559    }
560}
561
562/// Delete response by ID endpoint
563#[utoipa::path(
564    delete,
565    tag = "Mistral.rs",
566    path = "/v1/responses/{response_id}",
567    params(("response_id" = String, Path, description = "The ID of the response to delete")),
568    responses((status = 200, description = "Response deleted"))
569)]
570pub async fn delete_response(
571    State(_state): ExtractedMistralRsState,
572    Path(response_id): Path<String>,
573) -> impl IntoResponse {
574    let cache = get_response_cache();
575
576    match cache.delete_response(&response_id) {
577        Ok(true) => (
578            StatusCode::OK,
579            Json(serde_json::json!({
580                "deleted": true,
581                "id": response_id,
582                "object": "response.deleted"
583            })),
584        )
585            .into_response(),
586        Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
587            .to_response(StatusCode::NOT_FOUND),
588        Err(e) => JsonError::new(format!(
589            "Error deleting response: {}",
590            sanitize_error_message(&*e)
591        ))
592        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
593    }
594}
595
596/// Handle errors
597fn handle_error(
598    state: SharedMistralRsState,
599    e: Box<dyn std::error::Error + Send + Sync + 'static>,
600) -> ResponsesResponder {
601    handle_completion_error(state, e)
602}
603
604/// Create SSE streamer
605fn create_streamer(
606    streamer: ResponsesStreamer,
607    on_done: Option<OnDoneCallback<ResponsesChunk>>,
608) -> Sse<KeepAliveStream<ResponsesStreamer>> {
609    let keep_alive_interval = get_keep_alive_interval();
610
611    let streamer_with_callback = ResponsesStreamer {
612        on_done,
613        ..streamer
614    };
615
616    Sse::new(streamer_with_callback)
617        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
618}