1use anyhow::Result;
2use std::{
3 env,
4 error::Error,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::Duration,
9};
10use tokio::sync::mpsc::{channel, Receiver, Sender};
11
12use crate::openai::{CompletionRequest, Grammar, StopTokens};
13use axum::{
14 extract::{Json, State},
15 http::{self, StatusCode},
16 response::{
17 sse::{Event, KeepAlive},
18 IntoResponse, Sse,
19 },
20};
21use mistralrs_core::{
22 CompletionResponse, Constraint, DrySamplingParams, MistralRs, NormalRequest, Request,
23 RequestMessage, Response, SamplingParams, StopTokens as InternalStopTokens,
24};
25use serde::Serialize;
26use tracing::warn;
27
28#[derive(Debug)]
29struct ModelErrorMessage(String);
30impl std::fmt::Display for ModelErrorMessage {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "{}", self.0)
33 }
34}
35impl std::error::Error for ModelErrorMessage {}
36
37enum DoneState {
38 Running,
39 SendingDone,
40 Done,
41}
42
43pub struct Streamer {
44 rx: Receiver<Response>,
45 done_state: DoneState,
46 state: Arc<MistralRs>,
47}
48
49impl futures::Stream for Streamer {
50 type Item = Result<Event, axum::Error>;
51
52 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
53 match self.done_state {
54 DoneState::SendingDone => {
55 self.done_state = DoneState::Done;
58 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
59 }
60 DoneState::Done => {
61 return Poll::Ready(None);
62 }
63 DoneState::Running => (),
64 }
65
66 match self.rx.poll_recv(cx) {
67 Poll::Ready(Some(resp)) => match resp {
68 Response::CompletionModelError(msg, _) => {
69 MistralRs::maybe_log_error(
70 self.state.clone(),
71 &ModelErrorMessage(msg.to_string()),
72 );
73 self.done_state = DoneState::SendingDone;
75 Poll::Ready(Some(Ok(Event::default().data(msg))))
76 }
77 Response::ValidationError(e) => {
78 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
79 }
80 Response::InternalError(e) => {
81 MistralRs::maybe_log_error(self.state.clone(), &*e);
82 Poll::Ready(Some(Ok(Event::default().data(e.to_string()))))
83 }
84 Response::CompletionChunk(response) => {
85 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
86 self.done_state = DoneState::SendingDone;
88 }
89 MistralRs::maybe_log_response(self.state.clone(), &response);
90 Poll::Ready(Some(Event::default().json_data(response)))
91 }
92 Response::Done(_) => unreachable!(),
93 Response::CompletionDone(_) => unreachable!(),
94 Response::Chunk(_) => unreachable!(),
95 Response::ImageGeneration(_) => unreachable!(),
96 Response::ModelError(_, _) => unreachable!(),
97 Response::Raw { .. } => unreachable!(),
98 },
99 Poll::Pending | Poll::Ready(None) => Poll::Pending,
100 }
101 }
102}
103
104pub enum CompletionResponder {
105 Sse(Sse<Streamer>),
106 Json(CompletionResponse),
107 ModelError(String, CompletionResponse),
108 InternalError(Box<dyn Error>),
109 ValidationError(Box<dyn Error>),
110}
111
112trait ErrorToResponse: Serialize {
113 fn to_response(&self, code: StatusCode) -> axum::response::Response {
114 let mut r = Json(self).into_response();
115 *r.status_mut() = code;
116 r
117 }
118}
119
120#[derive(Serialize)]
121struct JsonError {
122 message: String,
123}
124
125impl JsonError {
126 fn new(message: String) -> Self {
127 Self { message }
128 }
129}
130impl ErrorToResponse for JsonError {}
131
132#[derive(Serialize)]
133struct JsonModelError {
134 message: String,
135 partial_response: CompletionResponse,
136}
137
138impl JsonModelError {
139 fn new(message: String, partial_response: CompletionResponse) -> Self {
140 Self {
141 message,
142 partial_response,
143 }
144 }
145}
146
147impl ErrorToResponse for JsonModelError {}
148
149impl IntoResponse for CompletionResponder {
150 fn into_response(self) -> axum::response::Response {
151 match self {
152 CompletionResponder::Sse(s) => s.into_response(),
153 CompletionResponder::Json(s) => Json(s).into_response(),
154 CompletionResponder::InternalError(e) => {
155 JsonError::new(e.to_string()).to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
156 }
157 CompletionResponder::ValidationError(e) => {
158 JsonError::new(e.to_string()).to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
159 }
160 CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
161 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
162 }
163 }
164}
165
166fn parse_request(
167 oairequest: CompletionRequest,
168 state: Arc<MistralRs>,
169 tx: Sender<Response>,
170) -> Result<(Request, bool)> {
171 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
172 MistralRs::maybe_log_request(state.clone(), repr);
173
174 let stop_toks = match oairequest.stop_seqs {
175 Some(StopTokens::Multi(m)) => Some(InternalStopTokens::Seqs(m)),
176 Some(StopTokens::Single(s)) => Some(InternalStopTokens::Seqs(vec![s])),
177 None => None,
178 };
179
180 if oairequest.logprobs.is_some() {
181 warn!("Completion requests do not support logprobs.");
182 }
183
184 let is_streaming = oairequest.stream.unwrap_or(false);
185
186 let dry_params = if let Some(dry_multiplier) = oairequest.dry_multiplier {
187 Some(DrySamplingParams::new_with_defaults(
188 dry_multiplier,
189 oairequest.dry_sequence_breakers,
190 oairequest.dry_base,
191 oairequest.dry_allowed_length,
192 )?)
193 } else {
194 None
195 };
196 Ok((
197 Request::Normal(NormalRequest {
198 id: state.next_request_id(),
199 messages: RequestMessage::Completion {
200 text: oairequest.prompt,
201 echo_prompt: oairequest.echo_prompt,
202 best_of: oairequest.best_of,
203 },
204 sampling_params: SamplingParams {
205 temperature: oairequest.temperature,
206 top_k: oairequest.top_k,
207 top_p: oairequest.top_p,
208 min_p: oairequest.min_p,
209 top_n_logprobs: 1,
210 frequency_penalty: oairequest.frequency_penalty,
211 presence_penalty: oairequest.presence_penalty,
212 max_len: oairequest.max_tokens,
213 stop_toks,
214 logits_bias: oairequest.logit_bias,
215 n_choices: oairequest.n_choices,
216 dry_params,
217 },
218 response: tx,
219 return_logprobs: false,
220 is_streaming,
221 suffix: oairequest.suffix,
222 constraint: match oairequest.grammar {
223 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
224 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
225 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
226 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
227 None => Constraint::None,
228 },
229 adapters: oairequest.adapters,
230 tool_choice: oairequest.tool_choice,
231 tools: oairequest.tools,
232 logits_processors: None,
233 return_raw_logits: false,
234 web_search_options: None,
235 }),
236 is_streaming,
237 ))
238}
239
240#[utoipa::path(
241 post,
242 tag = "Mistral.rs",
243 path = "/v1/completions",
244 request_body = CompletionRequest,
245 responses((status = 200, description = "Completions"))
246)]
247
248pub async fn completions(
249 State(state): State<Arc<MistralRs>>,
250 Json(oairequest): Json<CompletionRequest>,
251) -> CompletionResponder {
252 let (tx, mut rx) = channel(10_000);
253 if oairequest.logprobs.is_some() {
254 return CompletionResponder::ValidationError(
255 "Completion requests do not support logprobs.".into(),
256 );
257 }
258
259 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
260 Ok(x) => x,
261 Err(e) => {
262 let e = anyhow::Error::msg(e.to_string());
263 MistralRs::maybe_log_error(state, &*e);
264 return CompletionResponder::InternalError(e.into());
265 }
266 };
267 let sender = state.get_sender().unwrap();
268
269 if let Err(e) = sender.send(request).await {
270 let e = anyhow::Error::msg(e.to_string());
271 MistralRs::maybe_log_error(state, &*e);
272 return CompletionResponder::InternalError(e.into());
273 }
274
275 if is_streaming {
276 let streamer = Streamer {
277 rx,
278 done_state: DoneState::Running,
279 state,
280 };
281
282 let keep_alive_interval = env::var("KEEP_ALIVE_INTERVAL")
283 .map(|val| val.parse::<u64>().unwrap_or(10000))
284 .unwrap_or(10000);
285 CompletionResponder::Sse(
286 Sse::new(streamer)
287 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval))),
288 )
289 } else {
290 let response = match rx.recv().await {
291 Some(response) => response,
292 None => {
293 let e = anyhow::Error::msg("No response received from the model.");
294 MistralRs::maybe_log_error(state, &*e);
295 return CompletionResponder::InternalError(e.into());
296 }
297 };
298
299 match response {
300 Response::InternalError(e) => {
301 MistralRs::maybe_log_error(state, &*e);
302 CompletionResponder::InternalError(e)
303 }
304 Response::CompletionModelError(msg, response) => {
305 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
306 MistralRs::maybe_log_response(state, &response);
307 CompletionResponder::ModelError(msg, response)
308 }
309 Response::ValidationError(e) => CompletionResponder::ValidationError(e),
310 Response::CompletionDone(response) => {
311 MistralRs::maybe_log_response(state, &response);
312 CompletionResponder::Json(response)
313 }
314 Response::CompletionChunk(_) => unreachable!(),
315 Response::Chunk(_) => unreachable!(),
316 Response::Done(_) => unreachable!(),
317 Response::ModelError(_, _) => unreachable!(),
318 Response::ImageGeneration(_) => unreachable!(),
319 Response::Raw { .. } => unreachable!(),
320 }
321 }
322}