1use directories::ProjectDirs;
2use either::Either;
3use indexmap::IndexMap;
4use mistralrs_core::{
5 speech_utils, ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
6 ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
7 Request, RequestMessage, Response, ResponseOk, SamplingParams, WebSearchOptions,
8 TERMINATE_ALL_NEXT_STEP,
9};
10use once_cell::sync::Lazy;
11use regex::Regex;
12use rustyline::{error::ReadlineError, history::History, DefaultEditor, Editor, Helper};
13use serde_json::Value;
14use std::{
15 fs,
16 io::{self, Write},
17 path::PathBuf,
18 sync::{atomic::Ordering, Arc, Mutex},
19 time::Instant,
20};
21use tokio::sync::mpsc::channel;
22use tracing::{error, info};
23
24use mistralrs_server_core::util;
25
26fn exit_handler() {
27 std::process::exit(0);
28}
29
30fn terminate_handler() {
31 TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst);
32}
33
34fn history_file_path() -> PathBuf {
35 let proj_dirs = ProjectDirs::from("com", "", "mistral.rs")
37 .expect("Could not determine project directories");
38 let config_dir = proj_dirs.config_dir();
39
40 fs::create_dir_all(config_dir).expect("Failed to create config directory");
42
43 config_dir.join("history.txt")
45}
46
47fn read_line<H: Helper, I: History>(editor: &mut Editor<H, I>) -> String {
48 let r = editor.readline("> ");
49 match r {
50 Err(ReadlineError::Interrupted) => {
51 editor.save_history(&history_file_path()).unwrap();
52 std::process::exit(0);
54 }
55
56 Err(ReadlineError::Eof) => {
57 editor.save_history(&history_file_path()).unwrap();
58 std::process::exit(0);
60 }
61
62 Err(e) => {
63 editor.save_history(&history_file_path()).unwrap();
64 eprintln!("Error reading input: {e:?}");
65 std::process::exit(1);
66 }
67 Ok(prompt) => {
68 editor.add_history_entry(prompt.clone()).unwrap();
69 prompt
70 }
71 }
72}
73
74static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
75 Lazy::new(|| Mutex::new(&exit_handler));
76
77pub async fn interactive_mode(
78 mistralrs: Arc<MistralRs>,
79 do_search: bool,
80 enable_thinking: Option<bool>,
81) {
82 match mistralrs.get_model_category(None) {
83 Ok(ModelCategory::Text) => {
84 text_interactive_mode(mistralrs, do_search, enable_thinking).await
85 }
86 Ok(ModelCategory::Vision { .. }) => {
87 vision_interactive_mode(mistralrs, do_search, enable_thinking).await
88 }
89 Ok(ModelCategory::Diffusion) => diffusion_interactive_mode(mistralrs, do_search).await,
90 Ok(ModelCategory::Audio) => {
91 audio_interactive_mode(mistralrs, do_search, enable_thinking).await
92 }
93 Ok(ModelCategory::Speech) => speech_interactive_mode(mistralrs, do_search).await,
94 Err(e) => eprintln!("Error getting model category: {e}"),
95 }
96}
97
98const COMMAND_COMMANDS: &str = r#"
99Commands:
100- `\help`: Display this message.
101- `\exit`: Quit interactive mode.
102- `\system <system message here>`:
103 Add a system message to the chat without running the model.
104 Ex: `\system Always respond as a pirate.`
105- `\clear`: Clear the chat history.
106- `\temperature <float>`: Set sampling temperature (0.0 to 2.0).
107- `\topk <int>`: Set top-k sampling value (>0).
108- `\topp <float>`: Set top-p sampling value in (0.0 to 1.0).
109"#;
110
111const TEXT_INTERACTIVE_HELP: &str = r#"
112Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
113"#;
114
115const VISION_INTERACTIVE_HELP: &str = r#"
116Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
117
118To specify a message with one or more images or audios, simply include the image/audio URL or path:
119
120- `Describe these images: path/to/image1.jpg path/to/image2.png`
121- `Describe the image and transcribe the audio: path/to/image1.jpg path/to/sound.mp3`
122"#;
123
124const DIFFUSION_INTERACTIVE_HELP: &str = r#"
125Welcome to interactive mode! Because this model is a diffusion model, you can enter prompts and the model will generate an image.
126
127Commands:
128- `\help`: Display this message.
129- `\exit`: Quit interactive mode.
130"#;
131
132const SPEECH_INTERACTIVE_HELP: &str = r#"
133Welcome to interactive mode! Because this model is a speech generation model, you can enter prompts and the model will generate audio.
134
135Commands:
136- `\help`: Display this message.
137- `\exit`: Quit interactive mode.
138"#;
139
140const HELP_CMD: &str = "\\help";
141const EXIT_CMD: &str = "\\exit";
142const SYSTEM_CMD: &str = "\\system";
143const CLEAR_CMD: &str = "\\clear";
144const TEMPERATURE_CMD: &str = "\\temperature";
145const TOPK_CMD: &str = "\\topk";
146const TOPP_CMD: &str = "\\topp";
147
148const IMAGE_REGEX: &str = r#"((?:https?://|file://)?\S+?\.(?:png|jpe?g|bmp|gif|webp)(?:\?\S+?)?)"#;
150const AUDIO_REGEX: &str = r#"((?:https?://|file://)?\S+?\.(?:wav|mp3|flac|ogg)(?:\?\S+?)?)"#;
151
152fn interactive_sample_parameters() -> SamplingParams {
153 SamplingParams {
154 temperature: Some(0.1),
155 top_k: Some(32),
156 top_p: Some(0.1),
157 min_p: Some(0.05),
158 top_n_logprobs: 0,
159 frequency_penalty: Some(0.1),
160 presence_penalty: Some(0.1),
161 max_len: None,
162 stop_toks: None,
163 logits_bias: None,
164 n_choices: 1,
165 dry_params: Some(DrySamplingParams::default()),
166 }
167}
168
169fn handle_sampling_command(prompt: &str, sampling_params: &mut SamplingParams) -> bool {
172 let trimmed = prompt.trim();
173 if trimmed.starts_with(TEMPERATURE_CMD) {
174 let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
175 if let [_, value] = parts.as_slice() {
176 match value.trim().parse::<f64>() {
177 Ok(v) if v > 0.0 && v <= 2.0 => {
178 sampling_params.temperature = Some(v);
179 info!("Set temperature to {v}");
180 }
181 Ok(_) => {
182 println!("Error: temperature must be in (0.0, 2.0]");
183 }
184 Err(_) => println!("Error: format is `{TEMPERATURE_CMD} <float>`"),
185 }
186 } else {
187 println!("Error: format is `{TEMPERATURE_CMD} <float>`");
188 }
189 return true;
190 }
191 if trimmed.starts_with(TOPK_CMD) {
192 let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
193 if let [_, value] = parts.as_slice() {
194 match value.trim().parse::<usize>() {
195 Ok(v) if v > 0 => {
196 sampling_params.top_k = Some(v);
197 info!("Set top-k to {v}");
198 }
199 Ok(_) => {
200 println!("Error: top-k must be a positive integer");
201 }
202 Err(_) => println!("Error: format is `{TOPK_CMD} <int>`"),
203 }
204 } else {
205 println!("Error: format is `{TOPK_CMD} <int>`");
206 }
207 return true;
208 }
209 if trimmed.starts_with(TOPP_CMD) {
210 let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
211 if let [_, value] = parts.as_slice() {
212 match value.trim().parse::<f64>() {
213 Ok(v) if v > 0.0 && v <= 1.0 => {
214 sampling_params.top_p = Some(v);
215 info!("Set top-p to {v}");
216 }
217 Ok(_) => {
218 println!("Error: top-p must be in (0.0, 1.0]");
219 }
220 Err(_) => println!("Error: format is `{TOPP_CMD} <float>`"),
221 }
222 } else {
223 println!("Error: format is `{TOPP_CMD} <float>`");
224 }
225 return true;
226 }
227 false
228}
229
230async fn text_interactive_mode(
231 mistralrs: Arc<MistralRs>,
232 do_search: bool,
233 enable_thinking: Option<bool>,
234) {
235 let sender = mistralrs.get_sender(None).unwrap();
236 let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
237
238 let mut sampling_params = interactive_sample_parameters();
239
240 info!("Starting interactive loop with sampling params: {sampling_params:?}");
241 println!(
242 "{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
243 "=".repeat(20),
244 "=".repeat(20)
245 );
246
247 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
249
250 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
251 .expect("Failed to set CTRL-C handler for interactive mode");
252
253 let mut rl = DefaultEditor::new().expect("Failed to open input");
254 let _ = rl.load_history(&history_file_path());
255 'outer: loop {
256 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
258
259 let prompt = read_line(&mut rl);
260
261 let prompt_trimmed = prompt.as_str().trim();
262 if prompt_trimmed.is_empty() {
263 continue;
264 }
265 if handle_sampling_command(prompt_trimmed, &mut sampling_params) {
266 continue;
267 }
268 match prompt_trimmed {
269 HELP_CMD => {
270 println!(
271 "{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
272 "=".repeat(20),
273 "=".repeat(20)
274 );
275 continue;
276 }
277 EXIT_CMD => {
278 break;
279 }
280 CLEAR_CMD => {
281 messages.clear();
282 info!("Cleared chat history.");
283 continue;
284 }
285 _ if prompt_trimmed.starts_with(SYSTEM_CMD) => {
286 let parsed = match &prompt_trimmed.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
287 &["", a] => a.trim(),
288 _ => {
289 println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
290 continue;
291 }
292 };
293 info!("Set system message to `{parsed}`.");
294 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
295 user_message.insert("role".to_string(), Either::Left("system".to_string()));
296 user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
297 messages.push(user_message);
298 continue;
299 }
300 message => {
301 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
302 user_message.insert("role".to_string(), Either::Left("user".to_string()));
303 user_message.insert("content".to_string(), Either::Left(message.to_string()));
304 messages.push(user_message);
305 }
306 }
307
308 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
310
311 let request_messages = RequestMessage::Chat {
312 messages: messages.clone(),
313 enable_thinking,
314 };
315
316 let (tx, mut rx) = channel(10_000);
317 let req = Request::Normal(Box::new(NormalRequest {
318 id: mistralrs.next_request_id(),
319 messages: request_messages,
320 sampling_params: sampling_params.clone(),
321 response: tx,
322 return_logprobs: false,
323 is_streaming: true,
324 constraint: Constraint::None,
325 suffix: None,
326 tool_choice: None,
327 tools: None,
328 logits_processors: None,
329 return_raw_logits: false,
330 web_search_options: do_search.then(WebSearchOptions::default),
331 model_id: None,
332 }));
333 sender.send(req).await.unwrap();
334 let start_ttft = Instant::now();
335 let mut first_token_duration: Option<std::time::Duration> = None;
336
337 let mut assistant_output = String::new();
338
339 let mut last_usage = None;
340 while let Some(resp) = rx.recv().await {
341 match resp {
342 Response::Chunk(chunk) => {
343 last_usage = chunk.usage.clone();
344 if let ChunkChoice {
345 delta:
346 Delta {
347 content: Some(content),
348 ..
349 },
350 finish_reason,
351 ..
352 } = &chunk.choices[0]
353 {
354 if first_token_duration.is_none() {
355 let ttft = Instant::now().duration_since(start_ttft);
356 first_token_duration = Some(ttft);
357 }
358 assistant_output.push_str(content);
359 print!("{content}");
360 io::stdout().flush().unwrap();
361 if finish_reason.is_some() {
362 if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
363 print!("...");
364 }
365 break;
366 }
367 }
368 }
369 Response::InternalError(e) => {
370 error!("Got an internal error: {e:?}");
371 break 'outer;
372 }
373 Response::ModelError(e, resp) => {
374 error!("Got a model error: {e:?}, response: {resp:?}");
375 break 'outer;
376 }
377 Response::ValidationError(e) => {
378 error!("Got a validation error: {e:?}");
379 break 'outer;
380 }
381 Response::Done(_) => unreachable!(),
382 Response::CompletionDone(_) => unreachable!(),
383 Response::CompletionModelError(_, _) => unreachable!(),
384 Response::CompletionChunk(_) => unreachable!(),
385 Response::ImageGeneration(_) => unreachable!(),
386 Response::Speech { .. } => unreachable!(),
387 Response::Raw { .. } => unreachable!(),
388 }
389 }
390
391 if let Some(last_usage) = last_usage {
392 println!();
393 println!();
394 println!("Stats:");
395 if let Some(ttft) = first_token_duration {
396 println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
397 }
398 println!(
399 "Prompt: {} tokens, {:.2} T/s",
400 last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
401 );
402 println!(
403 "Decode: {} tokens, {:.2} T/s",
404 last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
405 );
406 }
407 let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
408 IndexMap::new();
409 assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
410 assistant_message.insert("content".to_string(), Either::Left(assistant_output));
411 messages.push(assistant_message);
412 println!();
413 }
414
415 rl.save_history(&history_file_path()).unwrap();
416}
417
418fn parse_files_and_message(input: &str, regex: &Regex) -> (Vec<String>, String) {
419 let urls: Vec<String> = regex
421 .captures_iter(input)
422 .filter_map(|cap| {
423 cap.get(1).map(|m| {
424 m.as_str()
425 .trim_end_matches(|c: char| {
426 matches!(
427 c,
428 '.' | ',' | ';' | ':' | '!' | '?' | ')' | ']' | '}' | '"' | '\''
429 )
430 })
431 .to_string()
432 })
433 })
434 .collect();
435 let text = regex.replace_all(input, "").trim().to_string();
437 (urls, text)
438}
439
440async fn vision_interactive_mode(
441 mistralrs: Arc<MistralRs>,
442 do_search: bool,
443 enable_thinking: Option<bool>,
444) {
445 let image_regex = Regex::new(IMAGE_REGEX).unwrap();
447 let audio_regex = Regex::new(AUDIO_REGEX).unwrap();
448
449 let sender = mistralrs.get_sender(None).unwrap();
450 let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
451 let mut images = Vec::new();
452 let mut audios = Vec::new();
453
454 let config = mistralrs.config(None).unwrap();
455 let prefixer = match &config.category {
456 ModelCategory::Vision { prefixer } => prefixer,
457 ModelCategory::Text
458 | ModelCategory::Diffusion
459 | ModelCategory::Speech
460 | ModelCategory::Audio => {
461 panic!("`add_image_message` expects a vision model.")
462 }
463 };
464
465 let mut sampling_params = interactive_sample_parameters();
466
467 info!("Starting interactive loop with sampling params: {sampling_params:?}");
468 println!(
469 "{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
470 "=".repeat(20),
471 "=".repeat(20)
472 );
473
474 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
476
477 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
478 .expect("Failed to set CTRL-C handler for interactive mode");
479
480 let mut rl = DefaultEditor::new().expect("Failed to open input");
481 let _ = rl.load_history(&history_file_path());
482 'outer: loop {
483 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
485
486 let prompt = read_line(&mut rl);
487
488 let prompt_trimmed = prompt.as_str().trim();
489 if prompt_trimmed.is_empty() {
490 continue;
491 }
492 if handle_sampling_command(prompt_trimmed, &mut sampling_params) {
493 continue;
494 }
495 match prompt_trimmed {
496 HELP_CMD => {
497 println!(
498 "{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
499 "=".repeat(20),
500 "=".repeat(20)
501 );
502 continue;
503 }
504 EXIT_CMD => {
505 break;
506 }
507 CLEAR_CMD => {
508 messages.clear();
509 images.clear();
510 info!("Cleared chat history.");
511 continue;
512 }
513 _ if prompt_trimmed.starts_with(SYSTEM_CMD) => {
514 let parsed = match &prompt_trimmed.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
515 &["", a] => a.trim(),
516 _ => {
517 println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
518 continue;
519 }
520 };
521 info!("Set system message to `{parsed}`.");
522 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
523 user_message.insert("role".to_string(), Either::Left("system".to_string()));
524 user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
525 messages.push(user_message);
526 continue;
527 }
528 _ => {
529 let (urls_image, text_without_images) =
530 parse_files_and_message(prompt_trimmed, &image_regex);
531 let (urls_audio, text) =
532 parse_files_and_message(&text_without_images, &audio_regex);
533 if !urls_image.is_empty() || !urls_audio.is_empty() {
534 let mut image_indexes = Vec::new();
536 for url in &urls_image {
537 match util::parse_image_url(url).await {
538 Ok(image) => {
539 info!("Added image at `{url}`");
540 image_indexes.push(images.len());
541 images.push(image);
542 }
543 Err(e) => {
544 error!("Failed to read image from URL/path {}: {}", url, e);
545 continue 'outer;
546 }
547 }
548 }
549 let mut audio_indexes = Vec::new();
551 for url in &urls_audio {
552 match util::parse_audio_url(url).await {
553 Ok(audio) => {
554 info!("Added audio at `{url}`");
555 audio_indexes.push(audios.len());
556 audios.push(audio);
557 }
558 Err(e) => {
559 error!("Failed to read audio from URL/path {}: {}", url, e);
560 continue 'outer;
561 }
562 }
563 }
564 let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
566 for _ in &urls_image {
567 content_vec.push(IndexMap::from([(
568 "type".to_string(),
569 Value::String("image".to_string()),
570 )]));
571 }
572 for _ in &urls_audio {
573 content_vec.push(IndexMap::from([(
574 "type".to_string(),
575 Value::String("audio".to_string()),
576 )]));
577 }
578 let mut prefixed_text = text.clone();
580 if !image_indexes.is_empty() {
581 prefixed_text =
582 prefixer.prefix_image(image_indexes.clone(), &prefixed_text);
583 }
584 if !audio_indexes.is_empty() {
585 prefixed_text =
586 prefixer.prefix_audio(audio_indexes.clone(), &prefixed_text);
587 }
588 content_vec.push(IndexMap::from([
590 ("type".to_string(), Value::String("text".to_string())),
591 ("text".to_string(), Value::String(prefixed_text)),
592 ]));
593 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
595 user_message.insert("role".to_string(), Either::Left("user".to_string()));
596 user_message.insert("content".to_string(), Either::Right(content_vec));
597 messages.push(user_message);
598 } else {
599 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
601 user_message.insert("role".to_string(), Either::Left("user".to_string()));
602 user_message.insert(
603 "content".to_string(),
604 Either::Left(prompt_trimmed.to_string()),
605 );
606 messages.push(user_message);
607 }
608 }
609 }
610
611 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
613
614 let request_messages = RequestMessage::VisionChat {
615 images: images.clone(),
616 audios: audios.clone(),
617 messages: messages.clone(),
618 enable_thinking,
619 };
620
621 let (tx, mut rx) = channel(10_000);
622 let req = Request::Normal(Box::new(NormalRequest {
623 id: mistralrs.next_request_id(),
624 messages: request_messages,
625 sampling_params: sampling_params.clone(),
626 response: tx,
627 return_logprobs: false,
628 is_streaming: true,
629 constraint: Constraint::None,
630 suffix: None,
631 tool_choice: None,
632 tools: None,
633 logits_processors: None,
634 return_raw_logits: false,
635 web_search_options: do_search.then(WebSearchOptions::default),
636 model_id: None,
637 }));
638 sender.send(req).await.unwrap();
639 let start_ttft = Instant::now();
640 let mut first_token_duration: Option<std::time::Duration> = None;
641
642 let mut assistant_output = String::new();
643
644 let mut last_usage = None;
645 while let Some(resp) = rx.recv().await {
646 match resp {
647 Response::Chunk(chunk) => {
648 last_usage = chunk.usage.clone();
649 if let ChunkChoice {
650 delta:
651 Delta {
652 content: Some(content),
653 ..
654 },
655 finish_reason,
656 ..
657 } = &chunk.choices[0]
658 {
659 if first_token_duration.is_none() {
660 let ttft = Instant::now().duration_since(start_ttft);
661 first_token_duration = Some(ttft);
662 }
663 assistant_output.push_str(content);
664 print!("{content}");
665 io::stdout().flush().unwrap();
666 if finish_reason.is_some() {
667 if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
668 print!("...");
669 }
670 break;
671 }
672 }
673 }
674 Response::InternalError(e) => {
675 error!("Got an internal error: {e:?}");
676 break 'outer;
677 }
678 Response::ModelError(e, resp) => {
679 error!("Got a model error: {e:?}, response: {resp:?}");
680 break 'outer;
681 }
682 Response::ValidationError(e) => {
683 error!("Got a validation error: {e:?}");
684 break 'outer;
685 }
686 Response::Done(_) => unreachable!(),
687 Response::CompletionDone(_) => unreachable!(),
688 Response::CompletionModelError(_, _) => unreachable!(),
689 Response::CompletionChunk(_) => unreachable!(),
690 Response::ImageGeneration(_) => unreachable!(),
691 Response::Speech { .. } => unreachable!(),
692 Response::Raw { .. } => unreachable!(),
693 }
694 }
695
696 if let Some(last_usage) = last_usage {
697 println!();
698 println!();
699 println!("Stats:");
700 if let Some(ttft) = first_token_duration {
701 println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
702 }
703 println!(
704 "Prompt: {} tokens, {:.2} T/s",
705 last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
706 );
707 println!(
708 "Decode: {} tokens, {:.2} T/s",
709 last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
710 );
711 }
712 let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
713 IndexMap::new();
714 assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
715 assistant_message.insert("content".to_string(), Either::Left(assistant_output));
716 messages.push(assistant_message);
717 println!();
718 }
719
720 rl.save_history(&history_file_path()).unwrap();
721}
722
723async fn audio_interactive_mode(
724 _mistralrs: Arc<MistralRs>,
725 _do_search: bool,
726 _enable_thinking: Option<bool>,
727) {
728 unimplemented!("Using audio models isn't supported yet")
729}
730
731async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
732 let sender = mistralrs.get_sender(None).unwrap();
733
734 let diffusion_params = DiffusionGenerationParams::default();
735
736 info!("Starting interactive loop with generation params: {diffusion_params:?}");
737 println!(
738 "{}{DIFFUSION_INTERACTIVE_HELP}{}",
739 "=".repeat(20),
740 "=".repeat(20)
741 );
742
743 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
745
746 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
747 .expect("Failed to set CTRL-C handler for interactive mode");
748
749 let mut rl = DefaultEditor::new().expect("Failed to open input");
750 let _ = rl.load_history(&history_file_path());
751 loop {
752 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
754
755 let prompt = read_line(&mut rl);
756
757 let prompt = match prompt.as_str().trim() {
758 "" => continue,
759 HELP_CMD => {
760 println!(
761 "{}{DIFFUSION_INTERACTIVE_HELP}{}",
762 "=".repeat(20),
763 "=".repeat(20)
764 );
765 continue;
766 }
767 EXIT_CMD => {
768 break;
769 }
770 prompt => prompt.to_string(),
771 };
772
773 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
775
776 let (tx, mut rx) = channel(10_000);
777 let req = Request::Normal(Box::new(NormalRequest {
778 id: 0,
779 messages: RequestMessage::ImageGeneration {
780 prompt: prompt.to_string(),
781 format: ImageGenerationResponseFormat::Url,
782 generation_params: diffusion_params.clone(),
783 },
784 sampling_params: SamplingParams::deterministic(),
785 response: tx,
786 return_logprobs: false,
787 is_streaming: false,
788 suffix: None,
789 constraint: Constraint::None,
790 tool_choice: None,
791 tools: None,
792 logits_processors: None,
793 return_raw_logits: false,
794 web_search_options: do_search.then(WebSearchOptions::default),
795 model_id: None,
796 }));
797
798 let start = Instant::now();
799 sender.send(req).await.unwrap();
800
801 let ResponseOk::ImageGeneration(response) = rx.recv().await.unwrap().as_result().unwrap()
802 else {
803 panic!("Got unexpected response type.")
804 };
805 let end = Instant::now();
806
807 let duration = end.duration_since(start).as_secs_f32();
808 let pixels_per_s = (diffusion_params.height * diffusion_params.width) as f32 / duration;
809
810 println!(
811 "Image generated can be found at: image is at `{}`. Took {duration:.2}s ({pixels_per_s:.2} pixels/s).",
812 response.data[0].url.as_ref().unwrap(),
813 );
814
815 println!();
816 }
817
818 rl.save_history(&history_file_path()).unwrap();
819}
820
821async fn speech_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
822 let sender = mistralrs.get_sender(None).unwrap();
823
824 info!("Starting interactive loop for speech");
825 println!(
826 "{}{SPEECH_INTERACTIVE_HELP}{}",
827 "=".repeat(20),
828 "=".repeat(20)
829 );
830
831 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
833
834 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
835 .expect("Failed to set CTRL-C handler for interactive mode");
836
837 let mut rl = DefaultEditor::new().expect("Failed to open input");
838 let _ = rl.load_history(&history_file_path());
839
840 let mut n = 0;
841 loop {
842 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
844
845 let prompt = read_line(&mut rl);
846
847 let prompt = match prompt.as_str().trim() {
848 "" => continue,
849 HELP_CMD => {
850 println!(
851 "{}{SPEECH_INTERACTIVE_HELP}{}",
852 "=".repeat(20),
853 "=".repeat(20)
854 );
855 continue;
856 }
857 EXIT_CMD => {
858 break;
859 }
860 prompt => prompt.to_string(),
861 };
862
863 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
865
866 let (tx, mut rx) = channel(10_000);
867 let req = Request::Normal(Box::new(NormalRequest {
868 id: 0,
869 messages: RequestMessage::SpeechGeneration {
870 prompt: prompt.to_string(),
871 },
872 sampling_params: SamplingParams::deterministic(),
873 response: tx,
874 return_logprobs: false,
875 is_streaming: false,
876 suffix: None,
877 constraint: Constraint::None,
878 tool_choice: None,
879 tools: None,
880 logits_processors: None,
881 return_raw_logits: false,
882 web_search_options: do_search.then(WebSearchOptions::default),
883 model_id: None,
884 }));
885
886 let start = Instant::now();
887 sender.send(req).await.unwrap();
888
889 let ResponseOk::Speech {
890 pcm,
891 rate,
892 channels,
893 } = rx.recv().await.unwrap().as_result().unwrap()
894 else {
895 panic!("Got unexpected response type.")
896 };
897 let end = Instant::now();
898
899 let out_file = format!("speech-{n}.wav");
900 let mut output = std::fs::File::create(&out_file).unwrap();
901 speech_utils::write_pcm_as_wav(&mut output, &pcm, rate as u32, channels as u16).unwrap();
902
903 let duration = end.duration_since(start).as_secs_f32();
904 println!("Speech generated can be found at `{out_file}`. Took {duration:.2}s.");
905
906 n += 1;
907
908 println!();
909 }
910
911 rl.save_history(&history_file_path()).unwrap();
912}
913
914#[cfg(test)]
915mod tests {
916 use super::*;
917
918 #[test]
919 fn parse_files_and_message_trims_trailing_punctuation() {
920 let regex = Regex::new(IMAGE_REGEX).unwrap();
921 let input = "Look at this https://example.com/test.png.";
922 let (urls, text) = parse_files_and_message(input, ®ex);
923 assert_eq!(urls, vec!["https://example.com/test.png"]);
924 assert_eq!(text, "Look at this .");
925 }
926}