mistralrs_server_core/
completions.rs

1//! ## Completions functionality and route handler.
2
3use std::{
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use anyhow::Result;
11use axum::{
12    extract::{Json, State},
13    http::{self},
14    response::{
15        sse::{Event, KeepAlive, KeepAliveStream},
16        IntoResponse, Sse,
17    },
18};
19use mistralrs_core::{
20    CompletionChunkResponse, CompletionResponse, Constraint, MistralRs, NormalRequest, Request,
21    RequestMessage, Response, SamplingParams,
22};
23use tokio::sync::mpsc::{Receiver, Sender};
24use tracing::warn;
25
26use crate::{
27    completion_core::{
28        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
29        BaseCompletionResponder,
30    },
31    handler_core::{
32        base_process_non_streaming_response, create_response_channel, send_request,
33        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
34    },
35    openai::{CompletionRequest, Grammar},
36    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
37    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
38    util::{sanitize_error_message, validate_model_name},
39};
40
41/// A callback function that processes streaming response chunks before they are sent to the client.
42///
43/// This hook allows modification of each chunk in the streaming response, enabling features like
44/// content filtering, transformation, or logging. The callback receives a chunk and must return
45/// a (potentially modified) chunk.
46///
47/// ### Examples
48///
49/// ```no_run
50/// use mistralrs_server_core::completions::CompletionOnChunkCallback;
51///
52/// let on_chunk: CompletionOnChunkCallback = Box::new(|mut chunk| {
53///     // Log the chunk or modify its content
54///     println!("Processing chunk: {:?}", chunk);
55///     chunk
56/// });
57/// ```
58pub type CompletionOnChunkCallback = OnChunkCallback<CompletionChunkResponse>;
59
60/// A callback function that is executed when the streaming response completes.
61///
62/// This hook receives all chunks that were streamed during the response, allowing for
63/// post-processing, analytics, or cleanup operations after the stream finishes.
64///
65/// ### Examples
66///
67/// ```no_run
68/// use mistralrs_server_core::completions::CompletionOnDoneCallback;
69///
70/// let on_done: CompletionOnDoneCallback = Box::new(|chunks| {
71///     println!("Stream completed with {} chunks", chunks.len());
72///     // Process all chunks for analytics
73/// });
74/// ```
75pub type CompletionOnDoneCallback = OnDoneCallback<CompletionChunkResponse>;
76
77/// A streaming response handler.
78///
79/// It processes incoming response chunks from a model and converts them
80/// into Server-Sent Events (SSE) format for real-time streaming to clients.
81pub type CompletionStreamer =
82    BaseStreamer<CompletionChunkResponse, CompletionOnChunkCallback, CompletionOnDoneCallback>;
83
84impl futures::Stream for CompletionStreamer {
85    type Item = Result<Event, axum::Error>;
86
87    /// Polls the stream for the next Server-Sent Event.
88    ///
89    /// This method implements the core streaming logic:
90    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
91    /// 2. Processes incoming model responses and converts them to SSE events
92    /// 3. Applies chunk modifications if a callback is provided
93    /// 4. Stores chunks if completion callback is configured
94    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        match self.done_state {
96            DoneState::SendingDone => {
97                // https://platform.openai.com/docs/api-reference/completions/create
98                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
99                self.done_state = DoneState::Done;
100                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
101            }
102            DoneState::Done => {
103                if let Some(on_done) = &self.on_done {
104                    on_done(&self.chunks);
105                }
106                return Poll::Ready(None);
107            }
108            DoneState::Running => (),
109        }
110
111        match self.rx.poll_recv(cx) {
112            Poll::Ready(Some(resp)) => match resp {
113                Response::CompletionModelError(msg, _) => {
114                    MistralRs::maybe_log_error(
115                        self.state.clone(),
116                        &ModelErrorMessage(msg.to_string()),
117                    );
118                    // Done now, just need to send the [DONE]
119                    self.done_state = DoneState::SendingDone;
120                    Poll::Ready(Some(Ok(Event::default().data(msg))))
121                }
122                Response::ValidationError(e) => Poll::Ready(Some(Ok(
123                    Event::default().data(sanitize_error_message(e.as_ref()))
124                ))),
125                Response::InternalError(e) => {
126                    MistralRs::maybe_log_error(self.state.clone(), &*e);
127                    Poll::Ready(Some(Ok(
128                        Event::default().data(sanitize_error_message(e.as_ref()))
129                    )))
130                }
131                Response::CompletionChunk(mut response) => {
132                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
133                        self.done_state = DoneState::SendingDone;
134                    }
135                    // Done now, just need to send the [DONE]
136                    MistralRs::maybe_log_response(self.state.clone(), &response);
137
138                    if let Some(on_chunk) = &self.on_chunk {
139                        response = on_chunk(response);
140                    }
141
142                    if self.store_chunks {
143                        self.chunks.push(response.clone());
144                    }
145
146                    Poll::Ready(Some(Event::default().json_data(response)))
147                }
148                Response::Done(_) => unreachable!(),
149                Response::CompletionDone(_) => unreachable!(),
150                Response::Chunk(_) => unreachable!(),
151                Response::ImageGeneration(_) => unreachable!(),
152                Response::ModelError(_, _) => unreachable!(),
153                Response::Speech { .. } => unreachable!(),
154                Response::Raw { .. } => unreachable!(),
155                Response::Embeddings { .. } => unreachable!(),
156            },
157            Poll::Pending | Poll::Ready(None) => Poll::Pending,
158        }
159    }
160}
161
162/// Represents different types of completion responses.
163pub type CompletionResponder =
164    BaseCompletionResponder<CompletionResponse, KeepAliveStream<CompletionStreamer>>;
165
166/// JSON error response structure for model errors.
167type JsonModelError = BaseJsonModelError<CompletionResponse>;
168impl ErrorToResponse for JsonModelError {}
169
170impl IntoResponse for CompletionResponder {
171    /// Converts the completion responder into an HTTP response.
172    fn into_response(self) -> axum::response::Response {
173        match self {
174            CompletionResponder::Sse(s) => s.into_response(),
175            CompletionResponder::Json(s) => Json(s).into_response(),
176            CompletionResponder::InternalError(e) => {
177                JsonError::new(sanitize_error_message(e.as_ref()))
178                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
179            }
180            CompletionResponder::ValidationError(e) => {
181                JsonError::new(sanitize_error_message(e.as_ref()))
182                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
183            }
184            CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
185                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
186        }
187    }
188}
189
190/// Parses and validates a completion request.
191///
192/// This function transforms an OpenAI-compatible completion request into the
193/// request format used by mistral.rs.
194pub fn parse_request(
195    oairequest: CompletionRequest,
196    state: Arc<MistralRs>,
197    tx: Sender<Response>,
198) -> Result<(Request, bool)> {
199    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
200    MistralRs::maybe_log_request(state.clone(), repr);
201
202    // Validate that the requested model matches the loaded model
203    validate_model_name(&oairequest.model, state.clone())?;
204
205    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
206
207    if oairequest.logprobs.is_some() {
208        warn!("Completion requests do not support logprobs.");
209    }
210
211    let is_streaming = oairequest.stream.unwrap_or(false);
212
213    let dry_params = get_dry_sampling_params(
214        oairequest.dry_multiplier,
215        oairequest.dry_sequence_breakers,
216        oairequest.dry_base,
217        oairequest.dry_allowed_length,
218    )?;
219
220    Ok((
221        Request::Normal(Box::new(NormalRequest {
222            id: state.next_request_id(),
223            messages: RequestMessage::Completion {
224                text: oairequest.prompt,
225                echo_prompt: oairequest.echo_prompt,
226                best_of: oairequest.best_of,
227            },
228            sampling_params: SamplingParams {
229                temperature: oairequest.temperature,
230                top_k: oairequest.top_k,
231                top_p: oairequest.top_p,
232                min_p: oairequest.min_p,
233                top_n_logprobs: 1,
234                frequency_penalty: oairequest.frequency_penalty,
235                presence_penalty: oairequest.presence_penalty,
236                repetition_penalty: oairequest.repetition_penalty,
237                max_len: oairequest.max_tokens,
238                stop_toks,
239                logits_bias: oairequest.logit_bias,
240                n_choices: oairequest.n_choices,
241                dry_params,
242            },
243            response: tx,
244            return_logprobs: false,
245            is_streaming,
246            suffix: oairequest.suffix,
247            constraint: match oairequest.grammar {
248                Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
249                Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
250                Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
251                Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
252                None => Constraint::None,
253            },
254            tool_choice: oairequest.tool_choice,
255            tools: oairequest.tools,
256            logits_processors: None,
257            return_raw_logits: false,
258            web_search_options: None,
259            model_id: if oairequest.model == "default" {
260                None
261            } else {
262                Some(oairequest.model.clone())
263            },
264            truncate_sequence: oairequest.truncate_sequence.unwrap_or(false),
265        })),
266        is_streaming,
267    ))
268}
269
270/// OpenAI-compatible completions endpoint handler.
271#[utoipa::path(
272    post,
273    tag = "Mistral.rs",
274    path = "/v1/completions",
275    request_body = CompletionRequest,
276    responses((status = 200, description = "Completions"))
277)]
278pub async fn completions(
279    State(state): ExtractedMistralRsState,
280    Json(oairequest): Json<CompletionRequest>,
281) -> CompletionResponder {
282    let (tx, mut rx) = create_response_channel(None);
283
284    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
285        Ok(x) => x,
286        Err(e) => return handle_error(state, e.into()),
287    };
288
289    if let Err(e) = send_request(&state, request).await {
290        return handle_error(state, e.into());
291    }
292
293    if is_streaming {
294        CompletionResponder::Sse(create_streamer(rx, state, None, None))
295    } else {
296        process_non_streaming_response(&mut rx, state).await
297    }
298}
299
300/// Handle route / generation errors and logging them.
301pub fn handle_error(
302    state: SharedMistralRsState,
303    e: Box<dyn std::error::Error + Send + Sync + 'static>,
304) -> CompletionResponder {
305    handle_completion_error(state, e)
306}
307
308/// Creates a SSE streamer for chat completions with optional callbacks.
309pub fn create_streamer(
310    rx: Receiver<Response>,
311    state: SharedMistralRsState,
312    on_chunk: Option<CompletionOnChunkCallback>,
313    on_done: Option<CompletionOnDoneCallback>,
314) -> Sse<KeepAliveStream<CompletionStreamer>> {
315    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
316    let keep_alive_interval = get_keep_alive_interval();
317
318    Sse::new(streamer)
319        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
320}
321
322/// Process non-streaming completion responses.
323pub async fn process_non_streaming_response(
324    rx: &mut Receiver<Response>,
325    state: SharedMistralRsState,
326) -> CompletionResponder {
327    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
328}
329
330/// Matches and processes different types of model responses into appropriate completion responses.
331pub fn match_responses(state: SharedMistralRsState, response: Response) -> CompletionResponder {
332    match response {
333        Response::InternalError(e) => {
334            MistralRs::maybe_log_error(state, &*e);
335            CompletionResponder::InternalError(e)
336        }
337        Response::CompletionModelError(msg, response) => {
338            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
339            MistralRs::maybe_log_response(state, &response);
340            CompletionResponder::ModelError(msg, response)
341        }
342        Response::ValidationError(e) => CompletionResponder::ValidationError(e),
343        Response::CompletionDone(response) => {
344            MistralRs::maybe_log_response(state, &response);
345            CompletionResponder::Json(response)
346        }
347        Response::CompletionChunk(_) => unreachable!(),
348        Response::Chunk(_) => unreachable!(),
349        Response::Done(_) => unreachable!(),
350        Response::ModelError(_, _) => unreachable!(),
351        Response::ImageGeneration(_) => unreachable!(),
352        Response::Speech { .. } => unreachable!(),
353        Response::Raw { .. } => unreachable!(),
354        Response::Embeddings { .. } => unreachable!(),
355    }
356}