1use std::{
4 pin::Pin,
5 sync::Arc,
6 task::{Context, Poll},
7 time::Duration,
8};
9
10use anyhow::Result;
11use axum::{
12 extract::{Json, State},
13 http::{self},
14 response::{
15 sse::{Event, KeepAlive, KeepAliveStream},
16 IntoResponse, Sse,
17 },
18};
19use mistralrs_core::{
20 CompletionChunkResponse, CompletionResponse, Constraint, MistralRs, NormalRequest, Request,
21 RequestMessage, Response, SamplingParams,
22};
23use tokio::sync::mpsc::{Receiver, Sender};
24use tracing::warn;
25
26use crate::{
27 completion_core::{
28 convert_stop_tokens, get_dry_sampling_params, handle_completion_error,
29 BaseCompletionResponder,
30 },
31 handler_core::{
32 base_process_non_streaming_response, create_response_channel, send_request,
33 BaseJsonModelError, ErrorToResponse, JsonError, ModelErrorMessage,
34 },
35 openai::{CompletionRequest, Grammar},
36 streaming::{base_create_streamer, get_keep_alive_interval, BaseStreamer, DoneState},
37 types::{ExtractedMistralRsState, OnChunkCallback, OnDoneCallback, SharedMistralRsState},
38 util::{sanitize_error_message, validate_model_name},
39};
40
41pub type CompletionOnChunkCallback = OnChunkCallback<CompletionChunkResponse>;
59
60pub type CompletionOnDoneCallback = OnDoneCallback<CompletionChunkResponse>;
76
77pub type CompletionStreamer =
82 BaseStreamer<CompletionChunkResponse, CompletionOnChunkCallback, CompletionOnDoneCallback>;
83
84impl futures::Stream for CompletionStreamer {
85 type Item = Result<Event, axum::Error>;
86
87 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 match self.done_state {
96 DoneState::SendingDone => {
97 self.done_state = DoneState::Done;
100 return Poll::Ready(Some(Ok(Event::default().data("[DONE]"))));
101 }
102 DoneState::Done => {
103 if let Some(on_done) = &self.on_done {
104 on_done(&self.chunks);
105 }
106 return Poll::Ready(None);
107 }
108 DoneState::Running => (),
109 }
110
111 match self.rx.poll_recv(cx) {
112 Poll::Ready(Some(resp)) => match resp {
113 Response::CompletionModelError(msg, _) => {
114 MistralRs::maybe_log_error(
115 self.state.clone(),
116 &ModelErrorMessage(msg.to_string()),
117 );
118 self.done_state = DoneState::SendingDone;
120 Poll::Ready(Some(Ok(Event::default().data(msg))))
121 }
122 Response::ValidationError(e) => Poll::Ready(Some(Ok(
123 Event::default().data(sanitize_error_message(e.as_ref()))
124 ))),
125 Response::InternalError(e) => {
126 MistralRs::maybe_log_error(self.state.clone(), &*e);
127 Poll::Ready(Some(Ok(
128 Event::default().data(sanitize_error_message(e.as_ref()))
129 )))
130 }
131 Response::CompletionChunk(mut response) => {
132 if response.choices.iter().all(|x| x.finish_reason.is_some()) {
133 self.done_state = DoneState::SendingDone;
134 }
135 MistralRs::maybe_log_response(self.state.clone(), &response);
137
138 if let Some(on_chunk) = &self.on_chunk {
139 response = on_chunk(response);
140 }
141
142 if self.store_chunks {
143 self.chunks.push(response.clone());
144 }
145
146 Poll::Ready(Some(Event::default().json_data(response)))
147 }
148 Response::Done(_) => unreachable!(),
149 Response::CompletionDone(_) => unreachable!(),
150 Response::Chunk(_) => unreachable!(),
151 Response::ImageGeneration(_) => unreachable!(),
152 Response::ModelError(_, _) => unreachable!(),
153 Response::Speech { .. } => unreachable!(),
154 Response::Raw { .. } => unreachable!(),
155 Response::Embeddings { .. } => unreachable!(),
156 },
157 Poll::Pending | Poll::Ready(None) => Poll::Pending,
158 }
159 }
160}
161
162pub type CompletionResponder =
164 BaseCompletionResponder<CompletionResponse, KeepAliveStream<CompletionStreamer>>;
165
166type JsonModelError = BaseJsonModelError<CompletionResponse>;
168impl ErrorToResponse for JsonModelError {}
169
170impl IntoResponse for CompletionResponder {
171 fn into_response(self) -> axum::response::Response {
173 match self {
174 CompletionResponder::Sse(s) => s.into_response(),
175 CompletionResponder::Json(s) => Json(s).into_response(),
176 CompletionResponder::InternalError(e) => {
177 JsonError::new(sanitize_error_message(e.as_ref()))
178 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR)
179 }
180 CompletionResponder::ValidationError(e) => {
181 JsonError::new(sanitize_error_message(e.as_ref()))
182 .to_response(http::StatusCode::UNPROCESSABLE_ENTITY)
183 }
184 CompletionResponder::ModelError(msg, response) => JsonModelError::new(msg, response)
185 .to_response(http::StatusCode::INTERNAL_SERVER_ERROR),
186 }
187 }
188}
189
190pub fn parse_request(
195 oairequest: CompletionRequest,
196 state: Arc<MistralRs>,
197 tx: Sender<Response>,
198) -> Result<(Request, bool)> {
199 let repr = serde_json::to_string(&oairequest).expect("Serialization of request failed.");
200 MistralRs::maybe_log_request(state.clone(), repr);
201
202 validate_model_name(&oairequest.model, state.clone())?;
204
205 let stop_toks = convert_stop_tokens(oairequest.stop_seqs);
206
207 if oairequest.logprobs.is_some() {
208 warn!("Completion requests do not support logprobs.");
209 }
210
211 let is_streaming = oairequest.stream.unwrap_or(false);
212
213 let dry_params = get_dry_sampling_params(
214 oairequest.dry_multiplier,
215 oairequest.dry_sequence_breakers,
216 oairequest.dry_base,
217 oairequest.dry_allowed_length,
218 )?;
219
220 Ok((
221 Request::Normal(Box::new(NormalRequest {
222 id: state.next_request_id(),
223 messages: RequestMessage::Completion {
224 text: oairequest.prompt,
225 echo_prompt: oairequest.echo_prompt,
226 best_of: oairequest.best_of,
227 },
228 sampling_params: SamplingParams {
229 temperature: oairequest.temperature,
230 top_k: oairequest.top_k,
231 top_p: oairequest.top_p,
232 min_p: oairequest.min_p,
233 top_n_logprobs: 1,
234 frequency_penalty: oairequest.frequency_penalty,
235 presence_penalty: oairequest.presence_penalty,
236 repetition_penalty: oairequest.repetition_penalty,
237 max_len: oairequest.max_tokens,
238 stop_toks,
239 logits_bias: oairequest.logit_bias,
240 n_choices: oairequest.n_choices,
241 dry_params,
242 },
243 response: tx,
244 return_logprobs: false,
245 is_streaming,
246 suffix: oairequest.suffix,
247 constraint: match oairequest.grammar {
248 Some(Grammar::Regex(regex)) => Constraint::Regex(regex),
249 Some(Grammar::Lark(lark)) => Constraint::Lark(lark),
250 Some(Grammar::JsonSchema(schema)) => Constraint::JsonSchema(schema),
251 Some(Grammar::Llguidance(llguidance)) => Constraint::Llguidance(llguidance),
252 None => Constraint::None,
253 },
254 tool_choice: oairequest.tool_choice,
255 tools: oairequest.tools,
256 logits_processors: None,
257 return_raw_logits: false,
258 web_search_options: None,
259 model_id: if oairequest.model == "default" {
260 None
261 } else {
262 Some(oairequest.model.clone())
263 },
264 truncate_sequence: oairequest.truncate_sequence.unwrap_or(false),
265 })),
266 is_streaming,
267 ))
268}
269
270#[utoipa::path(
272 post,
273 tag = "Mistral.rs",
274 path = "/v1/completions",
275 request_body = CompletionRequest,
276 responses((status = 200, description = "Completions"))
277)]
278pub async fn completions(
279 State(state): ExtractedMistralRsState,
280 Json(oairequest): Json<CompletionRequest>,
281) -> CompletionResponder {
282 let (tx, mut rx) = create_response_channel(None);
283
284 let (request, is_streaming) = match parse_request(oairequest, state.clone(), tx) {
285 Ok(x) => x,
286 Err(e) => return handle_error(state, e.into()),
287 };
288
289 if let Err(e) = send_request(&state, request).await {
290 return handle_error(state, e.into());
291 }
292
293 if is_streaming {
294 CompletionResponder::Sse(create_streamer(rx, state, None, None))
295 } else {
296 process_non_streaming_response(&mut rx, state).await
297 }
298}
299
300pub fn handle_error(
302 state: SharedMistralRsState,
303 e: Box<dyn std::error::Error + Send + Sync + 'static>,
304) -> CompletionResponder {
305 handle_completion_error(state, e)
306}
307
308pub fn create_streamer(
310 rx: Receiver<Response>,
311 state: SharedMistralRsState,
312 on_chunk: Option<CompletionOnChunkCallback>,
313 on_done: Option<CompletionOnDoneCallback>,
314) -> Sse<KeepAliveStream<CompletionStreamer>> {
315 let streamer = base_create_streamer(rx, state, on_chunk, on_done);
316 let keep_alive_interval = get_keep_alive_interval();
317
318 Sse::new(streamer)
319 .keep_alive(KeepAlive::new().interval(Duration::from_millis(keep_alive_interval)))
320}
321
322pub async fn process_non_streaming_response(
324 rx: &mut Receiver<Response>,
325 state: SharedMistralRsState,
326) -> CompletionResponder {
327 base_process_non_streaming_response(rx, state, match_responses, handle_error).await
328}
329
330pub fn match_responses(state: SharedMistralRsState, response: Response) -> CompletionResponder {
332 match response {
333 Response::InternalError(e) => {
334 MistralRs::maybe_log_error(state, &*e);
335 CompletionResponder::InternalError(e)
336 }
337 Response::CompletionModelError(msg, response) => {
338 MistralRs::maybe_log_error(state.clone(), &ModelErrorMessage(msg.to_string()));
339 MistralRs::maybe_log_response(state, &response);
340 CompletionResponder::ModelError(msg, response)
341 }
342 Response::ValidationError(e) => CompletionResponder::ValidationError(e),
343 Response::CompletionDone(response) => {
344 MistralRs::maybe_log_response(state, &response);
345 CompletionResponder::Json(response)
346 }
347 Response::CompletionChunk(_) => unreachable!(),
348 Response::Chunk(_) => unreachable!(),
349 Response::Done(_) => unreachable!(),
350 Response::ModelError(_, _) => unreachable!(),
351 Response::ImageGeneration(_) => unreachable!(),
352 Response::Speech { .. } => unreachable!(),
353 Response::Raw { .. } => unreachable!(),
354 Response::Embeddings { .. } => unreachable!(),
355 }
356}