mistralrs_server/
interactive_mode.rs

1use directories::ProjectDirs;
2use either::Either;
3use indexmap::IndexMap;
4use mistralrs_core::{
5    speech_utils, ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
6    ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
7    Request, RequestMessage, Response, ResponseOk, SamplingParams, WebSearchOptions,
8    TERMINATE_ALL_NEXT_STEP,
9};
10use once_cell::sync::Lazy;
11use regex::Regex;
12use rustyline::{error::ReadlineError, history::History, DefaultEditor, Editor, Helper};
13use serde_json::Value;
14use std::{
15    fs,
16    io::{self, Write},
17    path::PathBuf,
18    sync::{atomic::Ordering, Arc, Mutex},
19    time::Instant,
20};
21use tokio::sync::mpsc::channel;
22use tracing::{error, info};
23
24use mistralrs_server_core::util;
25
26fn exit_handler() {
27    std::process::exit(0);
28}
29
30fn terminate_handler() {
31    TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst);
32}
33
34fn history_file_path() -> PathBuf {
35    // Replace these with your own org/app identifiers.
36    let proj_dirs = ProjectDirs::from("com", "", "mistral.rs")
37        .expect("Could not determine project directories");
38    let config_dir = proj_dirs.config_dir();
39
40    // Ensure the directory exists:
41    fs::create_dir_all(config_dir).expect("Failed to create config directory");
42
43    // e.g. ~/.config/MyApp/history.txt
44    config_dir.join("history.txt")
45}
46
47fn read_line<H: Helper, I: History>(editor: &mut Editor<H, I>) -> String {
48    let r = editor.readline("> ");
49    match r {
50        Err(ReadlineError::Interrupted) => {
51            editor.save_history(&history_file_path()).unwrap();
52            // Ctrl+C
53            std::process::exit(0);
54        }
55
56        Err(ReadlineError::Eof) => {
57            editor.save_history(&history_file_path()).unwrap();
58            // CTRL-D
59            std::process::exit(0);
60        }
61
62        Err(e) => {
63            editor.save_history(&history_file_path()).unwrap();
64            eprintln!("Error reading input: {e:?}");
65            std::process::exit(1);
66        }
67        Ok(prompt) => {
68            editor.add_history_entry(prompt.clone()).unwrap();
69            prompt
70        }
71    }
72}
73
74static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
75    Lazy::new(|| Mutex::new(&exit_handler));
76
77pub async fn interactive_mode(
78    mistralrs: Arc<MistralRs>,
79    do_search: bool,
80    enable_thinking: Option<bool>,
81) {
82    match mistralrs.get_model_category(None) {
83        Ok(ModelCategory::Text) => {
84            text_interactive_mode(mistralrs, do_search, enable_thinking).await
85        }
86        Ok(ModelCategory::Vision { .. }) => {
87            vision_interactive_mode(mistralrs, do_search, enable_thinking).await
88        }
89        Ok(ModelCategory::Diffusion) => diffusion_interactive_mode(mistralrs, do_search).await,
90        Ok(ModelCategory::Audio) => {
91            audio_interactive_mode(mistralrs, do_search, enable_thinking).await
92        }
93        Ok(ModelCategory::Speech) => speech_interactive_mode(mistralrs, do_search).await,
94        Err(e) => eprintln!("Error getting model category: {e}"),
95    }
96}
97
98const COMMAND_COMMANDS: &str = r#"
99Commands:
100- `\help`: Display this message.
101- `\exit`: Quit interactive mode.
102- `\system <system message here>`:
103    Add a system message to the chat without running the model.
104    Ex: `\system Always respond as a pirate.`
105- `\clear`: Clear the chat history.
106- `\temperature <float>`: Set sampling temperature (0.0 to 2.0).
107- `\topk <int>`: Set top-k sampling value (>0).
108- `\topp <float>`: Set top-p sampling value in (0.0 to 1.0).
109"#;
110
111const TEXT_INTERACTIVE_HELP: &str = r#"
112Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
113"#;
114
115const VISION_INTERACTIVE_HELP: &str = r#"
116Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
117
118To specify a message with one or more images or audios, simply include the image/audio URL or path:
119
120- `Describe these images: path/to/image1.jpg path/to/image2.png`
121- `Describe the image and transcribe the audio: path/to/image1.jpg path/to/sound.mp3`
122"#;
123
124const DIFFUSION_INTERACTIVE_HELP: &str = r#"
125Welcome to interactive mode! Because this model is a diffusion model, you can enter prompts and the model will generate an image.
126
127Commands:
128- `\help`: Display this message.
129- `\exit`: Quit interactive mode.
130"#;
131
132const SPEECH_INTERACTIVE_HELP: &str = r#"
133Welcome to interactive mode! Because this model is a speech generation model, you can enter prompts and the model will generate audio.
134
135Commands:
136- `\help`: Display this message.
137- `\exit`: Quit interactive mode.
138"#;
139
140const HELP_CMD: &str = "\\help";
141const EXIT_CMD: &str = "\\exit";
142const SYSTEM_CMD: &str = "\\system";
143const CLEAR_CMD: &str = "\\clear";
144const TEMPERATURE_CMD: &str = "\\temperature";
145const TOPK_CMD: &str = "\\topk";
146const TOPP_CMD: &str = "\\topp";
147
148/// Regex string used to extract image URLs from prompts.
149const IMAGE_REGEX: &str = r#"((?:https?://|file://)?\S+?\.(?:png|jpe?g|bmp|gif|webp)(?:\?\S+?)?)"#;
150const AUDIO_REGEX: &str = r#"((?:https?://|file://)?\S+?\.(?:wav|mp3|flac|ogg)(?:\?\S+?)?)"#;
151
152fn interactive_sample_parameters() -> SamplingParams {
153    SamplingParams {
154        temperature: Some(0.1),
155        top_k: Some(32),
156        top_p: Some(0.1),
157        min_p: Some(0.05),
158        top_n_logprobs: 0,
159        frequency_penalty: Some(0.1),
160        presence_penalty: Some(0.1),
161        max_len: None,
162        stop_toks: None,
163        logits_bias: None,
164        n_choices: 1,
165        dry_params: Some(DrySamplingParams::default()),
166    }
167}
168
169/// Handles sampling commands (\temperature, \topk, \topp) and updates the sampling_params accordingly.
170/// Returns true if the prompt was a handled sampling command, otherwise false.
171fn handle_sampling_command(prompt: &str, sampling_params: &mut SamplingParams) -> bool {
172    let trimmed = prompt.trim();
173    if trimmed.starts_with(TEMPERATURE_CMD) {
174        let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
175        if let [_, value] = parts.as_slice() {
176            match value.trim().parse::<f64>() {
177                Ok(v) if v > 0.0 && v <= 2.0 => {
178                    sampling_params.temperature = Some(v);
179                    info!("Set temperature to {v}");
180                }
181                Ok(_) => {
182                    println!("Error: temperature must be in (0.0, 2.0]");
183                }
184                Err(_) => println!("Error: format is `{TEMPERATURE_CMD} <float>`"),
185            }
186        } else {
187            println!("Error: format is `{TEMPERATURE_CMD} <float>`");
188        }
189        return true;
190    }
191    if trimmed.starts_with(TOPK_CMD) {
192        let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
193        if let [_, value] = parts.as_slice() {
194            match value.trim().parse::<usize>() {
195                Ok(v) if v > 0 => {
196                    sampling_params.top_k = Some(v);
197                    info!("Set top-k to {v}");
198                }
199                Ok(_) => {
200                    println!("Error: top-k must be a positive integer");
201                }
202                Err(_) => println!("Error: format is `{TOPK_CMD} <int>`"),
203            }
204        } else {
205            println!("Error: format is `{TOPK_CMD} <int>`");
206        }
207        return true;
208    }
209    if trimmed.starts_with(TOPP_CMD) {
210        let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
211        if let [_, value] = parts.as_slice() {
212            match value.trim().parse::<f64>() {
213                Ok(v) if v > 0.0 && v <= 1.0 => {
214                    sampling_params.top_p = Some(v);
215                    info!("Set top-p to {v}");
216                }
217                Ok(_) => {
218                    println!("Error: top-p must be in (0.0, 1.0]");
219                }
220                Err(_) => println!("Error: format is `{TOPP_CMD} <float>`"),
221            }
222        } else {
223            println!("Error: format is `{TOPP_CMD} <float>`");
224        }
225        return true;
226    }
227    false
228}
229
230async fn text_interactive_mode(
231    mistralrs: Arc<MistralRs>,
232    do_search: bool,
233    enable_thinking: Option<bool>,
234) {
235    let sender = mistralrs.get_sender(None).unwrap();
236    let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
237
238    let mut sampling_params = interactive_sample_parameters();
239
240    info!("Starting interactive loop with sampling params: {sampling_params:?}");
241    println!(
242        "{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
243        "=".repeat(20),
244        "=".repeat(20)
245    );
246
247    // Set the handler to process exit
248    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
249
250    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
251        .expect("Failed to set CTRL-C handler for interactive mode");
252
253    let mut rl = DefaultEditor::new().expect("Failed to open input");
254    let _ = rl.load_history(&history_file_path());
255    'outer: loop {
256        // Set the handler to process exit
257        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
258
259        let prompt = read_line(&mut rl);
260
261        let prompt_trimmed = prompt.as_str().trim();
262        if prompt_trimmed.is_empty() {
263            continue;
264        }
265        if handle_sampling_command(prompt_trimmed, &mut sampling_params) {
266            continue;
267        }
268        match prompt_trimmed {
269            HELP_CMD => {
270                println!(
271                    "{}{TEXT_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
272                    "=".repeat(20),
273                    "=".repeat(20)
274                );
275                continue;
276            }
277            EXIT_CMD => {
278                break;
279            }
280            CLEAR_CMD => {
281                messages.clear();
282                info!("Cleared chat history.");
283                continue;
284            }
285            _ if prompt_trimmed.starts_with(SYSTEM_CMD) => {
286                let parsed = match &prompt_trimmed.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
287                    &["", a] => a.trim(),
288                    _ => {
289                        println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
290                        continue;
291                    }
292                };
293                info!("Set system message to `{parsed}`.");
294                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
295                user_message.insert("role".to_string(), Either::Left("system".to_string()));
296                user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
297                messages.push(user_message);
298                continue;
299            }
300            message => {
301                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
302                user_message.insert("role".to_string(), Either::Left("user".to_string()));
303                user_message.insert("content".to_string(), Either::Left(message.to_string()));
304                messages.push(user_message);
305            }
306        }
307
308        // Set the handler to terminate all seqs, so allowing cancelling running
309        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
310
311        let request_messages = RequestMessage::Chat {
312            messages: messages.clone(),
313            enable_thinking,
314        };
315
316        let (tx, mut rx) = channel(10_000);
317        let req = Request::Normal(Box::new(NormalRequest {
318            id: mistralrs.next_request_id(),
319            messages: request_messages,
320            sampling_params: sampling_params.clone(),
321            response: tx,
322            return_logprobs: false,
323            is_streaming: true,
324            constraint: Constraint::None,
325            suffix: None,
326            tool_choice: None,
327            tools: None,
328            logits_processors: None,
329            return_raw_logits: false,
330            web_search_options: do_search.then(WebSearchOptions::default),
331            model_id: None,
332        }));
333        sender.send(req).await.unwrap();
334        let start_ttft = Instant::now();
335        let mut first_token_duration: Option<std::time::Duration> = None;
336
337        let mut assistant_output = String::new();
338
339        let mut last_usage = None;
340        while let Some(resp) = rx.recv().await {
341            match resp {
342                Response::Chunk(chunk) => {
343                    last_usage = chunk.usage.clone();
344                    if let ChunkChoice {
345                        delta:
346                            Delta {
347                                content: Some(content),
348                                ..
349                            },
350                        finish_reason,
351                        ..
352                    } = &chunk.choices[0]
353                    {
354                        if first_token_duration.is_none() {
355                            let ttft = Instant::now().duration_since(start_ttft);
356                            first_token_duration = Some(ttft);
357                        }
358                        assistant_output.push_str(content);
359                        print!("{content}");
360                        io::stdout().flush().unwrap();
361                        if finish_reason.is_some() {
362                            if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
363                                print!("...");
364                            }
365                            break;
366                        }
367                    }
368                }
369                Response::InternalError(e) => {
370                    error!("Got an internal error: {e:?}");
371                    break 'outer;
372                }
373                Response::ModelError(e, resp) => {
374                    error!("Got a model error: {e:?}, response: {resp:?}");
375                    break 'outer;
376                }
377                Response::ValidationError(e) => {
378                    error!("Got a validation error: {e:?}");
379                    break 'outer;
380                }
381                Response::Done(_) => unreachable!(),
382                Response::CompletionDone(_) => unreachable!(),
383                Response::CompletionModelError(_, _) => unreachable!(),
384                Response::CompletionChunk(_) => unreachable!(),
385                Response::ImageGeneration(_) => unreachable!(),
386                Response::Speech { .. } => unreachable!(),
387                Response::Raw { .. } => unreachable!(),
388            }
389        }
390
391        if let Some(last_usage) = last_usage {
392            println!();
393            println!();
394            println!("Stats:");
395            if let Some(ttft) = first_token_duration {
396                println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
397            }
398            println!(
399                "Prompt: {} tokens, {:.2} T/s",
400                last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
401            );
402            println!(
403                "Decode: {} tokens, {:.2} T/s",
404                last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
405            );
406        }
407        let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
408            IndexMap::new();
409        assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
410        assistant_message.insert("content".to_string(), Either::Left(assistant_output));
411        messages.push(assistant_message);
412        println!();
413    }
414
415    rl.save_history(&history_file_path()).unwrap();
416}
417
418fn parse_files_and_message(input: &str, regex: &Regex) -> (Vec<String>, String) {
419    // Collect all URLs
420    let urls: Vec<String> = regex
421        .captures_iter(input)
422        .filter_map(|cap| {
423            cap.get(1).map(|m| {
424                m.as_str()
425                    .trim_end_matches(|c: char| {
426                        matches!(
427                            c,
428                            '.' | ',' | ';' | ':' | '!' | '?' | ')' | ']' | '}' | '"' | '\''
429                        )
430                    })
431                    .to_string()
432            })
433        })
434        .collect();
435    // Remove the URLs from the input to get the message text
436    let text = regex.replace_all(input, "").trim().to_string();
437    (urls, text)
438}
439
440async fn vision_interactive_mode(
441    mistralrs: Arc<MistralRs>,
442    do_search: bool,
443    enable_thinking: Option<bool>,
444) {
445    // Capture HTTP/HTTPS URLs and local file paths ending with common image extensions
446    let image_regex = Regex::new(IMAGE_REGEX).unwrap();
447    let audio_regex = Regex::new(AUDIO_REGEX).unwrap();
448
449    let sender = mistralrs.get_sender(None).unwrap();
450    let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
451    let mut images = Vec::new();
452    let mut audios = Vec::new();
453
454    let config = mistralrs.config(None).unwrap();
455    let prefixer = match &config.category {
456        ModelCategory::Vision { prefixer } => prefixer,
457        ModelCategory::Text
458        | ModelCategory::Diffusion
459        | ModelCategory::Speech
460        | ModelCategory::Audio => {
461            panic!("`add_image_message` expects a vision model.")
462        }
463    };
464
465    let mut sampling_params = interactive_sample_parameters();
466
467    info!("Starting interactive loop with sampling params: {sampling_params:?}");
468    println!(
469        "{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
470        "=".repeat(20),
471        "=".repeat(20)
472    );
473
474    // Set the handler to process exit
475    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
476
477    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
478        .expect("Failed to set CTRL-C handler for interactive mode");
479
480    let mut rl = DefaultEditor::new().expect("Failed to open input");
481    let _ = rl.load_history(&history_file_path());
482    'outer: loop {
483        // Set the handler to process exit
484        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
485
486        let prompt = read_line(&mut rl);
487
488        let prompt_trimmed = prompt.as_str().trim();
489        if prompt_trimmed.is_empty() {
490            continue;
491        }
492        if handle_sampling_command(prompt_trimmed, &mut sampling_params) {
493            continue;
494        }
495        match prompt_trimmed {
496            HELP_CMD => {
497                println!(
498                    "{}{VISION_INTERACTIVE_HELP}{COMMAND_COMMANDS}{}",
499                    "=".repeat(20),
500                    "=".repeat(20)
501                );
502                continue;
503            }
504            EXIT_CMD => {
505                break;
506            }
507            CLEAR_CMD => {
508                messages.clear();
509                images.clear();
510                info!("Cleared chat history.");
511                continue;
512            }
513            _ if prompt_trimmed.starts_with(SYSTEM_CMD) => {
514                let parsed = match &prompt_trimmed.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
515                    &["", a] => a.trim(),
516                    _ => {
517                        println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
518                        continue;
519                    }
520                };
521                info!("Set system message to `{parsed}`.");
522                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
523                user_message.insert("role".to_string(), Either::Left("system".to_string()));
524                user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
525                messages.push(user_message);
526                continue;
527            }
528            _ => {
529                let (urls_image, text_without_images) =
530                    parse_files_and_message(prompt_trimmed, &image_regex);
531                let (urls_audio, text) =
532                    parse_files_and_message(&text_without_images, &audio_regex);
533                if !urls_image.is_empty() || !urls_audio.is_empty() {
534                    // Load images
535                    let mut image_indexes = Vec::new();
536                    for url in &urls_image {
537                        match util::parse_image_url(url).await {
538                            Ok(image) => {
539                                info!("Added image at `{url}`");
540                                image_indexes.push(images.len());
541                                images.push(image);
542                            }
543                            Err(e) => {
544                                error!("Failed to read image from URL/path {}: {}", url, e);
545                                continue 'outer;
546                            }
547                        }
548                    }
549                    // Load audios
550                    let mut audio_indexes = Vec::new();
551                    for url in &urls_audio {
552                        match util::parse_audio_url(url).await {
553                            Ok(audio) => {
554                                info!("Added audio at `{url}`");
555                                audio_indexes.push(audios.len());
556                                audios.push(audio);
557                            }
558                            Err(e) => {
559                                error!("Failed to read audio from URL/path {}: {}", url, e);
560                                continue 'outer;
561                            }
562                        }
563                    }
564                    // Build mixed content parts
565                    let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
566                    for _ in &urls_image {
567                        content_vec.push(IndexMap::from([(
568                            "type".to_string(),
569                            Value::String("image".to_string()),
570                        )]));
571                    }
572                    for _ in &urls_audio {
573                        content_vec.push(IndexMap::from([(
574                            "type".to_string(),
575                            Value::String("audio".to_string()),
576                        )]));
577                    }
578                    // Prefix the text with any media context
579                    let mut prefixed_text = text.clone();
580                    if !image_indexes.is_empty() {
581                        prefixed_text =
582                            prefixer.prefix_image(image_indexes.clone(), &prefixed_text);
583                    }
584                    if !audio_indexes.is_empty() {
585                        prefixed_text =
586                            prefixer.prefix_audio(audio_indexes.clone(), &prefixed_text);
587                    }
588                    // Add the final text part
589                    content_vec.push(IndexMap::from([
590                        ("type".to_string(), Value::String("text".to_string())),
591                        ("text".to_string(), Value::String(prefixed_text)),
592                    ]));
593                    // Push the combined user message
594                    let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
595                    user_message.insert("role".to_string(), Either::Left("user".to_string()));
596                    user_message.insert("content".to_string(), Either::Right(content_vec));
597                    messages.push(user_message);
598                } else {
599                    // Default: handle as text-only prompt
600                    let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
601                    user_message.insert("role".to_string(), Either::Left("user".to_string()));
602                    user_message.insert(
603                        "content".to_string(),
604                        Either::Left(prompt_trimmed.to_string()),
605                    );
606                    messages.push(user_message);
607                }
608            }
609        }
610
611        // Set the handler to terminate all seqs, so allowing cancelling running
612        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
613
614        let request_messages = RequestMessage::VisionChat {
615            images: images.clone(),
616            audios: audios.clone(),
617            messages: messages.clone(),
618            enable_thinking,
619        };
620
621        let (tx, mut rx) = channel(10_000);
622        let req = Request::Normal(Box::new(NormalRequest {
623            id: mistralrs.next_request_id(),
624            messages: request_messages,
625            sampling_params: sampling_params.clone(),
626            response: tx,
627            return_logprobs: false,
628            is_streaming: true,
629            constraint: Constraint::None,
630            suffix: None,
631            tool_choice: None,
632            tools: None,
633            logits_processors: None,
634            return_raw_logits: false,
635            web_search_options: do_search.then(WebSearchOptions::default),
636            model_id: None,
637        }));
638        sender.send(req).await.unwrap();
639        let start_ttft = Instant::now();
640        let mut first_token_duration: Option<std::time::Duration> = None;
641
642        let mut assistant_output = String::new();
643
644        let mut last_usage = None;
645        while let Some(resp) = rx.recv().await {
646            match resp {
647                Response::Chunk(chunk) => {
648                    last_usage = chunk.usage.clone();
649                    if let ChunkChoice {
650                        delta:
651                            Delta {
652                                content: Some(content),
653                                ..
654                            },
655                        finish_reason,
656                        ..
657                    } = &chunk.choices[0]
658                    {
659                        if first_token_duration.is_none() {
660                            let ttft = Instant::now().duration_since(start_ttft);
661                            first_token_duration = Some(ttft);
662                        }
663                        assistant_output.push_str(content);
664                        print!("{content}");
665                        io::stdout().flush().unwrap();
666                        if finish_reason.is_some() {
667                            if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
668                                print!("...");
669                            }
670                            break;
671                        }
672                    }
673                }
674                Response::InternalError(e) => {
675                    error!("Got an internal error: {e:?}");
676                    break 'outer;
677                }
678                Response::ModelError(e, resp) => {
679                    error!("Got a model error: {e:?}, response: {resp:?}");
680                    break 'outer;
681                }
682                Response::ValidationError(e) => {
683                    error!("Got a validation error: {e:?}");
684                    break 'outer;
685                }
686                Response::Done(_) => unreachable!(),
687                Response::CompletionDone(_) => unreachable!(),
688                Response::CompletionModelError(_, _) => unreachable!(),
689                Response::CompletionChunk(_) => unreachable!(),
690                Response::ImageGeneration(_) => unreachable!(),
691                Response::Speech { .. } => unreachable!(),
692                Response::Raw { .. } => unreachable!(),
693            }
694        }
695
696        if let Some(last_usage) = last_usage {
697            println!();
698            println!();
699            println!("Stats:");
700            if let Some(ttft) = first_token_duration {
701                println!("Time to first token: {:.2?}s", ttft.as_secs_f32());
702            }
703            println!(
704                "Prompt: {} tokens, {:.2} T/s",
705                last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
706            );
707            println!(
708                "Decode: {} tokens, {:.2} T/s",
709                last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
710            );
711        }
712        let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
713            IndexMap::new();
714        assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
715        assistant_message.insert("content".to_string(), Either::Left(assistant_output));
716        messages.push(assistant_message);
717        println!();
718    }
719
720    rl.save_history(&history_file_path()).unwrap();
721}
722
723async fn audio_interactive_mode(
724    _mistralrs: Arc<MistralRs>,
725    _do_search: bool,
726    _enable_thinking: Option<bool>,
727) {
728    unimplemented!("Using audio models isn't supported yet")
729}
730
731async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
732    let sender = mistralrs.get_sender(None).unwrap();
733
734    let diffusion_params = DiffusionGenerationParams::default();
735
736    info!("Starting interactive loop with generation params: {diffusion_params:?}");
737    println!(
738        "{}{DIFFUSION_INTERACTIVE_HELP}{}",
739        "=".repeat(20),
740        "=".repeat(20)
741    );
742
743    // Set the handler to process exit
744    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
745
746    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
747        .expect("Failed to set CTRL-C handler for interactive mode");
748
749    let mut rl = DefaultEditor::new().expect("Failed to open input");
750    let _ = rl.load_history(&history_file_path());
751    loop {
752        // Set the handler to process exit
753        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
754
755        let prompt = read_line(&mut rl);
756
757        let prompt = match prompt.as_str().trim() {
758            "" => continue,
759            HELP_CMD => {
760                println!(
761                    "{}{DIFFUSION_INTERACTIVE_HELP}{}",
762                    "=".repeat(20),
763                    "=".repeat(20)
764                );
765                continue;
766            }
767            EXIT_CMD => {
768                break;
769            }
770            prompt => prompt.to_string(),
771        };
772
773        // Set the handler to terminate all seqs, so allowing cancelling running
774        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
775
776        let (tx, mut rx) = channel(10_000);
777        let req = Request::Normal(Box::new(NormalRequest {
778            id: 0,
779            messages: RequestMessage::ImageGeneration {
780                prompt: prompt.to_string(),
781                format: ImageGenerationResponseFormat::Url,
782                generation_params: diffusion_params.clone(),
783            },
784            sampling_params: SamplingParams::deterministic(),
785            response: tx,
786            return_logprobs: false,
787            is_streaming: false,
788            suffix: None,
789            constraint: Constraint::None,
790            tool_choice: None,
791            tools: None,
792            logits_processors: None,
793            return_raw_logits: false,
794            web_search_options: do_search.then(WebSearchOptions::default),
795            model_id: None,
796        }));
797
798        let start = Instant::now();
799        sender.send(req).await.unwrap();
800
801        let ResponseOk::ImageGeneration(response) = rx.recv().await.unwrap().as_result().unwrap()
802        else {
803            panic!("Got unexpected response type.")
804        };
805        let end = Instant::now();
806
807        let duration = end.duration_since(start).as_secs_f32();
808        let pixels_per_s = (diffusion_params.height * diffusion_params.width) as f32 / duration;
809
810        println!(
811            "Image generated can be found at: image is at `{}`. Took {duration:.2}s ({pixels_per_s:.2} pixels/s).",
812            response.data[0].url.as_ref().unwrap(),
813        );
814
815        println!();
816    }
817
818    rl.save_history(&history_file_path()).unwrap();
819}
820
821async fn speech_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
822    let sender = mistralrs.get_sender(None).unwrap();
823
824    info!("Starting interactive loop for speech");
825    println!(
826        "{}{SPEECH_INTERACTIVE_HELP}{}",
827        "=".repeat(20),
828        "=".repeat(20)
829    );
830
831    // Set the handler to process exit
832    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
833
834    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
835        .expect("Failed to set CTRL-C handler for interactive mode");
836
837    let mut rl = DefaultEditor::new().expect("Failed to open input");
838    let _ = rl.load_history(&history_file_path());
839
840    let mut n = 0;
841    loop {
842        // Set the handler to process exit
843        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
844
845        let prompt = read_line(&mut rl);
846
847        let prompt = match prompt.as_str().trim() {
848            "" => continue,
849            HELP_CMD => {
850                println!(
851                    "{}{SPEECH_INTERACTIVE_HELP}{}",
852                    "=".repeat(20),
853                    "=".repeat(20)
854                );
855                continue;
856            }
857            EXIT_CMD => {
858                break;
859            }
860            prompt => prompt.to_string(),
861        };
862
863        // Set the handler to terminate all seqs, so allowing cancelling running
864        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
865
866        let (tx, mut rx) = channel(10_000);
867        let req = Request::Normal(Box::new(NormalRequest {
868            id: 0,
869            messages: RequestMessage::SpeechGeneration {
870                prompt: prompt.to_string(),
871            },
872            sampling_params: SamplingParams::deterministic(),
873            response: tx,
874            return_logprobs: false,
875            is_streaming: false,
876            suffix: None,
877            constraint: Constraint::None,
878            tool_choice: None,
879            tools: None,
880            logits_processors: None,
881            return_raw_logits: false,
882            web_search_options: do_search.then(WebSearchOptions::default),
883            model_id: None,
884        }));
885
886        let start = Instant::now();
887        sender.send(req).await.unwrap();
888
889        let ResponseOk::Speech {
890            pcm,
891            rate,
892            channels,
893        } = rx.recv().await.unwrap().as_result().unwrap()
894        else {
895            panic!("Got unexpected response type.")
896        };
897        let end = Instant::now();
898
899        let out_file = format!("speech-{n}.wav");
900        let mut output = std::fs::File::create(&out_file).unwrap();
901        speech_utils::write_pcm_as_wav(&mut output, &pcm, rate as u32, channels as u16).unwrap();
902
903        let duration = end.duration_since(start).as_secs_f32();
904        println!("Speech generated can be found at `{out_file}`. Took {duration:.2}s.");
905
906        n += 1;
907
908        println!();
909    }
910
911    rl.save_history(&history_file_path()).unwrap();
912}
913
914#[cfg(test)]
915mod tests {
916    use super::*;
917
918    #[test]
919    fn parse_files_and_message_trims_trailing_punctuation() {
920        let regex = Regex::new(IMAGE_REGEX).unwrap();
921        let input = "Look at this https://example.com/test.png.";
922        let (urls, text) = parse_files_and_message(input, &regex);
923        assert_eq!(urls, vec!["https://example.com/test.png"]);
924        assert_eq!(text, "Look at this .");
925    }
926}