mistralrs_server_core/
completions.rs

1//! ## Completions functionality and route handler.
2
3use std::{
4    pin::Pin,
5    sync::Arc,
6    task::{Context, Poll},
7    time::Duration,
8};
9
10use anyhow::Result;
11use axum::{
12    extract::{Json, State},
13    http::{self},
14    response::{
15        sse::{Event, KeepAlive},
16        IntoResponse, Sse,
17    },
18};
19use mistralrs_core::{
20    CompletionChunkResponse, CompletionResponse, Constraint, MistralRs, NormalRequest, Request,
21    RequestMessage, Response, SamplingParams,
22};
23use tokio::sync::mpsc::{Receiver, Sender};
24use tracing::warn;
25
26use crate::{
27    completion_core::{
28        convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
29        BaseCompletionResponder,
30    },
31    handler_core::{
32        base_process_non_streaming_response, create_response_channel, send_request,
33        BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
34    },
35    openai::{CompletionRequest, Grammar},
36    streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
37    types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
38    util::validate_model_name,
39};
40
41/// A callback function that processes streaming response chunks before they are sent to the client.
42///
43/// This hook allows modification of each chunk in the streaming response, enabling features like
44/// content filtering, transformation, or logging. The callback receives a chunk and must return
45/// a (potentially modified) chunk.
46///
47/// ### Examples
48///
49/// ```no_run
50/// use mistralrs_server_core::completions::CompletionOnChunkCallback;
51///
52/// let on_chunk: CompletionOnChunkCallback = Box::new(|mut chunk| {
53///     // Log the chunk or modify its content
54///     println!("Processing chunk: {:?}", chunk);
55///     chunk
56/// });
57/// ```
58pub type CompletionOnChunkCallback = OnChunkCallback<CompletionChunkResponse>;
59
60/// A callback function that is executed when the streaming response completes.
61///
62/// This hook receives all chunks that were streamed during the response, allowing for
63/// post-processing, analytics, or cleanup operations after the stream finishes.
64///
65/// ### Examples
66///
67/// ```no_run
68/// use mistralrs_server_core::completions::CompletionOnDoneCallback;
69///
70/// let on_done: CompletionOnDoneCallback = Box::new(|chunks| {
71///     println!("Stream completed with {} chunks", chunks.len());
72///     // Process all chunks for analytics
73/// });
74/// ```
75pub type CompletionOnDoneCallback = OnDoneCallback<CompletionChunkResponse>;
76
77/// A streaming response handler.
78///
79/// It processes incoming response chunks from a model and converts them
80/// into Server-Sent Events (SSE) format for real-time streaming to clients.
81pub type CompletionStreamer =
82    BaseStreamer<CompletionChunkResponse, CompletionOnChunkCallback, CompletionOnDoneCallback>;
83
84impl futures::Stream for CompletionStreamer {
85    type Item = Result<Event, axum::Error>;
86
87    /// Polls the stream for the next Server-Sent Event.
88    ///
89    /// This method implements the core streaming logic:
90    /// 1. Handles stream completion by sending `[DONE]` and executing callbacks
91    /// 2. Processes incoming model responses and converts them to SSE events
92    /// 3. Applies chunk modifications if a callback is provided
93    /// 4. Stores chunks if completion callback is configured
94    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        match self.done_state {
96            DoneState::SendingDone => {
97                // https://platform.openai.com/docs/api-reference/completions/create
98                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
99                self.done_state = DoneState::Done;
100                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
101            }
102            DoneState::Done => {
103                if let Some(on_done) = &self.on_done {
104                    on_done(&self.chunks);
105                }
106                return Poll::Ready(None);
107            }
108            DoneState::Running => (),
109        }
110
111        match self.rx.poll_recv(cx) {
112            Poll::Ready(Some(resp)) => match resp {
113                Response::CompletionModelError(msg, _) => {
114                    MistralRs::maybe_log_error(
115                        self.state.clone(),
116                        &ModelErrorMessage(msg.to_string()),
117                    );
118                    // Done now, just need to send the [DONE]
119                    self.done_state = DoneState::SendingDone;
120                    Poll::Ready(Some(Ok(Event::default().data(msg))))
121                }
122                Response::ValidationError(e) => {
123                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
124                }
125                Response::InternalError(e) => {
126                    MistralRs::maybe_log_error(self.state.clone(), &*e);
127                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
128                }
129                Response::CompletionChunk(mut response) => {
130                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
131                        self.done_state = DoneState::SendingDone;
132                    }
133                    // Done now, just need to send the [DONE]
134                    MistralRs::maybe_log_response(self.state.clone(), &response);
135
136                    if let Some(on_chunk) = &self.on_chunk {
137                        response = on_chunk(response);
138                    }
139
140                    if self.store_chunks {
141                        self.chunks.push(response.clone());
142                    }
143
144                    Poll::Ready(Some(Event::default().json_data(response)))
145                }
146                Response::Done(_) => unreachable!(),
147                Response::CompletionDone(_) => unreachable!(),
148                Response::Chunk(_) => unreachable!(),
149                Response::ImageGeneration(_) => unreachable!(),
150                Response::ModelError(_, _) => unreachable!(),
151                Response::Speech { .. } => unreachable!(),
152                Response::Raw { .. } => unreachable!(),
153            },
154            Poll::Pending | Poll::Ready(None) => Poll::Pending,
155        }
156    }
157}
158
159/// Represents different types of completion responses.
160pub type CompletionResponder = BaseCompletionResponder<CompletionResponse, CompletionStreamer>;
161
162/// JSON error response structure for model errors.
163type JsonModelError = BaseJsonModelError<CompletionResponse>;
164impl ErrorToResponse for JsonModelError {}
165
166impl IntoResponse for CompletionResponder {
167    /// Converts the completion responder into an HTTP response.
168    fn into_response(self) -> axum::response::Response {
169        match self {
170            CompletionResponder::Sse(s) => s.into_response(),
171            CompletionResponder::Json(s) => Json(s).into_response(),
172            CompletionResponder::InternalError(e) => {
173                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
174            }
175            CompletionResponder::ValidationError(e) => {
176                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
177            }
178            CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
179                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
180        }
181    }
182}
183
184/// Parses and validates a completion request.
185///
186/// This function transforms an OpenAI-compatible completion request into the
187/// request format used by mistral.rs.
188pub fn parse_request(
189    oairequest: CompletionRequest,
190    state: Arc<MistralRs>,
191    tx: Sender<Response>,
192) -> Result<(Request, bool)> {
193    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
194    MistralRs::maybe_log_request(state.clone(), repr);
195
196    // Validate that the requested model matches the loaded model
197    validate_model_name(&oairequest.model, state.clone())?;
198
199    let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
200
201    if oairequest.logprobs.is_some() {
202        warn!("Completion requests do not support logprobs.");
203    }
204
205    let is_streaming = oairequest.stream.unwrap_or(false);
206
207    let dry_params = get_dry_sampling_params(
208        oairequest.dry_multiplier,
209        oairequest.dry_sequence_breakers,
210        oairequest.dry_base,
211        oairequest.dry_allowed_length,
212    )?;
213
214    Ok((
215        Request::Normal(Box::new(NormalRequest {
216            id: state.next_request_id(),
217            messages: RequestMessage::Completion {
218                text: oairequest.prompt,
219                echo_prompt: oairequest.echo_prompt,
220                best_of: oairequest.best_of,
221            },
222            sampling_params: SamplingParams {
223                temperature: oairequest.temperature,
224                top_k: oairequest.top_k,
225                top_p: oairequest.top_p,
226                min_p: oairequest.min_p,
227                top_n_logprobs: 1,
228                frequency_penalty: oairequest.frequency_penalty,
229                presence_penalty: oairequest.presence_penalty,
230                max_len: oairequest.max_tokens,
231                stop_toks,
232                logits_bias: oairequest.logit_bias,
233                n_choices: oairequest.n_choices,
234                dry_params,
235            },
236            response: tx,
237            return_logprobs: false,
238            is_streaming,
239            suffix: oairequest.suffix,
240            constraint: match oairequest.grammar {
241                Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
242                Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
243                Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
244                Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
245                None => Constraint::None,
246            },
247            tool_choice: oairequest.tool_choice,
248            tools: oairequest.tools,
249            logits_processors: None,
250            return_raw_logits: false,
251            web_search_options: None,
252            model_id: if oairequest.model == "default" {
253                None
254            } else {
255                Some(oairequest.model.clone())
256            },
257        })),
258        is_streaming,
259    ))
260}
261
262/// OpenAI-compatible completions endpoint handler.
263#[utoipa::path(
264    post,
265    tag = "Mistral.rs",
266    path = "/v1/completions",
267    request_body = CompletionRequest,
268    responses((status = 200, description = "Completions"))
269)]
270pub async fn completions(
271    State(state): ExtractedMistralRsState,
272    Json(oairequest): Json<CompletionRequest>,
273) -> CompletionResponder {
274    let (tx, mut rx) = create_response_channel(None);
275
276    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
277        Ok(x) => x,
278        Err(e) => return handle_error(state, e.into()),
279    };
280
281    if let Err(e) = send_request(&state, request).await {
282        return handle_error(state, e.into());
283    }
284
285    if is_streaming {
286        CompletionResponder::Sse(create_streamer(rx, state, None, None))
287    } else {
288        process_non_streaming_response(&mut rx, state).await
289    }
290}
291
292/// Handle route / generation errors and logging them.
293pub fn handle_error(
294    state: SharedMistralRsState,
295    e: Box<dyn std::error::Error + Send + Sync + 'static>,
296) -> CompletionResponder {
297    handle_completion_error(state, e)
298}
299
300/// Creates a SSE streamer for chat completions with optional callbacks.
301pub fn create_streamer(
302    rx: Receiver<Response>,
303    state: SharedMistralRsState,
304    on_chunk: Option<CompletionOnChunkCallback>,
305    on_done: Option<CompletionOnDoneCallback>,
306) -> Sse<CompletionStreamer> {
307    let streamer = base_create_streamer(rx, state, on_chunk, on_done);
308    let keep_alive_interval = get_keep_alive_interval();
309
310    Sse::new(streamer)
311        .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
312}
313
314/// Process non-streaming completion responses.
315pub async fn process_non_streaming_response(
316    rx: &mut Receiver<Response>,
317    state: SharedMistralRsState,
318) -> CompletionResponder {
319    base_process_non_streaming_response(rx, state, match_responses, handle_error).await
320}
321
322/// Matches and processes different types of model responses into appropriate completion responses.
323pub fn match_responses(state: SharedMistralRsState, response: Response) -> CompletionResponder {
324    match response {
325        Response::InternalError(e) => {
326            MistralRs::maybe_log_error(state, &*e);
327            CompletionResponder::InternalError(e)
328        }
329        Response::CompletionModelError(msg, response) => {
330            MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
331            MistralRs::maybe_log_response(state, &response);
332            CompletionResponder::ModelError(msg, response)
333        }
334        Response::ValidationError(e) => CompletionResponder::ValidationError(e),
335        Response::CompletionDone(response) => {
336            MistralRs::maybe_log_response(state, &response);
337            CompletionResponder::Json(response)
338        }
339        Response::CompletionChunk(_) => unreachable!(),
340        Response::Chunk(_) => unreachable!(),
341        Response::Done(_) => unreachable!(),
342        Response::ModelError(_, _) => unreachable!(),
343        Response::ImageGeneration(_) => unreachable!(),
344        Response::Speech { .. } => unreachable!(),
345        Response::Raw { .. } => unreachable!(),
346    }
347}