mistralrs_server/
completions.rs

1use anyhow::Result;
2use std::{
3    env,
4    error::Error,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::Duration,
9};
10use tokio::sync::mpsc::{channel, Receiver, Sender};
11
12use crate::openai::{CompletionRequest, Grammar, StopTokens};
13use axum::{
14    extract::{Json, State},
15    http::{self, StatusCode},
16    response::{
17        sse::{Event, KeepAlive},
18        IntoResponse, Sse,
19    },
20};
21use mistralrs_core::{
22    CompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
23    RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
24};
25use serde::Serialize;
26use tracing::warn;
27
28#[derive(Debug)]
29struct ModelErrorMessage(String);
30impl std::fmt::Display for ModelErrorMessage {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "{}", self.0)
33    }
34}
35impl std::error::Error for ModelErrorMessage {}
36
37enum DoneState {
38    Running,
39    SendingDone,
40    Done,
41}
42
43pub struct Streamer {
44    rx: Receiver<Response>,
45    done_state: DoneState,
46    state: Arc<MistralRs>,
47}
48
49impl futures::Stream for Streamer {
50    type Item = Result<Event, axum::Error>;
51
52    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53        match self.done_state {
54            DoneState::SendingDone => {
55                // https://platform.openai.com/docs/api-reference/completions/create
56                // If true, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a data: [DONE] message.
57                self.done_state = DoneState::Done;
58                return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
59            }
60            DoneState::Done => {
61                return Poll::Ready(None);
62            }
63            DoneState::Running => (),
64        }
65
66        match self.rx.poll_recv(cx) {
67            Poll::Ready(Some(resp)) => match resp {
68                Response::CompletionModelError(msg, _) => {
69                    MistralRs::maybe_log_error(
70                        self.state.clone(),
71                        &ModelErrorMessage(msg.to_string()),
72                    );
73                    // Done now, just need to send the [DONE]
74                    self.done_state = DoneState::SendingDone;
75                    Poll::Ready(Some(Ok(Event::default().data(msg))))
76                }
77                Response::ValidationError(e) => {
78                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
79                }
80                Response::InternalError(e) => {
81                    MistralRs::maybe_log_error(self.state.clone(), &*e);
82                    Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
83                }
84                Response::CompletionChunk(response) => {
85                    if response.choices.iter().all(|x| x.finish_reason.is_some()) {
86                        // Done now, just need to send the [DONE]
87                        self.done_state = DoneState::SendingDone;
88                    }
89                    MistralRs::maybe_log_response(self.state.clone(), &response);
90                    Poll::Ready(Some(Event::default().json_data(response)))
91                }
92                Response::Done(_) => unreachable!(),
93                Response::CompletionDone(_) => unreachable!(),
94                Response::Chunk(_) => unreachable!(),
95                Response::ImageGeneration(_) => unreachable!(),
96                Response::ModelError(_, _) => unreachable!(),
97                Response::Raw { .. } => unreachable!(),
98            },
99            Poll::Pending | Poll::Ready(None) => Poll::Pending,
100        }
101    }
102}
103
104pub enum CompletionResponder {
105    Sse(Sse<Streamer>),
106    Json(CompletionResponse),
107    ModelError(String, CompletionResponse),
108    InternalError(Box<dyn Error>),
109    ValidationError(Box<dyn Error>),
110}
111
112trait ErrorToResponse: Serialize {
113    fn to_response(&self, code: StatusCode) -> axum::response::Response {
114        let mut r = Json(self).into_response();
115        *r.status_mut() = code;
116        r
117    }
118}
119
120#[derive(Serialize)]
121struct JsonError {
122    message: String,
123}
124
125impl JsonError {
126    fn new(message: String) -> Self {
127        Self { message }
128    }
129}
130impl ErrorToResponse for JsonError {}
131
132#[derive(Serialize)]
133struct JsonModelError {
134    message: String,
135    partial_response: CompletionResponse,
136}
137
138impl JsonModelError {
139    fn new(message: String, partial_response: CompletionResponse) -> Self {
140        Self {
141            message,
142            partial_response,
143        }
144    }
145}
146
147impl ErrorToResponse for JsonModelError {}
148
149impl IntoResponse for CompletionResponder {
150    fn into_response(self) -> axum::response::Response {
151        match self {
152            CompletionResponder::Sse(s) => s.into_response(),
153            CompletionResponder::Json(s) => Json(s).into_response(),
154            CompletionResponder::InternalError(e) => {
155                JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
156            }
157            CompletionResponder::ValidationError(e) => {
158                JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
159            }
160            CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
161                .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
162        }
163    }
164}
165
166fn parse_request(
167    oairequest: CompletionRequest,
168    state: Arc<MistralRs>,
169    tx: Sender<Response>,
170) -> Result<(Request, bool)> {
171    let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
172    MistralRs::maybe_log_request(state.clone(), repr);
173
174    let stop_toks = match oairequest.stop_seqs {
175        Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
176        Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
177        None => None,
178    };
179
180    if oairequest.logprobs.is_some() {
181        warn!("Completion requests do not support logprobs.");
182    }
183
184    let is_streaming = oairequest.stream.unwrap_or(false);
185
186    let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
187        Some(DrySamplingParams::new_with_defaults(
188            dry_multiplier,
189            oairequest.dry_sequence_breakers,
190            oairequest.dry_base,
191            oairequest.dry_allowed_length,
192        )?)
193    } else {
194        None
195    };
196    Ok((
197        Request::Normal(NormalRequest {
198            id: state.next_request_id(),
199            messages: RequestMessage::Completion {
200                text: oairequest.prompt,
201                echo_prompt: oairequest.echo_prompt,
202                best_of: oairequest.best_of,
203            },
204            sampling_params: SamplingParams {
205                temperature: oairequest.temperature,
206                top_k: oairequest.top_k,
207                top_p: oairequest.top_p,
208                min_p: oairequest.min_p,
209                top_n_logprobs: 1,
210                frequency_penalty: oairequest.frequency_penalty,
211                presence_penalty: oairequest.presence_penalty,
212                max_len: oairequest.max_tokens,
213                stop_toks,
214                logits_bias: oairequest.logit_bias,
215                n_choices: oairequest.n_choices,
216                dry_params,
217            },
218            response: tx,
219            return_logprobs: false,
220            is_streaming,
221            suffix: oairequest.suffix,
222            constraint: match oairequest.grammar {
223                Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
224                Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
225                Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
226                Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
227                None => Constraint::None,
228            },
229            tool_choice: oairequest.tool_choice,
230            tools: oairequest.tools,
231            logits_processors: None,
232            return_raw_logits: false,
233            web_search_options: None,
234        }),
235        is_streaming,
236    ))
237}
238
239#[utoipa::path(
240    post,
241    tag = "Mistral.rs",
242    path = "/v1/completions",
243    request_body = CompletionRequest,
244    responses((status = 200, description = "Completions"))
245)]
246
247pub async fn completions(
248    State(state): State<Arc<MistralRs>>,
249    Json(oairequest): Json<CompletionRequest>,
250) -> CompletionResponder {
251    let (tx, mut rx) = channel(10_000);
252    if oairequest.logprobs.is_some() {
253        return CompletionResponder::ValidationError(
254            "Completion requests do not support logprobs.".into(),
255        );
256    }
257
258    let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
259        Ok(x) => x,
260        Err(e) => {
261            let e = anyhow::Error::msg(e.to_string());
262            MistralRs::maybe_log_error(state, &*e);
263            return CompletionResponder::InternalError(e.into());
264        }
265    };
266    let sender = state.get_sender().unwrap();
267
268    if let Err(e) = sender.send(request).await {
269        let e = anyhow::Error::msg(e.to_string());
270        MistralRs::maybe_log_error(state, &*e);
271        return CompletionResponder::InternalError(e.into());
272    }
273
274    if is_streaming {
275        let streamer = Streamer {
276            rx,
277            done_state: DoneState::Running,
278            state,
279        };
280
281        let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL")
282            .map(|val| val.parse::<u64>().unwrap_or(10000))
283            .unwrap_or(10000);
284        CompletionResponder::Sse(
285            Sse::new(streamer)
286                .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))),
287        )
288    } else {
289        let response = match rx.recv().await {
290            Some(response) => response,
291            None => {
292                let e = anyhow::Error::msg("No response received from the model.");
293                MistralRs::maybe_log_error(state, &*e);
294                return CompletionResponder::InternalError(e.into());
295            }
296        };
297
298        match response {
299            Response::InternalError(e) => {
300                MistralRs::maybe_log_error(state, &*e);
301                CompletionResponder::InternalError(e)
302            }
303            Response::CompletionModelError(msg, response) => {
304                MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
305                MistralRs::maybe_log_response(state, &response);
306                CompletionResponder::ModelError(msg, response)
307            }
308            Response::ValidationError(e) => CompletionResponder::ValidationError(e),
309            Response::CompletionDone(response) => {
310                MistralRs::maybe_log_response(state, &response);
311                CompletionResponder::Json(response)
312            }
313            Response::CompletionChunk(_) => unreachable!(),
314            Response::Chunk(_) => unreachable!(),
315            Response::Done(_) => unreachable!(),
316            Response::ModelError(_, _) => unreachable!(),
317            Response::ImageGeneration(_) => unreachable!(),
318            Response::Raw { .. } => unreachable!(),
319        }
320    }
321}