mistralrs_server_core/
responses.rs

1//! ## Responses API functionality and route handlers.
2
3use std::{pin::Pin, task::Poll, time::Duration};
4
5use anyhow::Result;
6use axum::{
7    extract::{Json, Path, State},
8    http::{self, StatusCode},
9    response::{
10        sse::{Event, KeepAlive},
11        IntoResponse, Sse,
12    },
13};
14use either::Either;
15use mistralrs_core::{ChatCompletionResponse, MistralRs, Request, Response};
16use serde_json::Value;
17use tokio::sync::mpsc::Sender;
18use uuid::Uuid;
19
20use crate::{
21    cached_responses::get_response_cache,
22    chat_completion::parse_request as parse_chat_request,
23    completion_core::{handle_completion_error, BaseCompletionResponder},
24    handler_core::{
25        create_response_channel, send_request_with_model, BaseJsonModelError, ErrorToResponse,
26        JsonError, ModelErrorMessage,
27    },
28    openai::{
29        ChatCompletionRequest, Message, MessageContent, ResponsesChunk, ResponsesContent,
30        ResponsesCreateRequest, ResponsesDelta, ResponsesDeltaContent, ResponsesDeltaOutput,
31        ResponsesError, ResponsesObject, ResponsesOutput, ResponsesUsage,
32    },
33    streaming::{get_keep_alive_interval, BaseStreamer, DoneState},
34    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
35    util::sanitize_error_message,
36};
37
38/// Response streamer for the Responses API
39pub type ResponsesStreamer =
40    BaseStreamer<ResponsesChunk, OnChunkCallback<ResponsesChunk>, OnDoneCallback<ResponsesChunk>>;
41
42impl futures::Stream for ResponsesStreamer {
43    type Item = Result<Event, axum::Error>;
44
45    fn poll_next(
46        mut self: Pin<&mut Self>,
47        cx: &mut std::task::Context<'_>,
48    ) -> Poll<Option<Self::Item>> {
49        match self.done_state {
50            DoneState::SendingDone => {
51                self.done_state = DoneState::Done;
52                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
53            }
54            DoneState::Done => {
55                if let Some(on_done) = &self.on_done {
56                    on_done(&self.chunks);
57                }
58                return Poll::Ready(None);
59            }
60            DoneState::Running => (),
61        }
62
63        match self.rx.poll_recv(cx) {
64            Poll::Ready(Some(resp)) => match resp {
65                Response::ModelError(msg, _) => {
66                    MistralRs::maybe_log_error(
67                        self.state.clone(),
68                        &ModelErrorMessage(msg.to_string()),
69                    );
70                    self.done_state = DoneState::SendingDone;
71                    Poll::Ready(Some(Ok(Event::default().data(msg))))
72                }
73                Response::ValidationError(e) => Poll::Ready(Some(Ok(
74                    Event::default().data(sanitize_error_message(e.as_ref()))
75                ))),
76                Response::InternalError(e) => {
77                    MistralRs::maybe_log_error(self.state.clone(), &*e);
78                    Poll::Ready(Some(Ok(
79                        Event::default().data(sanitize_error_message(e.as_ref()))
80                    )))
81                }
82                Response::Chunk(chat_chunk) => {
83                    // Convert ChatCompletionChunkResponse to ResponsesChunk
84                    let mut delta_outputs = vec![];
85
86                    // Check if all choices are finished
87                    let all_finished = chat_chunk.choices.iter().all(|c| c.finish_reason.is_some());
88
89                    for choice in &chat_chunk.choices {
90                        let mut delta_content_items = Vec::new();
91
92                        // Handle text content in delta
93                        if let Some(content) = &choice.delta.content {
94                            delta_content_items.push(ResponsesDeltaContent {
95                                content_type: "output_text".to_string(),
96                                text: Some(content.clone()),
97                            });
98                        }
99
100                        // Handle tool calls in delta
101                        if let Some(tool_calls) = &choice.delta.tool_calls {
102                            for tool_call in tool_calls {
103                                let tool_text = format!(
104                                    "Tool: {} args: {}",
105                                    tool_call.function.name, tool_call.function.arguments
106                                );
107                                delta_content_items.push(ResponsesDeltaContent {
108                                    content_type: "tool_use".to_string(),
109                                    text: Some(tool_text),
110                                });
111                            }
112                        }
113
114                        if !delta_content_items.is_empty() {
115                            delta_outputs.push(ResponsesDeltaOutput {
116                                id: format!("msg_{}", Uuid::new_v4()),
117                                output_type: "message".to_string(),
118                                content: Some(delta_content_items),
119                            });
120                        }
121                    }
122
123                    let mut response_chunk = ResponsesChunk {
124                        id: chat_chunk.id.clone(),
125                        object: "response.chunk",
126                        created_at: chat_chunk.created as f64,
127                        model: chat_chunk.model.clone(),
128                        chunk_type: "delta".to_string(),
129                        delta: Some(ResponsesDelta {
130                            output: if delta_outputs.is_empty() {
131                                None
132                            } else {
133                                Some(delta_outputs)
134                            },
135                            status: if all_finished {
136                                Some("completed".to_string())
137                            } else {
138                                None
139                            },
140                        }),
141                        usage: None,
142                        metadata: None,
143                    };
144
145                    if all_finished {
146                        self.done_state = DoneState::SendingDone;
147                    }
148
149                    MistralRs::maybe_log_response(self.state.clone(), &chat_chunk);
150
151                    if let Some(on_chunk) = &self.on_chunk {
152                        response_chunk = on_chunk(response_chunk);
153                    }
154
155                    if self.store_chunks {
156                        self.chunks.push(response_chunk.clone());
157                    }
158
159                    Poll::Ready(Some(Event::default().json_data(response_chunk)))
160                }
161                _ => unreachable!(),
162            },
163            Poll::Pending | Poll::Ready(None) => Poll::Pending,
164        }
165    }
166}
167
168/// Response responder types
169pub type ResponsesResponder = BaseCompletionResponder<ResponsesObject, ResponsesStreamer>;
170
171type JsonModelError = BaseJsonModelError<ResponsesObject>;
172impl ErrorToResponse for JsonModelError {}
173
174impl IntoResponse for ResponsesResponder {
175    fn into_response(self) -> axum::response::Response {
176        match self {
177            ResponsesResponder::Sse(s) => s.into_response(),
178            ResponsesResponder::Json(s) => Json(s).into_response(),
179            ResponsesResponder::InternalError(e) => {
180                JsonError::new(sanitize_error_message(e.as_ref()))
181                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
182            }
183            ResponsesResponder::ValidationError(e) => {
184                JsonError::new(sanitize_error_message(e.as_ref()))
185                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
186            }
187            ResponsesResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
188                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
189        }
190    }
191}
192
193/// Convert chat completion response to responses object
194fn chat_response_to_responses_object(
195    chat_resp: &ChatCompletionResponse,
196    request_id: String,
197    metadata: Option<Value>,
198) -> ResponsesObject {
199    let mut outputs = Vec::new();
200    let mut output_text_parts = Vec::new();
201
202    for choice in &chat_resp.choices {
203        let mut content_items = Vec::new();
204        let mut has_content = false;
205
206        // Handle text content
207        if let Some(text) = &choice.message.content {
208            output_text_parts.push(text.clone());
209            content_items.push(ResponsesContent {
210                content_type: "output_text".to_string(),
211                text: Some(text.clone()),
212                annotations: None,
213            });
214            has_content = true;
215        }
216
217        // Handle tool calls
218        if let Some(tool_calls) = &choice.message.tool_calls {
219            for tool_call in tool_calls {
220                let tool_text = format!(
221                    "Tool call: {} with args: {}",
222                    tool_call.function.name, tool_call.function.arguments
223                );
224                content_items.push(ResponsesContent {
225                    content_type: "tool_use".to_string(),
226                    text: Some(tool_text),
227                    annotations: None,
228                });
229                has_content = true;
230            }
231        }
232
233        // Only add output if we have content
234        if has_content {
235            outputs.push(ResponsesOutput {
236                id: format!("msg_{}", Uuid::new_v4()),
237                output_type: "message".to_string(),
238                role: choice.message.role.clone(),
239                status: None,
240                content: content_items,
241            });
242        }
243    }
244
245    ResponsesObject {
246        id: request_id,
247        object: "response",
248        created_at: chat_resp.created as f64,
249        model: chat_resp.model.clone(),
250        status: "completed".to_string(),
251        output: outputs,
252        output_text: if output_text_parts.is_empty() {
253            None
254        } else {
255            Some(output_text_parts.join(" "))
256        },
257        usage: Some(ResponsesUsage {
258            input_tokens: chat_resp.usage.prompt_tokens,
259            output_tokens: chat_resp.usage.completion_tokens,
260            total_tokens: chat_resp.usage.total_tokens,
261            input_tokens_details: None,
262            output_tokens_details: None,
263        }),
264        error: None,
265        metadata,
266        instructions: None,
267        incomplete_details: None,
268    }
269}
270
271/// Parse responses request into internal format
272async fn parse_responses_request(
273    oairequest: ResponsesCreateRequest,
274    state: SharedMistralRsState,
275    tx: Sender<Response>,
276) -> Result<(Request, bool, Option<Vec<Message>>)> {
277    if oairequest.instructions.is_some() {
278        return Err(anyhow::anyhow!(
279            "The 'instructions' field is not supported in the Responses API"
280        ));
281    }
282    // If previous_response_id is provided, get the full conversation history from cache
283    let previous_messages = if let Some(prev_id) = &oairequest.previous_response_id {
284        let cache = get_response_cache();
285        cache.get_conversation_history(prev_id)?
286    } else {
287        None
288    };
289
290    // Get messages from either messages or input field
291    let messages = oairequest.input.into_either();
292
293    // Convert to ChatCompletionRequest for reuse
294    let mut chat_request = ChatCompletionRequest {
295        messages: messages.clone(),
296        model: oairequest.model,
297        logit_bias: oairequest.logit_bias,
298        logprobs: oairequest.logprobs,
299        top_logprobs: oairequest.top_logprobs,
300        max_tokens: oairequest.max_tokens,
301        n_choices: oairequest.n_choices,
302        presence_penalty: oairequest.presence_penalty,
303        frequency_penalty: oairequest.frequency_penalty,
304        stop_seqs: oairequest.stop_seqs,
305        temperature: oairequest.temperature,
306        top_p: oairequest.top_p,
307        stream: oairequest.stream,
308        tools: oairequest.tools,
309        tool_choice: oairequest.tool_choice,
310        response_format: oairequest.response_format,
311        web_search_options: oairequest.web_search_options,
312        top_k: oairequest.top_k,
313        grammar: oairequest.grammar,
314        min_p: oairequest.min_p,
315        dry_multiplier: oairequest.dry_multiplier,
316        dry_base: oairequest.dry_base,
317        dry_allowed_length: oairequest.dry_allowed_length,
318        dry_sequence_breakers: oairequest.dry_sequence_breakers,
319        enable_thinking: oairequest.enable_thinking,
320    };
321
322    // Prepend previous messages if available
323    if let Some(prev_msgs) = previous_messages {
324        match &mut chat_request.messages {
325            Either::Left(msgs) => {
326                let mut combined = prev_msgs;
327                combined.extend(msgs.clone());
328                chat_request.messages = Either::Left(combined);
329            }
330            Either::Right(_) => {
331                // If it's a prompt string, convert to messages and prepend
332                let prompt = chat_request.messages.as_ref().right().unwrap().clone();
333                let mut combined = prev_msgs;
334                combined.push(Message {
335                    content: Some(MessageContent::from_text(prompt)),
336                    role: "user".to_string(),
337                    name: None,
338                    tool_calls: None,
339                });
340                chat_request.messages = Either::Left(combined);
341            }
342        }
343    }
344
345    // Get all messages for prompt_details
346    let all_messages = match &chat_request.messages {
347        Either::Left(msgs) => msgs.clone(),
348        Either::Right(prompt) => vec![Message {
349            content: Some(MessageContent::from_text(prompt.clone())),
350            role: "user".to_string(),
351            name: None,
352            tool_calls: None,
353        }],
354    };
355
356    let (request, is_streaming) = parse_chat_request(chat_request, state, tx).await?;
357    Ok((request, is_streaming, Some(all_messages)))
358}
359
360/// Create response endpoint
361#[utoipa::path(
362    post,
363    tag = "Mistral.rs",
364    path = "/v1/responses",
365    request_body = ResponsesCreateRequest,
366    responses((status = 200, description = "Response created"))
367)]
368pub async fn create_response(
369    State(state): ExtractedMistralRsState,
370    Json(oairequest): Json<ResponsesCreateRequest>,
371) -> ResponsesResponder {
372    let (tx, mut rx) = create_response_channel(None);
373    let request_id = format!("resp_{}", Uuid::new_v4());
374    let metadata = oairequest.metadata.clone();
375    let store = oairequest.store.unwrap_or(true);
376
377    // Extract model_id for routing
378    let model_id = if oairequest.model == "default" {
379        None
380    } else {
381        Some(oairequest.model.clone())
382    };
383
384    let (request, is_streaming, conversation_history) =
385        match parse_responses_request(oairequest, state.clone(), tx).await {
386            Ok(x) => x,
387            Err(e) => return handle_error(state, e.into()),
388        };
389
390    if let Err(e) = send_request_with_model(&state, request, model_id.as_deref()).await {
391        return handle_error(state, e.into());
392    }
393
394    if is_streaming {
395        let streamer = ResponsesStreamer {
396            rx,
397            done_state: DoneState::Running,
398            state: state.clone(),
399            on_chunk: None,
400            on_done: None,
401            chunks: Vec::new(),
402            store_chunks: store,
403        };
404
405        // Store chunks for later retrieval if requested
406        if store {
407            let cache = get_response_cache();
408            let id = request_id.clone();
409            let chunks_cache = cache.clone();
410
411            // Create a wrapper that stores chunks and conversation history
412            let history_for_streaming = conversation_history.clone();
413            let on_done: OnDoneCallback<ResponsesChunk> = Box::new(move |chunks| {
414                let _ = chunks_cache.store_chunks(id.clone(), chunks.to_vec());
415
416                // Reconstruct the assistant's message from chunks and store conversation history
417                if let Some(history) = history_for_streaming.clone() {
418                    let mut history = history;
419                    let mut assistant_message = String::new();
420
421                    // Collect all text from chunks
422                    for chunk in chunks {
423                        if let Some(delta) = &chunk.delta {
424                            if let Some(outputs) = &delta.output {
425                                for output in outputs {
426                                    if let Some(contents) = &output.content {
427                                        for content in contents {
428                                            if let Some(text) = &content.text {
429                                                assistant_message.push_str(text);
430                                            }
431                                        }
432                                    }
433                                }
434                            }
435                        }
436                    }
437
438                    // Add the complete assistant message to history
439                    if !assistant_message.is_empty() {
440                        history.push(Message {
441                            content: Some(MessageContent::from_text(assistant_message)),
442                            role: "assistant".to_string(),
443                            name: None,
444                            tool_calls: None,
445                        });
446                    }
447
448                    let _ = chunks_cache.store_conversation_history(id.clone(), history);
449                }
450            });
451
452            ResponsesResponder::Sse(create_streamer(streamer, Some(on_done)))
453        } else {
454            ResponsesResponder::Sse(create_streamer(streamer, None))
455        }
456    } else {
457        // Non-streaming response
458        match rx.recv().await {
459            Some(Response::Done(chat_resp)) => {
460                let response_obj =
461                    chat_response_to_responses_object(&chat_resp, request_id.clone(), metadata);
462
463                // Store if requested
464                if store {
465                    let cache = get_response_cache();
466                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
467
468                    // Create complete conversation history including the assistant's response
469                    if let Some(mut history) = conversation_history.clone() {
470                        // Add the assistant's response to the conversation history
471                        for choice in &chat_resp.choices {
472                            if let Some(content) = &choice.message.content {
473                                history.push(Message {
474                                    content: Some(MessageContent::from_text(content.clone())),
475                                    role: choice.message.role.clone(),
476                                    name: None,
477                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
478                                });
479                            }
480                        }
481                        let _ = cache.store_conversation_history(request_id, history);
482                    }
483                }
484
485                ResponsesResponder::Json(response_obj)
486            }
487            Some(Response::ModelError(msg, partial_resp)) => {
488                let mut response_obj =
489                    chat_response_to_responses_object(&partial_resp, request_id.clone(), metadata);
490                response_obj.error = Some(ResponsesError {
491                    error_type: "model_error".to_string(),
492                    message: msg.to_string(),
493                });
494                response_obj.status = "failed".to_string();
495
496                if store {
497                    let cache = get_response_cache();
498                    let _ = cache.store_response(request_id.clone(), response_obj.clone());
499
500                    // Even on error, store conversation history with partial response
501                    if let Some(mut history) = conversation_history.clone() {
502                        // Add any partial response to the conversation history
503                        for choice in &partial_resp.choices {
504                            if let Some(content) = &choice.message.content {
505                                history.push(Message {
506                                    content: Some(MessageContent::from_text(content.clone())),
507                                    role: choice.message.role.clone(),
508                                    name: None,
509                                    tool_calls: None, // TODO: Convert ToolCallResponse to ToolCall if needed
510                                });
511                            }
512                        }
513                        let _ = cache.store_conversation_history(request_id, history);
514                    }
515                }
516                ResponsesResponder::ModelError(msg.to_string(), response_obj)
517            }
518            Some(Response::ValidationError(e)) => ResponsesResponder::ValidationError(e),
519            Some(Response::InternalError(e)) => ResponsesResponder::InternalError(e),
520            _ => ResponsesResponder::InternalError(
521                anyhow::anyhow!("Unexpected response type").into(),
522            ),
523        }
524    }
525}
526
527/// Get response by ID endpoint
528#[utoipa::path(
529    get,
530    tag = "Mistral.rs",
531    path = "/v1/responses/{response_id}",
532    params(("response_id" = String, Path, description = "The ID of the response to retrieve")),
533    responses((status = 200, description = "Response object"))
534)]
535pub async fn get_response(
536    State(_state): ExtractedMistralRsState,
537    Path(response_id): Path<String>,
538) -> impl IntoResponse {
539    let cache = get_response_cache();
540
541    match cache.get_response(&response_id) {
542        Ok(Some(response)) => (StatusCode::OK, Json(response)).into_response(),
543        Ok(None) => JsonError::new(format!("Response with ID '{response_id}' not found"))
544            .to_response(StatusCode::NOT_FOUND),
545        Err(e) => JsonError::new(format!(
546            "Error retrieving response: {}",
547            sanitize_error_message(&*e)
548        ))
549        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
550    }
551}
552
553/// Delete response by ID endpoint
554#[utoipa::path(
555    delete,
556    tag = "Mistral.rs",
557    path = "/v1/responses/{response_id}",
558    params(("response_id" = String, Path, description = "The ID of the response to delete")),
559    responses((status = 200, description = "Response deleted"))
560)]
561pub async fn delete_response(
562    State(_state): ExtractedMistralRsState,
563    Path(response_id): Path<String>,
564) -> impl IntoResponse {
565    let cache = get_response_cache();
566
567    match cache.delete_response(&response_id) {
568        Ok(true) => (
569            StatusCode::OK,
570            Json(serde_json::json!({
571                "deleted": true,
572                "id": response_id,
573                "object": "response.deleted"
574            })),
575        )
576            .into_response(),
577        Ok(false) => JsonError::new(format!("Response with ID '{response_id}' not found"))
578            .to_response(StatusCode::NOT_FOUND),
579        Err(e) => JsonError::new(format!(
580            "Error deleting response: {}",
581            sanitize_error_message(&*e)
582        ))
583        .to_response(StatusCode::INTERNAL_SERVER_ERROR),
584    }
585}
586
587/// Handle errors
588fn handle_error(
589    state: SharedMistralRsState,
590    e: Box<dyn std::error::Error + Send + Sync + 'static>,
591) -> ResponsesResponder {
592    handle_completion_error(state, e)
593}
594
595/// Create SSE streamer
596fn create_streamer(
597    streamer: ResponsesStreamer,
598    on_done: Option<OnDoneCallback<ResponsesChunk>>,
599) -> Sse<ResponsesStreamer> {
600    let keep_alive_interval = get_keep_alive_interval();
601
602    let streamer_with_callback = ResponsesStreamer {
603        on_done,
604        ..streamer
605    };
606
607    Sse::new(streamer_with_callback)
608        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
609}