mistralrs_server/
completions.rs

1use anyhow::Result;
2use std::{
3    env,
4    error::Error,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::Duration,
9};
10use tokio::sync::mpsc::{channel, Receiver, Sender};
11
12use crate::openai::{CompletionRequest, Grammar, StopTokens};
13use axum::{
14    extract::{Json, State},
15    http::{self, StatusCode},
16    response::{
17        sse::{Event, KeepAlive},
18        IntoResponse, Sse,
19    },
20};
21use mistralrs_core::{
22    CompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
23    RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
24};
25use serde::Serialize;
26use tracing::warn;
27
28#[derive(Debug)]
29struct ModelErrorMessage(String);
30impl std::fmt::Display for ModelErrorMessage {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "{}", self.0)
33    }
34}
35impl std::error::Error for ModelErrorMessage {}
36
37enum DoneState {
38    Running,
39    SendingDone,
40    Done,
41}
42
43pub struct Streamer {
44    rx: Receiver<Response>,
45    done_state: DoneState,
46    state: Arc<MistralRs>,
47}
48
49impl futures::Stream for Streamer {
50    type Item = Result<Event, axum::Error>;
51
52    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53        match self.done_state {
54            DoneState::SendingDone => {
55                // https://platform.openai.com/docs/api-reference/completions/create
56                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
57                self.done_state = DoneState::Done;
58                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
59            }
60            DoneState::Done => {
61                return Poll::Ready(None);
62            }
63            DoneState::Running => (),
64        }
65
66        match self.rx.poll_recv(cx) {
67            Poll::Ready(Some(resp)) => match resp {
68                Response::CompletionModelError(msg, _) => {
69                    MistralRs::maybe_log_error(
70                        self.state.clone(),
71                        &ModelErrorMessage(msg.to_string()),
72                    );
73                    // Done now, just need to send the [DONE]
74                    self.done_state = DoneState::SendingDone;
75                    Poll::Ready(Some(Ok(Event::default().data(msg))))
76                }
77                Response::ValidationError(e) => {
78                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
79                }
80                Response::InternalError(e) => {
81                    MistralRs::maybe_log_error(self.state.clone(), &*e);
82                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
83                }
84                Response::CompletionChunk(response) => {
85                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
86                        // Done now, just need to send the [DONE]
87                        self.done_state = DoneState::SendingDone;
88                    }
89                    MistralRs::maybe_log_response(self.state.clone(), &response);
90                    Poll::Ready(Some(Event::default().json_data(response)))
91                }
92                Response::Done(_) => unreachable!(),
93                Response::CompletionDone(_) => unreachable!(),
94                Response::Chunk(_) => unreachable!(),
95                Response::ImageGeneration(_) => unreachable!(),
96                Response::ModelError(_, _) => unreachable!(),
97                Response::Raw { .. } => unreachable!(),
98            },
99            Poll::Pending | Poll::Ready(None) => Poll::Pending,
100        }
101    }
102}
103
104pub enum CompletionResponder {
105    Sse(Sse<Streamer>),
106    Json(CompletionResponse),
107    ModelError(String, CompletionResponse),
108    InternalError(Box<dyn Error>),
109    ValidationError(Box<dyn Error>),
110}
111
112trait ErrorToResponse: Serialize {
113    fn to_response(&self, code: StatusCode) -> axum::response::Response {
114        let mut r = Json(self).into_response();
115        *r.status_mut() = code;
116        r
117    }
118}
119
120#[derive(Serialize)]
121struct JsonError {
122    message: String,
123}
124
125impl JsonError {
126    fn new(message: String) -> Self {
127        Self { message }
128    }
129}
130impl ErrorToResponse for JsonError {}
131
132#[derive(Serialize)]
133struct JsonModelError {
134    message: String,
135    partial_response: CompletionResponse,
136}
137
138impl JsonModelError {
139    fn new(message: String, partial_response: CompletionResponse) -> Self {
140        Self {
141            message,
142            partial_response,
143        }
144    }
145}
146
147impl ErrorToResponse for JsonModelError {}
148
149impl IntoResponse for CompletionResponder {
150    fn into_response(self) -> axum::response::Response {
151        match self {
152            CompletionResponder::Sse(s) => s.into_response(),
153            CompletionResponder::Json(s) => Json(s).into_response(),
154            CompletionResponder::InternalError(e) => {
155                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
156            }
157            CompletionResponder::ValidationError(e) => {
158                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
159            }
160            CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
161                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
162        }
163    }
164}
165
166fn parse_request(
167    oairequest: CompletionRequest,
168    state: Arc<MistralRs>,
169    tx: Sender<Response>,
170) -> Result<(Request, bool)> {
171    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
172    MistralRs::maybe_log_request(state.clone(), repr);
173
174    let stop_toks = match oairequest.stop_seqs {
175        Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
176        Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
177        None => None,
178    };
179
180    if oairequest.logprobs.is_some() {
181        warn!("Completion requests do not support logprobs.");
182    }
183
184    let is_streaming = oairequest.stream.unwrap_or(false);
185
186    let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
187        Some(DrySamplingParams::new_with_defaults(
188            dry_multiplier,
189            oairequest.dry_sequence_breakers,
190            oairequest.dry_base,
191            oairequest.dry_allowed_length,
192        )?)
193    } else {
194        None
195    };
196    Ok((
197        Request::Normal(NormalRequest {
198            id: state.next_request_id(),
199            messages: RequestMessage::Completion {
200                text: oairequest.prompt,
201                echo_prompt: oairequest.echo_prompt,
202                best_of: oairequest.best_of,
203            },
204            sampling_params: SamplingParams {
205                temperature: oairequest.temperature,
206                top_k: oairequest.top_k,
207                top_p: oairequest.top_p,
208                min_p: oairequest.min_p,
209                top_n_logprobs: 1,
210                frequency_penalty: oairequest.frequency_penalty,
211                presence_penalty: oairequest.presence_penalty,
212                max_len: oairequest.max_tokens,
213                stop_toks,
214                logits_bias: oairequest.logit_bias,
215                n_choices: oairequest.n_choices,
216                dry_params,
217            },
218            response: tx,
219            return_logprobs: false,
220            is_streaming,
221            suffix: oairequest.suffix,
222            constraint: match oairequest.grammar {
223                Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
224                Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
225                Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
226                Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
227                None => Constraint::None,
228            },
229            adapters: oairequest.adapters,
230            tool_choice: oairequest.tool_choice,
231            tools: oairequest.tools,
232            logits_processors: None,
233            return_raw_logits: false,
234            web_search_options: None,
235        }),
236        is_streaming,
237    ))
238}
239
240#[utoipa::path(
241    post,
242    tag = "Mistral.rs",
243    path = "/v1/completions",
244    request_body = CompletionRequest,
245    responses((status = 200, description = "Completions"))
246)]
247
248pub async fn completions(
249    State(state): State<Arc<MistralRs>>,
250    Json(oairequest): Json<CompletionRequest>,
251) -> CompletionResponder {
252    let (tx, mut rx) = channel(10_000);
253    if oairequest.logprobs.is_some() {
254        return CompletionResponder::ValidationError(
255            "Completion requests do not support logprobs.".into(),
256        );
257    }
258
259    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
260        Ok(x) => x,
261        Err(e) => {
262            let e = anyhow::Error::msg(e.to_string());
263            MistralRs::maybe_log_error(state, &*e);
264            return CompletionResponder::InternalError(e.into());
265        }
266    };
267    let sender = state.get_sender().unwrap();
268
269    if let Err(e) = sender.send(request).await {
270        let e = anyhow::Error::msg(e.to_string());
271        MistralRs::maybe_log_error(state, &*e);
272        return CompletionResponder::InternalError(e.into());
273    }
274
275    if is_streaming {
276        let streamer = Streamer {
277            rx,
278            done_state: DoneState::Running,
279            state,
280        };
281
282        let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL")
283            .map(|val| val.parse::<u64>().unwrap_or(10000))
284            .unwrap_or(10000);
285        CompletionResponder::Sse(
286            Sse::new(streamer)
287                .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))),
288        )
289    } else {
290        let response = match rx.recv().await {
291            Some(response) => response,
292            None => {
293                let e = anyhow::Error::msg("No response received from the model.");
294                MistralRs::maybe_log_error(state, &*e);
295                return CompletionResponder::InternalError(e.into());
296            }
297        };
298
299        match response {
300            Response::InternalError(e) => {
301                MistralRs::maybe_log_error(state, &*e);
302                CompletionResponder::InternalError(e)
303            }
304            Response::CompletionModelError(msg, response) => {
305                MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
306                MistralRs::maybe_log_response(state, &response);
307                CompletionResponder::ModelError(msg, response)
308            }
309            Response::ValidationError(e) => CompletionResponder::ValidationError(e),
310            Response::CompletionDone(response) => {
311                MistralRs::maybe_log_response(state, &response);
312                CompletionResponder::Json(response)
313            }
314            Response::CompletionChunk(_) => unreachable!(),
315            Response::Chunk(_) => unreachable!(),
316            Response::Done(_) => unreachable!(),
317            Response::ModelError(_, _) => unreachable!(),
318            Response::ImageGeneration(_) => unreachable!(),
319            Response::Raw { .. } => unreachable!(),
320        }
321    }
322}