mistralrs_server_core/
completions.rs

1//! ## Completions functionality and route handler.
2
3use std::{
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use anyhow::Result;
11use axum::{
12    extract::{Json, State},
13    http::{self},
14    response::{
15        sse::{Event, KeepAlive},
16        IntoResponse, Sse,
17    },
18};
19use mistralrs_core::{
20    CompletionChunkResponse, CompletionResponse, Constraint, MistralRs, NormalRequest, Request,
21    RequestMessage, Response, SamplingParams,
22};
23use tokio::sync::mpsc::{Receiver, Sender};
24use tracing::warn;
25
26use crate::{
27    completion_core::{
28        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
29        BaseCompletionResponder,
30    },
31    handler_core::{
32        base_process_non_streaming_response, create_response_channel, send_request,
33        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
34    },
35    openai::{CompletionRequest, Grammar},
36    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
37    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
38    util::{sanitize_error_message, validate_model_name},
39};
40
41/// A callback function that processes streaming response chunks before they are sent to the client.
42///
43/// This hook allows modification of each chunk in the streaming response, enabling features like
44/// content filtering, transformation, or logging. The callback receives a chunk and must return
45/// a (potentially modified) chunk.
46///
47/// ### Examples
48///
49/// ```no_run
50/// use mistralrs_server_core::completions::CompletionOnChunkCallback;
51///
52/// let on_chunk: CompletionOnChunkCallback = Box::new(|mut chunk| {
53///     // Log the chunk or modify its content
54///     println!("Processing chunk: {:?}", chunk);
55///     chunk
56/// });
57/// ```
58pub type CompletionOnChunkCallback = OnChunkCallback<CompletionChunkResponse>;
59
60/// A callback function that is executed when the streaming response completes.
61///
62/// This hook receives all chunks that were streamed during the response, allowing for
63/// post-processing, analytics, or cleanup operations after the stream finishes.
64///
65/// ### Examples
66///
67/// ```no_run
68/// use mistralrs_server_core::completions::CompletionOnDoneCallback;
69///
70/// let on_done: CompletionOnDoneCallback = Box::new(|chunks| {
71///     println!("Stream completed with {} chunks", chunks.len());
72///     // Process all chunks for analytics
73/// });
74/// ```
75pub type CompletionOnDoneCallback = OnDoneCallback<CompletionChunkResponse>;
76
77/// A streaming response handler.
78///
79/// It processes incoming response chunks from a model and converts them
80/// into Server-Sent Events (SSE) format for real-time streaming to clients.
81pub type CompletionStreamer =
82    BaseStreamer<CompletionChunkResponse, CompletionOnChunkCallback, CompletionOnDoneCallback>;
83
84impl futures::Stream for CompletionStreamer {
85    type Item = Result<Event, axum::Error>;
86
87    /// Polls the stream for the next Server-Sent Event.
88    ///
89    /// This method implements the core streaming logic:
90    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
91    /// 2. Processes incoming model responses and converts them to SSE events
92    /// 3. Applies chunk modifications if a callback is provided
93    /// 4. Stores chunks if completion callback is configured
94    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        match self.done_state {
96            DoneState::SendingDone => {
97                // https://platform.openai.com/docs/api-reference/completions/create
98                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
99                self.done_state = DoneState::Done;
100                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
101            }
102            DoneState::Done => {
103                if let Some(on_done) = &self.on_done {
104                    on_done(&self.chunks);
105                }
106                return Poll::Ready(None);
107            }
108            DoneState::Running => (),
109        }
110
111        match self.rx.poll_recv(cx) {
112            Poll::Ready(Some(resp)) => match resp {
113                Response::CompletionModelError(msg, _) => {
114                    MistralRs::maybe_log_error(
115                        self.state.clone(),
116                        &ModelErrorMessage(msg.to_string()),
117                    );
118                    // Done now, just need to send the [DONE]
119                    self.done_state = DoneState::SendingDone;
120                    Poll::Ready(Some(Ok(Event::default().data(msg))))
121                }
122                Response::ValidationError(e) => Poll::Ready(Some(Ok(
123                    Event::default().data(sanitize_error_message(e.as_ref()))
124                ))),
125                Response::InternalError(e) => {
126                    MistralRs::maybe_log_error(self.state.clone(), &*e);
127                    Poll::Ready(Some(Ok(
128                        Event::default().data(sanitize_error_message(e.as_ref()))
129                    )))
130                }
131                Response::CompletionChunk(mut response) => {
132                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
133                        self.done_state = DoneState::SendingDone;
134                    }
135                    // Done now, just need to send the [DONE]
136                    MistralRs::maybe_log_response(self.state.clone(), &response);
137
138                    if let Some(on_chunk) = &self.on_chunk {
139                        response = on_chunk(response);
140                    }
141
142                    if self.store_chunks {
143                        self.chunks.push(response.clone());
144                    }
145
146                    Poll::Ready(Some(Event::default().json_data(response)))
147                }
148                Response::Done(_) => unreachable!(),
149                Response::CompletionDone(_) => unreachable!(),
150                Response::Chunk(_) => unreachable!(),
151                Response::ImageGeneration(_) => unreachable!(),
152                Response::ModelError(_, _) => unreachable!(),
153                Response::Speech { .. } => unreachable!(),
154                Response::Raw { .. } => unreachable!(),
155            },
156            Poll::Pending | Poll::Ready(None) => Poll::Pending,
157        }
158    }
159}
160
161/// Represents different types of completion responses.
162pub type CompletionResponder = BaseCompletionResponder<CompletionResponse, CompletionStreamer>;
163
164/// JSON error response structure for model errors.
165type JsonModelError = BaseJsonModelError<CompletionResponse>;
166impl ErrorToResponse for JsonModelError {}
167
168impl IntoResponse for CompletionResponder {
169    /// Converts the completion responder into an HTTP response.
170    fn into_response(self) -> axum::response::Response {
171        match self {
172            CompletionResponder::Sse(s) => s.into_response(),
173            CompletionResponder::Json(s) => Json(s).into_response(),
174            CompletionResponder::InternalError(e) => {
175                JsonError::new(sanitize_error_message(e.as_ref()))
176                    .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
177            }
178            CompletionResponder::ValidationError(e) => {
179                JsonError::new(sanitize_error_message(e.as_ref()))
180                    .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
181            }
182            CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
183                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
184        }
185    }
186}
187
188/// Parses and validates a completion request.
189///
190/// This function transforms an OpenAI-compatible completion request into the
191/// request format used by mistral.rs.
192pub fn parse_request(
193    oairequest: CompletionRequest,
194    state: Arc<MistralRs>,
195    tx: Sender<Response>,
196) -> Result<(Request, bool)> {
197    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
198    MistralRs::maybe_log_request(state.clone(), repr);
199
200    // Validate that the requested model matches the loaded model
201    validate_model_name(&oairequest.model, state.clone())?;
202
203    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
204
205    if oairequest.logprobs.is_some() {
206        warn!("Completion requests do not support logprobs.");
207    }
208
209    let is_streaming = oairequest.stream.unwrap_or(false);
210
211    let dry_params = get_dry_sampling_params(
212        oairequest.dry_multiplier,
213        oairequest.dry_sequence_breakers,
214        oairequest.dry_base,
215        oairequest.dry_allowed_length,
216    )?;
217
218    Ok((
219        Request::Normal(Box::new(NormalRequest {
220            id: state.next_request_id(),
221            messages: RequestMessage::Completion {
222                text: oairequest.prompt,
223                echo_prompt: oairequest.echo_prompt,
224                best_of: oairequest.best_of,
225            },
226            sampling_params: SamplingParams {
227                temperature: oairequest.temperature,
228                top_k: oairequest.top_k,
229                top_p: oairequest.top_p,
230                min_p: oairequest.min_p,
231                top_n_logprobs: 1,
232                frequency_penalty: oairequest.frequency_penalty,
233                presence_penalty: oairequest.presence_penalty,
234                max_len: oairequest.max_tokens,
235                stop_toks,
236                logits_bias: oairequest.logit_bias,
237                n_choices: oairequest.n_choices,
238                dry_params,
239            },
240            response: tx,
241            return_logprobs: false,
242            is_streaming,
243            suffix: oairequest.suffix,
244            constraint: match oairequest.grammar {
245                Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
246                Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
247                Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
248                Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
249                None => Constraint::None,
250            },
251            tool_choice: oairequest.tool_choice,
252            tools: oairequest.tools,
253            logits_processors: None,
254            return_raw_logits: false,
255            web_search_options: None,
256            model_id: if oairequest.model == "default" {
257                None
258            } else {
259                Some(oairequest.model.clone())
260            },
261        })),
262        is_streaming,
263    ))
264}
265
266/// OpenAI-compatible completions endpoint handler.
267#[utoipa::path(
268    post,
269    tag = "Mistral.rs",
270    path = "/v1/completions",
271    request_body = CompletionRequest,
272    responses((status = 200, description = "Completions"))
273)]
274pub async fn completions(
275    State(state): ExtractedMistralRsState,
276    Json(oairequest): Json<CompletionRequest>,
277) -> CompletionResponder {
278    let (tx, mut rx) = create_response_channel(None);
279
280    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
281        Ok(x) => x,
282        Err(e) => return handle_error(state, e.into()),
283    };
284
285    if let Err(e) = send_request(&state, request).await {
286        return handle_error(state, e.into());
287    }
288
289    if is_streaming {
290        CompletionResponder::Sse(create_streamer(rx, state, None, None))
291    } else {
292        process_non_streaming_response(&mut rx, state).await
293    }
294}
295
296/// Handle route / generation errors and logging them.
297pub fn handle_error(
298    state: SharedMistralRsState,
299    e: Box<dyn std::error::Error + Send + Sync + 'static>,
300) -> CompletionResponder {
301    handle_completion_error(state, e)
302}
303
304/// Creates a SSE streamer for chat completions with optional callbacks.
305pub fn create_streamer(
306    rx: Receiver<Response>,
307    state: SharedMistralRsState,
308    on_chunk: Option<CompletionOnChunkCallback>,
309    on_done: Option<CompletionOnDoneCallback>,
310) -> Sse<CompletionStreamer> {
311    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
312    let keep_alive_interval = get_keep_alive_interval();
313
314    Sse::new(streamer)
315        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
316}
317
318/// Process non-streaming completion responses.
319pub async fn process_non_streaming_response(
320    rx: &mut Receiver<Response>,
321    state: SharedMistralRsState,
322) -> CompletionResponder {
323    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
324}
325
326/// Matches and processes different types of model responses into appropriate completion responses.
327pub fn match_responses(state: SharedMistralRsState, response: Response) -> CompletionResponder {
328    match response {
329        Response::InternalError(e) => {
330            MistralRs::maybe_log_error(state, &*e);
331            CompletionResponder::InternalError(e)
332        }
333        Response::CompletionModelError(msg, response) => {
334            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
335            MistralRs::maybe_log_response(state, &response);
336            CompletionResponder::ModelError(msg, response)
337        }
338        Response::ValidationError(e) => CompletionResponder::ValidationError(e),
339        Response::CompletionDone(response) => {
340            MistralRs::maybe_log_response(state, &response);
341            CompletionResponder::Json(response)
342        }
343        Response::CompletionChunk(_) => unreachable!(),
344        Response::Chunk(_) => unreachable!(),
345        Response::Done(_) => unreachable!(),
346        Response::ModelError(_, _) => unreachable!(),
347        Response::ImageGeneration(_) => unreachable!(),
348        Response::Speech { .. } => unreachable!(),
349        Response::Raw { .. } => unreachable!(),
350    }
351}