1use anyhow::Context;
2use candle_core::{Device, Result, Tensor};
3use either::Either;
4use futures::future::join_all;
5use mistralrs_core::*;
6use std::sync::Arc;
7use tokio::sync::mpsc::{channel, Receiver};
8
9use crate::{EmbeddingRequest, EmbeddingRequestBuilder, RequestLike, TextMessages};
10
11pub use mistralrs_core::{AddModelConfig, ModelStatus, Pipeline, SchedulerConfig};
13
14pub fn best_device(force_cpu: bool) -> Result<Device> {
16 if force_cpu {
17 return Ok(Device::Cpu);
18 }
19 #[cfg(not(feature = "metal"))]
20 {
21 Device::cuda_if_available(0)
22 }
23 #[cfg(feature = "metal")]
24 {
25 Device::new_metal(0)
26 }
27}
28
29pub struct Model {
51 pub(crate) runner: Arc<MistralRs>,
52}
53
54pub struct Stream<'a> {
55 _server: &'a Model,
56 rx: Receiver<Response>,
57}
58
59impl Stream<'_> {
60 pub async fn next(&mut self) -> Option<Response> {
61 self.rx.recv().await
62 }
63}
64
65impl Model {
66 pub fn new(runner: Arc<MistralRs>) -> Self {
67 Self { runner }
68 }
69
70 pub async fn stream_chat_request<R: RequestLike>(
76 &self,
77 request: R,
78 ) -> anyhow::Result<Stream<'_>> {
79 self.stream_chat_request_with_model(request, None).await
80 }
81
82 pub async fn stream_chat_request_with_model<R: RequestLike>(
85 &self,
86 mut request: R,
87 model_id: Option<&str>,
88 ) -> anyhow::Result<Stream<'_>> {
89 let (tx, rx) = channel(1);
90
91 let truncate_sequence = request.truncate_sequence();
92 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
93 (Some(a), Some(b))
94 } else {
95 (None, None)
96 };
97 let request = Request::Normal(Box::new(NormalRequest {
98 messages: request.take_messages(),
99 sampling_params: request.take_sampling_params(),
100 response: tx,
101 return_logprobs: request.return_logprobs(),
102 is_streaming: true,
103 id: 0,
104 constraint: request.take_constraint(),
105 suffix: None,
106 tools,
107 tool_choice,
108 logits_processors: request.take_logits_processors(),
109 return_raw_logits: false,
110 web_search_options: request.take_web_search_options(),
111 model_id: model_id.map(|s| s.to_string()),
112 truncate_sequence,
113 }));
114
115 self.runner.get_sender(model_id)?.send(request).await?;
116
117 let stream = Stream { _server: self, rx };
118
119 Ok(stream)
120 }
121
122 pub async fn send_chat_request<R: RequestLike>(
124 &self,
125 request: R,
126 ) -> anyhow::Result<ChatCompletionResponse> {
127 self.send_chat_request_with_model(request, None).await
128 }
129
130 pub async fn send_chat_request_with_model<R: RequestLike>(
133 &self,
134 mut request: R,
135 model_id: Option<&str>,
136 ) -> anyhow::Result<ChatCompletionResponse> {
137 let (tx, mut rx) = channel(1);
138
139 let truncate_sequence = request.truncate_sequence();
140 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
141 (Some(a), Some(b))
142 } else {
143 (None, None)
144 };
145 let request = Request::Normal(Box::new(NormalRequest {
146 messages: request.take_messages(),
147 sampling_params: request.take_sampling_params(),
148 response: tx,
149 return_logprobs: request.return_logprobs(),
150 is_streaming: false,
151 id: 0,
152 constraint: request.take_constraint(),
153 suffix: None,
154 tools,
155 tool_choice,
156 logits_processors: request.take_logits_processors(),
157 return_raw_logits: false,
158 web_search_options: request.take_web_search_options(),
159 model_id: model_id.map(|s| s.to_string()),
160 truncate_sequence,
161 }));
162
163 self.runner.get_sender(model_id)?.send(request).await?;
164
165 let ResponseOk::Done(response) = rx
166 .recv()
167 .await
168 .context("Channel was erroneously closed!")?
169 .as_result()?
170 else {
171 anyhow::bail!("Got unexpected response type.")
172 };
173
174 Ok(response)
175 }
176
177 pub async fn send_raw_chat_request<R: RequestLike>(
181 &self,
182 request: R,
183 ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
184 self.send_raw_chat_request_with_model(request, None).await
185 }
186
187 pub async fn send_raw_chat_request_with_model<R: RequestLike>(
190 &self,
191 mut request: R,
192 model_id: Option<&str>,
193 ) -> anyhow::Result<(Vec<Tensor>, Vec<u32>)> {
194 let (tx, mut rx) = channel(1);
195
196 let truncate_sequence = request.truncate_sequence();
197 let (tools, tool_choice) = if let Some((a, b)) = request.take_tools() {
198 (Some(a), Some(b))
199 } else {
200 (None, None)
201 };
202 let request = Request::Normal(Box::new(NormalRequest {
203 messages: request.take_messages(),
204 sampling_params: request.take_sampling_params(),
205 response: tx,
206 return_logprobs: request.return_logprobs(),
207 is_streaming: false,
208 id: 0,
209 constraint: request.take_constraint(),
210 suffix: None,
211 tools,
212 tool_choice,
213 logits_processors: request.take_logits_processors(),
214 return_raw_logits: true,
215 web_search_options: request.take_web_search_options(),
216 model_id: model_id.map(|s| s.to_string()),
217 truncate_sequence,
218 }));
219
220 self.runner.get_sender(model_id)?.send(request).await?;
221
222 let ResponseOk::Raw {
223 logits_chunks,
224 tokens,
225 } = rx
226 .recv()
227 .await
228 .context("Channel was erroneously closed!")?
229 .as_result()?
230 else {
231 anyhow::bail!("Got unexpected response type.")
232 };
233
234 Ok((logits_chunks, tokens))
235 }
236
237 pub async fn generate_image(
243 &self,
244 prompt: impl ToString,
245 response_format: ImageGenerationResponseFormat,
246 generation_params: DiffusionGenerationParams,
247 ) -> anyhow::Result<ImageGenerationResponse> {
248 self.generate_image_with_model(prompt, response_format, generation_params, None)
249 .await
250 }
251
252 pub async fn generate_image_with_model(
255 &self,
256 prompt: impl ToString,
257 response_format: ImageGenerationResponseFormat,
258 generation_params: DiffusionGenerationParams,
259 model_id: Option<&str>,
260 ) -> anyhow::Result<ImageGenerationResponse> {
261 let (tx, mut rx) = channel(1);
262
263 let request = Request::Normal(Box::new(NormalRequest {
264 id: 0,
265 messages: RequestMessage::ImageGeneration {
266 prompt: prompt.to_string(),
267 format: response_format,
268 generation_params,
269 },
270 sampling_params: SamplingParams::deterministic(),
271 response: tx,
272 return_logprobs: false,
273 is_streaming: false,
274 suffix: None,
275 constraint: Constraint::None,
276 tool_choice: None,
277 tools: None,
278 logits_processors: None,
279 return_raw_logits: false,
280 web_search_options: None,
281 model_id: model_id.map(|s| s.to_string()),
282 truncate_sequence: false,
283 }));
284
285 self.runner.get_sender(model_id)?.send(request).await?;
286
287 let ResponseOk::ImageGeneration(response) = rx
288 .recv()
289 .await
290 .context("Channel was erroneously closed!")?
291 .as_result()?
292 else {
293 anyhow::bail!("Got unexpected response type.")
294 };
295
296 Ok(response)
297 }
298
299 pub async fn generate_speech(
307 &self,
308 prompt: impl ToString,
309 ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
310 self.generate_speech_with_model(prompt, None).await
311 }
312
313 pub async fn generate_speech_with_model(
318 &self,
319 prompt: impl ToString,
320 model_id: Option<&str>,
321 ) -> anyhow::Result<(Arc<Vec<f32>>, usize, usize)> {
322 let (tx, mut rx) = channel(1);
323
324 let request = Request::Normal(Box::new(NormalRequest {
325 id: 0,
326 messages: RequestMessage::SpeechGeneration {
327 prompt: prompt.to_string(),
328 },
329 sampling_params: SamplingParams::deterministic(),
330 response: tx,
331 return_logprobs: false,
332 is_streaming: false,
333 suffix: None,
334 constraint: Constraint::None,
335 tool_choice: None,
336 tools: None,
337 logits_processors: None,
338 return_raw_logits: false,
339 web_search_options: None,
340 model_id: model_id.map(|s| s.to_string()),
341 truncate_sequence: false,
342 }));
343
344 self.runner.get_sender(model_id)?.send(request).await?;
345
346 let ResponseOk::Speech {
347 pcm,
348 rate,
349 channels,
350 } = rx
351 .recv()
352 .await
353 .context("Channel was erroneously closed!")?
354 .as_result()?
355 else {
356 anyhow::bail!("Got unexpected response type.")
357 };
358
359 Ok((pcm, rate, channels))
360 }
361
362 pub async fn generate_embeddings(
370 &self,
371 request: EmbeddingRequestBuilder,
372 ) -> anyhow::Result<Vec<Vec<f32>>> {
373 self.generate_embeddings_with_model(request, None).await
374 }
375
376 pub async fn generate_embeddings_with_model(
381 &self,
382 request: EmbeddingRequestBuilder,
383 model_id: Option<&str>,
384 ) -> anyhow::Result<Vec<Vec<f32>>> {
385 let request = request.build()?;
386 let EmbeddingRequest {
387 inputs,
388 truncate_sequence,
389 } = request;
390
391 let runner = self.runner.clone();
392 let model_id_owned = model_id.map(|s| s.to_string());
393 let futures = inputs.into_iter().map(|input| {
394 let runner = runner.clone();
395 let model_id_owned = model_id_owned.clone();
396 async move {
397 let message = input.into_request_message();
398 let (tx, mut rx) = channel(1);
399
400 let request = Request::Normal(Box::new(NormalRequest {
401 id: 0,
402 messages: message,
403 sampling_params: SamplingParams::deterministic(),
404 response: tx,
405 return_logprobs: false,
406 is_streaming: false,
407 suffix: None,
408 constraint: Constraint::None,
409 tool_choice: None,
410 tools: None,
411 logits_processors: None,
412 return_raw_logits: false,
413 web_search_options: None,
414 model_id: model_id_owned.clone(),
415 truncate_sequence,
416 }));
417
418 runner
419 .get_sender(model_id_owned.as_deref())?
420 .send(request)
421 .await
422 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
423
424 let ResponseOk::Embeddings { embeddings, .. } = rx
425 .recv()
426 .await
427 .context("Channel was erroneously closed!")?
428 .as_result()?
429 else {
430 anyhow::bail!("Got unexpected response type.")
431 };
432
433 Ok::<Vec<f32>, anyhow::Error>(embeddings)
434 }
435 });
436
437 let results = join_all(futures).await;
438 let mut embeddings = Vec::with_capacity(results.len());
439 for result in results {
440 embeddings.push(result?);
441 }
442 Ok(embeddings)
443 }
444
445 pub async fn generate_embedding(&self, prompt: impl ToString) -> anyhow::Result<Vec<f32>> {
447 self.generate_embedding_with_model(prompt, None).await
448 }
449
450 pub async fn generate_embedding_with_model(
453 &self,
454 prompt: impl ToString,
455 model_id: Option<&str>,
456 ) -> anyhow::Result<Vec<f32>> {
457 let mut embeddings = self
458 .generate_embeddings_with_model(
459 EmbeddingRequest::builder().add_prompt(prompt.to_string()),
460 model_id,
461 )
462 .await?;
463
464 Ok(embeddings
465 .pop()
466 .expect("EmbeddingRequestBuilder should guarantee at least one input"))
467 }
468
469 pub async fn re_isq_model(&self, isq_type: IsqType) -> anyhow::Result<()> {
475 self.re_isq_model_with_model(isq_type, None).await
476 }
477
478 pub async fn re_isq_model_with_model(
481 &self,
482 isq_type: IsqType,
483 model_id: Option<&str>,
484 ) -> anyhow::Result<()> {
485 let request = Request::ReIsq(isq_type);
486
487 Ok(self.runner.get_sender(model_id)?.send(request).await?)
488 }
489
490 pub async fn tokenize(
497 &self,
498 text: Either<TextMessages, String>,
499 tools: Option<Vec<Tool>>,
500 add_special_tokens: bool,
501 add_generation_prompt: bool,
502 enable_thinking: Option<bool>,
503 ) -> anyhow::Result<Vec<u32>> {
504 self.tokenize_with_model(
505 text,
506 tools,
507 add_special_tokens,
508 add_generation_prompt,
509 enable_thinking,
510 None,
511 )
512 .await
513 }
514
515 pub async fn tokenize_with_model(
519 &self,
520 text: Either<TextMessages, String>,
521 tools: Option<Vec<Tool>>,
522 add_special_tokens: bool,
523 add_generation_prompt: bool,
524 enable_thinking: Option<bool>,
525 model_id: Option<&str>,
526 ) -> anyhow::Result<Vec<u32>> {
527 let (tx, mut rx) = channel(1);
528 let request = Request::Tokenize(TokenizationRequest {
529 text: text.map_left(Into::into),
530 tools,
531 add_special_tokens,
532 add_generation_prompt,
533 response: tx,
534 enable_thinking,
535 reasoning_effort: None,
536 });
537 self.runner.get_sender(model_id)?.send(request).await?;
538
539 rx.recv().await.context("Channel was erroneously closed!")?
540 }
541
542 pub async fn detokenize(
544 &self,
545 tokens: Vec<u32>,
546 skip_special_tokens: bool,
547 ) -> anyhow::Result<String> {
548 self.detokenize_with_model(tokens, skip_special_tokens, None)
549 .await
550 }
551
552 pub async fn detokenize_with_model(
555 &self,
556 tokens: Vec<u32>,
557 skip_special_tokens: bool,
558 model_id: Option<&str>,
559 ) -> anyhow::Result<String> {
560 let (tx, mut rx) = channel(1);
561 let request = Request::Detokenize(DetokenizationRequest {
562 tokens,
563 skip_special_tokens,
564 response: tx,
565 });
566 self.runner.get_sender(model_id)?.send(request).await?;
567
568 rx.recv().await.context("Channel was erroneously closed!")?
569 }
570
571 pub fn config(&self) -> std::result::Result<MistralRsConfig, String> {
577 self.config_with_model(None)
578 }
579
580 pub fn config_with_model(
583 &self,
584 model_id: Option<&str>,
585 ) -> std::result::Result<MistralRsConfig, String> {
586 self.runner.config(model_id)
587 }
588
589 pub fn max_sequence_length(&self) -> std::result::Result<Option<usize>, MistralRsError> {
591 self.max_sequence_length_with_model(None)
592 }
593
594 pub fn max_sequence_length_with_model(
597 &self,
598 model_id: Option<&str>,
599 ) -> std::result::Result<Option<usize>, MistralRsError> {
600 self.runner.max_sequence_length(model_id)
601 }
602
603 pub fn list_models(&self) -> std::result::Result<Vec<String>, String> {
609 self.runner.list_models()
610 }
611
612 pub fn get_default_model_id(&self) -> std::result::Result<Option<String>, String> {
614 self.runner.get_default_model_id()
615 }
616
617 pub fn set_default_model_id(&self, model_id: &str) -> std::result::Result<(), String> {
619 self.runner.set_default_model_id(model_id)
620 }
621
622 pub async fn add_model(
624 &self,
625 model_id: String,
626 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
627 method: SchedulerConfig,
628 config: AddModelConfig,
629 ) -> std::result::Result<(), String> {
630 self.runner
631 .add_model(model_id, pipeline, method, config)
632 .await
633 }
634
635 pub fn remove_model(&self, model_id: &str) -> std::result::Result<(), String> {
637 self.runner.remove_model(model_id)
638 }
639
640 pub fn unload_model(&self, model_id: &str) -> std::result::Result<(), MistralRsError> {
642 self.runner.unload_model(model_id)
643 }
644
645 pub async fn reload_model(&self, model_id: &str) -> std::result::Result<(), MistralRsError> {
647 self.runner.reload_model(model_id).await
648 }
649
650 pub fn is_model_loaded(&self, model_id: &str) -> std::result::Result<bool, MistralRsError> {
652 self.runner.is_model_loaded(model_id)
653 }
654
655 pub fn list_models_with_status(
657 &self,
658 ) -> std::result::Result<Vec<(String, ModelStatus)>, MistralRsError> {
659 self.runner.list_models_with_status()
660 }
661
662 pub fn inner(&self) -> &MistralRs {
664 &self.runner
665 }
666}