mistralrs_server_core/
responses.rs

1//! ## Responses API functionality and route handlers.
2
3use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7    extract::{Json, Path, State},
8    http::{self, StatusCode},
9    response::{
10        sse::{Event, KeepAlive, KeepAliveStream},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21    cached_responses::get_response_cache,
22    chat_completion::parse_request as parse_chat_request,
23    completion_core::{handle_completion_error, BaseCompletionResponder},
24    handler_core::{
25        create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26        JsonError, ModelErrorMessage,
27    },
28    openai::{
29        ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30        ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31        ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32    },
33    streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35    util::sanitize_error_message,
36};
37
38/// Response streamer for the Responses API
39pub type ResponsesStreamer =
40    BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43    type Item = Result<Event, axum::Error>;
44
45    fn poll_next(
46        mut self: Pin<&mut Self>,
47        cx: &mut std::task::Context<'_>,
48    ) -> Poll<Option<Self::Item>> {
49        match self.done_state {
50            DoneState::SendingDone => {
51                self.done_state = DoneState::Done;
52                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53            }
54            DoneState::Done => {
55                if let Some(on_done) = &self.on_done {
56                    on_done(&self.chunks);
57                }
58                return Poll::Ready(None);
59            }
60            DoneState::Running => (),
61        }
62
63        match self.rx.poll_recv(cx) {
64            Poll::Ready(Some(resp)) => match resp {
65                Response::ModelError(msg, _) => {
66                    MistralRs::maybe_log_error(
67                        self.state.clone(),
68                        &ModelErrorMessage(msg.to_string()),
69                    );
70                    self.done_state = DoneState::SendingDone;
71                    Poll::Ready(Some(Ok(Event::default().data(msg))))
72                }
73                Response::ValidationError(e) => Poll::Ready(Some(Ok(
74                    Event::default().data(sanitize_error_message(e.as_ref()))
75                ))),
76                Response::InternalError(e) => {
77                    MistralRs::maybe_log_error(self.state.clone(), &*e);
78                    Poll::Ready(Some(Ok(
79                        Event::default().data(sanitize_error_message(e.as_ref()))
80                    )))
81                }
82                Response::Chunk(chat_chunk) => {
83                    // Convert ChatCompletionChunkResponse to ResponsesChunk
84                    let mut delta_outputs = vec![];
85
86                    // Check if all choices are finished
87                    let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89                    for choice in &chat_chunk.choices {
90                        let mut delta_content_items = Vec::new();
91
92                        // Handle text content in delta
93                        if let Some(content) = &choice.delta.content {
94                            delta_content_items.push(ResponsesDeltaContent {
95                                content_type: "output_text".to_string(),
96                                text: Some(content.clone()),
97                            });
98                        }
99
100                        // Handle tool calls in delta
101                        if let Some(tool_calls) = &choice.delta.tool_calls {
102                            for tool_call in tool_calls {
103                                let tool_text = format!(
104                                    "Tool: {} args: {}",
105                                    tool_call.function.name, tool_call.function.arguments
106                                );
107                                delta_content_items.push(ResponsesDeltaContent {
108                                    content_type: "tool_use".to_string(),
109                                    text: Some(tool_text),
110                                });
111                            }
112                        }
113
114                        if !delta_content_items.is_empty() {
115                            delta_outputs.push(ResponsesDeltaOutput {
116                                id: format!("msg_{}", Uuid::new_v4()),
117                                output_type: "message".to_string(),
118                                content: Some(delta_content_items),
119                            });
120                        }
121                    }
122
123                    let mut response_chunk = ResponsesChunk {
124                        id: chat_chunk.id.clone(),
125                        object: "response.chunk",
126                        created_at: chat_chunk.created as f64,
127                        model: chat_chunk.model.clone(),
128                        chunk_type: "delta".to_string(),
129                        delta: Some(ResponsesDelta {
130                            output: if delta_outputs.is_empty() {
131                                None
132                            } else {
133                                Some(delta_outputs)
134                            },
135                            status: if all_finished {
136                                Some("completed".to_string())
137                            } else {
138                                None
139                            },
140                        }),
141                        usage: None,
142                        metadata: None,
143                    };
144
145                    if all_finished {
146                        self.done_state = DoneState::SendingDone;
147                    }
148
149                    MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151                    if let Some(on_chunk) = &self.on_chunk {
152                        response_chunk = on_chunk(response_chunk);
153                    }
154
155                    if self.store_chunks {
156                        self.chunks.push(response_chunk.clone());
157                    }
158
159                    Poll::Ready(Some(Event::default().json_data(response_chunk)))
160                }
161                _ => unreachable!(),
162            },
163            Poll::Pending | Poll::Ready(None) => Poll::Pending,
164        }
165    }
166}
167
168/// Response responder types
169pub type ResponsesResponder =
170    BaseCompletionResponder<ResponsesObject, KeepAliveStream<ResponsesStreamer>>;
171
172type JsonModelError = BaseJsonModelError<ResponsesObject>;
173impl ErrorToResponse for JsonModelError {}
174
175impl IntoResponse for ResponsesResponder {
176    fn into_response(self) -> axum::response::Response {
177        match self {
178            ResponsesResponder::Sse(s) => s.into_response(),
179            ResponsesResponder::Json(s) => Json(s).into_response(),
180            ResponsesResponder::InternalError(e) => {
181                JsonError::new(sanitize_error_message(e.as_ref()))
182                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
183            }
184            ResponsesResponder::ValidationError(e) => {
185                JsonError::new(sanitize_error_message(e.as_ref()))
186                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
187            }
188            ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
189                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
190        }
191    }
192}
193
194/// Convert chat completion response to responses object
195fn chat_response_to_responses_object(
196    chat_resp: &ChatCompletionResponse,
197    request_id: String,
198    metadata: Option<Value>,
199) -> ResponsesObject {
200    let mut outputs = Vec::new();
201    let mut output_text_parts = Vec::new();
202
203    for choice in &chat_resp.choices {
204        let mut content_items = Vec::new();
205        let mut has_content = false;
206
207        // Handle text content
208        if let Some(text) = &choice.message.content {
209            output_text_parts.push(text.clone());
210            content_items.push(ResponsesContent {
211                content_type: "output_text".to_string(),
212                text: Some(text.clone()),
213                annotations: None,
214            });
215            has_content = true;
216        }
217
218        // Handle tool calls
219        if let Some(tool_calls) = &choice.message.tool_calls {
220            for tool_call in tool_calls {
221                let tool_text = format!(
222                    "Tool call: {} with args: {}",
223                    tool_call.function.name, tool_call.function.arguments
224                );
225                content_items.push(ResponsesContent {
226                    content_type: "tool_use".to_string(),
227                    text: Some(tool_text),
228                    annotations: None,
229                });
230                has_content = true;
231            }
232        }
233
234        // Only add output if we have content
235        if has_content {
236            outputs.push(ResponsesOutput {
237                id: format!("msg_{}", Uuid::new_v4()),
238                output_type: "message".to_string(),
239                role: choice.message.role.clone(),
240                status: None,
241                content: content_items,
242            });
243        }
244    }
245
246    ResponsesObject {
247        id: request_id,
248        object: "response",
249        created_at: chat_resp.created as f64,
250        model: chat_resp.model.clone(),
251        status: "completed".to_string(),
252        output: outputs,
253        output_text: if output_text_parts.is_empty() {
254            None
255        } else {
256            Some(output_text_parts.join(" "))
257        },
258        usage: Some(ResponsesUsage {
259            input_tokens: chat_resp.usage.prompt_tokens,
260            output_tokens: chat_resp.usage.completion_tokens,
261            total_tokens: chat_resp.usage.total_tokens,
262            input_tokens_details: None,
263            output_tokens_details: None,
264        }),
265        error: None,
266        metadata,
267        instructions: None,
268        incomplete_details: None,
269    }
270}
271
272/// Parse responses request into internal format
273async fn parse_responses_request(
274    oairequest: ResponsesCreateRequest,
275    state: SharedMistralRsState,
276    tx: Sender<Response>,
277) -> Result<(Request, bool, Option<Vec<Message>>)> {
278    if oairequest.instructions.is_some() {
279        return Err(anyhow::anyhow!(
280            "The 'instructions' field is not supported in the Responses API"
281        ));
282    }
283    // If previous_response_id is provided, get the full conversation history from cache
284    let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
285        let cache = get_response_cache();
286        cache.get_conversation_history(prev_id)?
287    } else {
288        None
289    };
290
291    // Get messages from either messages or input field
292    let messages = oairequest.input.into_either();
293
294    // Convert to ChatCompletionRequest for reuse
295    let mut chat_request = ChatCompletionRequest {
296        messages: messages.clone(),
297        model: oairequest.model,
298        logit_bias: oairequest.logit_bias,
299        logprobs: oairequest.logprobs,
300        top_logprobs: oairequest.top_logprobs,
301        max_tokens: oairequest.max_tokens,
302        n_choices: oairequest.n_choices,
303        presence_penalty: oairequest.presence_penalty,
304        frequency_penalty: oairequest.frequency_penalty,
305        repetition_penalty: oairequest.repetition_penalty,
306        stop_seqs: oairequest.stop_seqs,
307        temperature: oairequest.temperature,
308        top_p: oairequest.top_p,
309        stream: oairequest.stream,
310        tools: oairequest.tools,
311        tool_choice: oairequest.tool_choice,
312        response_format: oairequest.response_format,
313        web_search_options: oairequest.web_search_options,
314        top_k: oairequest.top_k,
315        grammar: oairequest.grammar,
316        min_p: oairequest.min_p,
317        dry_multiplier: oairequest.dry_multiplier,
318        dry_base: oairequest.dry_base,
319        dry_allowed_length: oairequest.dry_allowed_length,
320        dry_sequence_breakers: oairequest.dry_sequence_breakers,
321        enable_thinking: oairequest.enable_thinking,
322        truncate_sequence: oairequest.truncate_sequence,
323    };
324
325    // Prepend previous messages if available
326    if let Some(prev_msgs) = previous_messages {
327        match &mut chat_request.messages {
328            Either::Left(msgs) => {
329                let mut combined = prev_msgs;
330                combined.extend(msgs.clone());
331                chat_request.messages = Either::Left(combined);
332            }
333            Either::Right(_) => {
334                // If it's a prompt string, convert to messages and prepend
335                let prompt = chat_request.messages.as_ref().right().unwrap().clone();
336                let mut combined = prev_msgs;
337                combined.push(Message {
338                    content: Some(MessageContent::from_text(prompt)),
339                    role: "user".to_string(),
340                    name: None,
341                    tool_calls: None,
342                });
343                chat_request.messages = Either::Left(combined);
344            }
345        }
346    }
347
348    // Get all messages for prompt_details
349    let all_messages = match &chat_request.messages {
350        Either::Left(msgs) => msgs.clone(),
351        Either::Right(prompt) => vec![Message {
352            content: Some(MessageContent::from_text(prompt.clone())),
353            role: "user".to_string(),
354            name: None,
355            tool_calls: None,
356        }],
357    };
358
359    let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
360    Ok((request, is_streaming, Some(all_messages)))
361}
362
363/// Create response endpoint
364#[utoipa::path(
365    post,
366    tag = "Mistral.rs",
367    path = "/v1/responses",
368    request_body = ResponsesCreateRequest,
369    responses((status = 200, description = "Response created"))
370)]
371pub async fn create_response(
372    State(state): ExtractedMistralRsState,
373    Json(oairequest): Json<ResponsesCreateRequest>,
374) -> ResponsesResponder {
375    let (tx, mut rx) = create_response_channel(None);
376    let request_id = format!("resp_{}", Uuid::new_v4());
377    let metadata = oairequest.metadata.clone();
378    let store = oairequest.store.unwrap_or(true);
379
380    // Extract model_id for routing
381    let model_id = if oairequest.model == "default" {
382        None
383    } else {
384        Some(oairequest.model.clone())
385    };
386
387    let (request, is_streaming, conversation_history) =
388        match parse_responses_request(oairequest, state.clone(), tx).await {
389            Ok(x) => x,
390            Err(e) => return handle_error(state, e.into()),
391        };
392
393    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
394        return handle_error(state, e.into());
395    }
396
397    if is_streaming {
398        let streamer = ResponsesStreamer {
399            rx,
400            done_state: DoneState::Running,
401            state: state.clone(),
402            on_chunk: None,
403            on_done: None,
404            chunks: Vec::new(),
405            store_chunks: store,
406        };
407
408        // Store chunks for later retrieval if requested
409        if store {
410            let cache = get_response_cache();
411            let id = request_id.clone();
412            let chunks_cache = cache.clone();
413
414            // Create a wrapper that stores chunks and conversation history
415            let history_for_streaming = conversation_history.clone();
416            let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
417                let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
418
419                // Reconstruct the assistant's message from chunks and store conversation history
420                if let Some(history) = history_for_streaming.clone() {
421                    let mut history = history;
422                    let mut assistant_message = String::new();
423
424                    // Collect all text from chunks
425                    for chunk in chunks {
426                        if let Some(delta) = &chunk.delta {
427                            if let Some(outputs) = &delta.output {
428                                for output in outputs {
429                                    if let Some(contents) = &output.content {
430                                        for content in contents {
431                                            if let Some(text) = &content.text {
432                                                assistant_message.push_str(text);
433                                            }
434                                        }
435                                    }
436                                }
437                            }
438                        }
439                    }
440
441                    // Add the complete assistant message to history
442                    if !assistant_message.is_empty() {
443                        history.push(Message {
444                            content: Some(MessageContent::from_text(assistant_message)),
445                            role: "assistant".to_string(),
446                            name: None,
447                            tool_calls: None,
448                        });
449                    }
450
451                    let _ = chunks_cache.store_conversation_history(id.clone(), history);
452                }
453            });
454
455            ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
456        } else {
457            ResponsesResponder::Sse(create_streamer(streamer, None))
458        }
459    } else {
460        // Non-streaming response
461        match rx.recv().await {
462            Some(Response::Done(chat_resp)) => {
463                let response_obj =
464                    chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
465
466                // Store if requested
467                if store {
468                    let cache = get_response_cache();
469                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
470
471                    // Create complete conversation history including the assistant's response
472                    if let Some(mut history) = conversation_history.clone() {
473                        // Add the assistant's response to the conversation history
474                        for choice in &chat_resp.choices {
475                            if let Some(content) = &choice.message.content {
476                                history.push(Message {
477                                    content: Some(MessageContent::from_text(content.clone())),
478                                    role: choice.message.role.clone(),
479                                    name: None,
480                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
481                                });
482                            }
483                        }
484                        let _ = cache.store_conversation_history(request_id, history);
485                    }
486                }
487
488                ResponsesResponder::Json(response_obj)
489            }
490            Some(Response::ModelError(msg, partial_resp)) => {
491                let mut response_obj =
492                    chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
493                response_obj.error = Some(ResponsesError {
494                    error_type: "model_error".to_string(),
495                    message: msg.to_string(),
496                });
497                response_obj.status = "failed".to_string();
498
499                if store {
500                    let cache = get_response_cache();
501                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
502
503                    // Even on error, store conversation history with partial response
504                    if let Some(mut history) = conversation_history.clone() {
505                        // Add any partial response to the conversation history
506                        for choice in &partial_resp.choices {
507                            if let Some(content) = &choice.message.content {
508                                history.push(Message {
509                                    content: Some(MessageContent::from_text(content.clone())),
510                                    role: choice.message.role.clone(),
511                                    name: None,
512                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
513                                });
514                            }
515                        }
516                        let _ = cache.store_conversation_history(request_id, history);
517                    }
518                }
519                ResponsesResponder::ModelError(msg.to_string(), response_obj)
520            }
521            Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
522            Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
523            _ => ResponsesResponder::InternalError(
524                anyhow::anyhow!("Unexpected response type").into(),
525            ),
526        }
527    }
528}
529
530/// Get response by ID endpoint
531#[utoipa::path(
532    get,
533    tag = "Mistral.rs",
534    path = "/v1/responses/{response_id}",
535    params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
536    responses((status = 200, description = "Response object"))
537)]
538pub async fn get_response(
539    State(_state): ExtractedMistralRsState,
540    Path(response_id): Path<String>,
541) -> impl IntoResponse {
542    let cache = get_response_cache();
543
544    match cache.get_response(&response_id) {
545        Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
546        Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
547            .to_response(StatusCode::NOT_FOUND),
548        Err(e) => JsonError::new(format!(
549            "Error retrieving response: {}",
550            sanitize_error_message(&*e)
551        ))
552        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
553    }
554}
555
556/// Delete response by ID endpoint
557#[utoipa::path(
558    delete,
559    tag = "Mistral.rs",
560    path = "/v1/responses/{response_id}",
561    params(("response_id" = String, Path, description = "The ID of the response to delete")),
562    responses((status = 200, description = "Response deleted"))
563)]
564pub async fn delete_response(
565    State(_state): ExtractedMistralRsState,
566    Path(response_id): Path<String>,
567) -> impl IntoResponse {
568    let cache = get_response_cache();
569
570    match cache.delete_response(&response_id) {
571        Ok(true) => (
572            StatusCode::OK,
573            Json(serde_json::json!({
574                "deleted": true,
575                "id": response_id,
576                "object": "response.deleted"
577            })),
578        )
579            .into_response(),
580        Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
581            .to_response(StatusCode::NOT_FOUND),
582        Err(e) => JsonError::new(format!(
583            "Error deleting response: {}",
584            sanitize_error_message(&*e)
585        ))
586        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
587    }
588}
589
590/// Handle errors
591fn handle_error(
592    state: SharedMistralRsState,
593    e: Box<dyn std::error::Error + Send + Sync + 'static>,
594) -> ResponsesResponder {
595    handle_completion_error(state, e)
596}
597
598/// Create SSE streamer
599fn create_streamer(
600    streamer: ResponsesStreamer,
601    on_done: Option<OnDoneCallback<ResponsesChunk>>,
602) -> Sse<KeepAliveStream<ResponsesStreamer>> {
603    let keep_alive_interval = get_keep_alive_interval();
604
605    let streamer_with_callback = ResponsesStreamer {
606        on_done,
607        ..streamer
608    };
609
610    Sse::new(streamer_with_callback)
611        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
612}