mistralrs_server/
interactive_mode.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_core::{
4    ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
5    ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
6    Request, RequestMessage, Response, ResponseOk, SamplingParams, WebSearchOptions,
7    TERMINATE_ALL_NEXT_STEP,
8};
9use once_cell::sync::Lazy;
10use regex::Regex;
11use serde_json::Value;
12use std::{
13    io::{self, Write},
14    sync::{atomic::Ordering, Arc, Mutex},
15    time::Instant,
16};
17use tokio::sync::mpsc::channel;
18use tracing::{error, info};
19
20use crate::util;
21
22fn exit_handler() {
23    std::process::exit(0);
24}
25
26fn terminate_handler() {
27    TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst);
28}
29
30static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
31    Lazy::new(|| Mutex::new(&exit_handler));
32
33pub async fn interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool, do_search: bool) {
34    match mistralrs.get_model_category() {
35        ModelCategory::Text => text_interactive_mode(mistralrs, throughput, do_search).await,
36        ModelCategory::Vision { .. } => {
37            vision_interactive_mode(mistralrs, throughput, do_search).await
38        }
39        ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
40    }
41}
42
43const TEXT_INTERACTIVE_HELP: &str = r#"
44Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
45
46Commands:
47- `\help`: Display this message.
48- `\exit`: Quit interactive mode.
49- `\system <system message here>`:
50    Add a system message to the chat without running the model.
51    Ex: `\system Always respond as a pirate.`
52"#;
53
54const VISION_INTERACTIVE_HELP: &str = r#"
55Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
56
57To specify a message with an image, use the `\image` command detailed below.
58
59Commands:
60- `\help`: Display this message.
61- `\exit`: Quit interactive mode.
62- `\system <system message here>`:
63    Add a system message to the chat without running the model.
64    Ex: `\system Always respond as a pirate.`
65- `\image <image URL or local path here> <message here>`:
66    Add a message paired with an image. The image will be fed to the model as if it were the first item in this prompt.
67    You do not need to modify your prompt for specific models.
68    Ex: `\image path/to/image.jpg Describe what is in this image.`
69"#;
70
71const DIFFUSION_INTERACTIVE_HELP: &str = r#"
72Welcome to interactive mode! Because this model is a diffusion model, you can enter prompts and the model will generate an image.
73
74Commands:
75- `\help`: Display this message.
76- `\exit`: Quit interactive mode.
77"#;
78
79const HELP_CMD: &str = "\\help";
80const EXIT_CMD: &str = "\\exit";
81const SYSTEM_CMD: &str = "\\system";
82const IMAGE_CMD: &str = "\\image";
83
84async fn text_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool, do_search: bool) {
85    let sender = mistralrs.get_sender().unwrap();
86    let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
87
88    let sampling_params = SamplingParams {
89        temperature: Some(0.1),
90        top_k: Some(32),
91        top_p: Some(0.1),
92        min_p: Some(0.05),
93        top_n_logprobs: 0,
94        frequency_penalty: Some(0.1),
95        presence_penalty: Some(0.1),
96        max_len: Some(4096),
97        stop_toks: None,
98        logits_bias: None,
99        n_choices: 1,
100        dry_params: Some(DrySamplingParams::default()),
101    };
102
103    info!("Starting interactive loop with sampling params: {sampling_params:?}");
104    println!(
105        "{}{TEXT_INTERACTIVE_HELP}{}",
106        "=".repeat(20),
107        "=".repeat(20)
108    );
109
110    // Set the handler to process exit
111    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
112
113    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
114        .expect("Failed to set CTRL-C handler for interactive mode");
115
116    'outer: loop {
117        // Set the handler to process exit
118        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
119
120        let mut prompt = String::new();
121        print!("> ");
122        io::stdout().flush().unwrap();
123        io::stdin()
124            .read_line(&mut prompt)
125            .expect("Failed to get input");
126
127        match prompt.as_str().trim() {
128            "" => continue,
129            HELP_CMD => {
130                println!(
131                    "{}{TEXT_INTERACTIVE_HELP}{}",
132                    "=".repeat(20),
133                    "=".repeat(20)
134                );
135                continue;
136            }
137            EXIT_CMD => {
138                break;
139            }
140            prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
141                let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
142                    &["", a] => a.trim(),
143                    _ => {
144                        println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
145                        continue;
146                    }
147                };
148                info!("Set system message to `{parsed}`.");
149                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
150                user_message.insert("role".to_string(), Either::Left("system".to_string()));
151                user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
152                messages.push(user_message);
153                continue;
154            }
155            message => {
156                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
157                user_message.insert("role".to_string(), Either::Left("user".to_string()));
158                user_message.insert("content".to_string(), Either::Left(message.to_string()));
159                messages.push(user_message);
160            }
161        }
162
163        // Set the handler to terminate all seqs, so allowing cancelling running
164        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
165
166        let request_messages = RequestMessage::Chat(messages.clone());
167
168        let (tx, mut rx) = channel(10_000);
169        let req = Request::Normal(NormalRequest {
170            id: mistralrs.next_request_id(),
171            messages: request_messages,
172            sampling_params: sampling_params.clone(),
173            response: tx,
174            return_logprobs: false,
175            is_streaming: true,
176            constraint: Constraint::None,
177            suffix: None,
178            adapters: None,
179            tool_choice: None,
180            tools: None,
181            logits_processors: None,
182            return_raw_logits: false,
183            web_search_options: do_search.then(WebSearchOptions::default),
184        });
185        sender.send(req).await.unwrap();
186
187        let mut assistant_output = String::new();
188
189        let start = Instant::now();
190        let mut toks = 0;
191        while let Some(resp) = rx.recv().await {
192            match resp {
193                Response::Chunk(chunk) => {
194                    if let ChunkChoice {
195                        delta:
196                            Delta {
197                                content: Some(content),
198                                ..
199                            },
200                        finish_reason,
201                        ..
202                    } = &chunk.choices[0]
203                    {
204                        assistant_output.push_str(content);
205                        print!("{}", content);
206                        toks += 1usize; // NOTE: we send toks every 3.
207                        io::stdout().flush().unwrap();
208                        if finish_reason.is_some() {
209                            if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
210                                print!("...");
211                            }
212                            break;
213                        }
214                    }
215                }
216                Response::InternalError(e) => {
217                    error!("Got an internal error: {e:?}");
218                    break 'outer;
219                }
220                Response::ModelError(e, resp) => {
221                    error!("Got a model error: {e:?}, response: {resp:?}");
222                    break 'outer;
223                }
224                Response::ValidationError(e) => {
225                    error!("Got a validation error: {e:?}");
226                    break 'outer;
227                }
228                Response::Done(_) => unreachable!(),
229                Response::CompletionDone(_) => unreachable!(),
230                Response::CompletionModelError(_, _) => unreachable!(),
231                Response::CompletionChunk(_) => unreachable!(),
232                Response::ImageGeneration(_) => unreachable!(),
233                Response::Raw { .. } => unreachable!(),
234            }
235        }
236        if throughput {
237            let time = Instant::now().duration_since(start).as_secs_f64();
238            println!();
239            info!("Average T/s: {}", toks as f64 / time);
240        }
241        let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
242            IndexMap::new();
243        assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
244        assistant_message.insert("content".to_string(), Either::Left(assistant_output));
245        messages.push(assistant_message);
246        println!();
247    }
248}
249
250fn parse_image_path_and_message(input: &str) -> Option<(String, String)> {
251    // Regex to capture the image path and the following message
252    let re = Regex::new(r#"\\image\s+"([^"]+)"\s*(.*)|\\image\s+(\S+)\s*(.*)"#).unwrap();
253
254    if let Some(captures) = re.captures(input) {
255        // Capture either the quoted or unquoted path and the message
256        if let Some(path) = captures.get(1) {
257            if let Some(message) = captures.get(2) {
258                return Some((
259                    path.as_str().trim().to_string(),
260                    message.as_str().trim().to_string(),
261                ));
262            }
263        } else if let Some(path) = captures.get(3) {
264            if let Some(message) = captures.get(4) {
265                return Some((
266                    path.as_str().trim().to_string(),
267                    message.as_str().trim().to_string(),
268                ));
269            }
270        }
271    }
272
273    None
274}
275
276async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool, do_search: bool) {
277    let sender = mistralrs.get_sender().unwrap();
278    let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
279    let mut images = Vec::new();
280
281    let prefixer = match &mistralrs.config().category {
282        ModelCategory::Text | ModelCategory::Diffusion => {
283            panic!("`add_image_message` expects a vision model.")
284        }
285        ModelCategory::Vision {
286            has_conv2d: _,
287            prefixer,
288        } => prefixer,
289    };
290
291    let sampling_params = SamplingParams {
292        temperature: Some(0.1),
293        top_k: Some(32),
294        top_p: Some(0.1),
295        min_p: Some(0.05),
296        top_n_logprobs: 0,
297        frequency_penalty: Some(0.1),
298        presence_penalty: Some(0.1),
299        max_len: Some(4096),
300        stop_toks: None,
301        logits_bias: None,
302        n_choices: 1,
303        dry_params: Some(DrySamplingParams::default()),
304    };
305
306    info!("Starting interactive loop with sampling params: {sampling_params:?}");
307    println!(
308        "{}{VISION_INTERACTIVE_HELP}{}",
309        "=".repeat(20),
310        "=".repeat(20)
311    );
312
313    // Set the handler to process exit
314    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
315
316    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
317        .expect("Failed to set CTRL-C handler for interactive mode");
318
319    'outer: loop {
320        // Set the handler to process exit
321        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
322
323        let mut prompt = String::new();
324        print!("> ");
325        io::stdout().flush().unwrap();
326        io::stdin()
327            .read_line(&mut prompt)
328            .expect("Failed to get input");
329
330        match prompt.as_str().trim() {
331            "" => continue,
332            HELP_CMD => {
333                println!(
334                    "{}{VISION_INTERACTIVE_HELP}{}",
335                    "=".repeat(20),
336                    "=".repeat(20)
337                );
338                continue;
339            }
340            EXIT_CMD => {
341                break;
342            }
343            prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
344                let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
345                    &["", a] => a.trim(),
346                    _ => {
347                        println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
348                        continue;
349                    }
350                };
351                info!("Set system message to `{parsed}`.");
352                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
353                user_message.insert("role".to_string(), Either::Left("system".to_string()));
354                user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
355                messages.push(user_message);
356                continue;
357            }
358            prompt if prompt.trim().starts_with(IMAGE_CMD) => {
359                let Some((url, message)) = parse_image_path_and_message(prompt.trim()) else {
360                    println!("Error: Adding an image message should be done with this format: `{IMAGE_CMD} path/to/image.jpg Describe what is in this image.`");
361                    continue;
362                };
363                let message = prefixer.prefix_image(images.len(), &message);
364
365                let image = util::parse_image_url(&url)
366                    .await
367                    .expect("Failed to read image from URL/path");
368                images.push(image);
369
370                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
371                user_message.insert("role".to_string(), Either::Left("user".to_string()));
372                user_message.insert(
373                    "content".to_string(),
374                    Either::Right(vec![
375                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
376                        IndexMap::from([
377                            ("type".to_string(), Value::String("text".to_string())),
378                            ("text".to_string(), Value::String(message)),
379                        ]),
380                    ]),
381                );
382                messages.push(user_message);
383            }
384            message => {
385                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
386                user_message.insert("role".to_string(), Either::Left("user".to_string()));
387                user_message.insert("content".to_string(), Either::Left(message.to_string()));
388                messages.push(user_message);
389            }
390        };
391
392        // Set the handler to terminate all seqs, so allowing cancelling running
393        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
394
395        let request_messages = RequestMessage::VisionChat {
396            images: images.clone(),
397            messages: messages.clone(),
398        };
399
400        let (tx, mut rx) = channel(10_000);
401        let req = Request::Normal(NormalRequest {
402            id: mistralrs.next_request_id(),
403            messages: request_messages,
404            sampling_params: sampling_params.clone(),
405            response: tx,
406            return_logprobs: false,
407            is_streaming: true,
408            constraint: Constraint::None,
409            suffix: None,
410            adapters: None,
411            tool_choice: None,
412            tools: None,
413            logits_processors: None,
414            return_raw_logits: false,
415            web_search_options: do_search.then(WebSearchOptions::default),
416        });
417        sender.send(req).await.unwrap();
418
419        let mut assistant_output = String::new();
420
421        let start = Instant::now();
422        let mut toks = 0;
423        while let Some(resp) = rx.recv().await {
424            match resp {
425                Response::Chunk(chunk) => {
426                    if let ChunkChoice {
427                        delta:
428                            Delta {
429                                content: Some(content),
430                                ..
431                            },
432                        finish_reason,
433                        ..
434                    } = &chunk.choices[0]
435                    {
436                        assistant_output.push_str(content);
437                        print!("{}", content);
438                        toks += 1usize; // NOTE: we send toks every 3.
439                        io::stdout().flush().unwrap();
440                        if finish_reason.is_some() {
441                            if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
442                                print!("...");
443                            }
444                            break;
445                        }
446                    }
447                }
448                Response::InternalError(e) => {
449                    error!("Got an internal error: {e:?}");
450                    break 'outer;
451                }
452                Response::ModelError(e, resp) => {
453                    error!("Got a model error: {e:?}, response: {resp:?}");
454                    break 'outer;
455                }
456                Response::ValidationError(e) => {
457                    error!("Got a validation error: {e:?}");
458                    break 'outer;
459                }
460                Response::Done(_) => unreachable!(),
461                Response::CompletionDone(_) => unreachable!(),
462                Response::CompletionModelError(_, _) => unreachable!(),
463                Response::CompletionChunk(_) => unreachable!(),
464                Response::ImageGeneration(_) => unreachable!(),
465                Response::Raw { .. } => unreachable!(),
466            }
467        }
468        if throughput {
469            let time = Instant::now().duration_since(start).as_secs_f64();
470            println!();
471            info!("Average T/s: {}", toks as f64 / time);
472        }
473        let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
474            IndexMap::new();
475        assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
476        assistant_message.insert("content".to_string(), Either::Left(assistant_output));
477        messages.push(assistant_message);
478        println!();
479    }
480}
481
482async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
483    let sender = mistralrs.get_sender().unwrap();
484
485    let diffusion_params = DiffusionGenerationParams::default();
486
487    info!("Starting interactive loop with generation params: {diffusion_params:?}");
488    println!(
489        "{}{DIFFUSION_INTERACTIVE_HELP}{}",
490        "=".repeat(20),
491        "=".repeat(20)
492    );
493
494    // Set the handler to process exit
495    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
496
497    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
498        .expect("Failed to set CTRL-C handler for interactive mode");
499
500    loop {
501        // Set the handler to process exit
502        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
503
504        let mut prompt = String::new();
505        print!("> ");
506        io::stdout().flush().unwrap();
507        io::stdin()
508            .read_line(&mut prompt)
509            .expect("Failed to get input");
510
511        let prompt = match prompt.as_str().trim() {
512            "" => continue,
513            HELP_CMD => {
514                println!(
515                    "{}{DIFFUSION_INTERACTIVE_HELP}{}",
516                    "=".repeat(20),
517                    "=".repeat(20)
518                );
519                continue;
520            }
521            EXIT_CMD => {
522                break;
523            }
524            prompt => prompt.to_string(),
525        };
526
527        // Set the handler to terminate all seqs, so allowing cancelling running
528        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
529
530        let (tx, mut rx) = channel(10_000);
531        let req = Request::Normal(NormalRequest {
532            id: 0,
533            messages: RequestMessage::ImageGeneration {
534                prompt: prompt.to_string(),
535                format: ImageGenerationResponseFormat::Url,
536                generation_params: diffusion_params.clone(),
537            },
538            sampling_params: SamplingParams::deterministic(),
539            response: tx,
540            return_logprobs: false,
541            is_streaming: false,
542            suffix: None,
543            constraint: Constraint::None,
544            adapters: None,
545            tool_choice: None,
546            tools: None,
547            logits_processors: None,
548            return_raw_logits: false,
549            web_search_options: do_search.then(WebSearchOptions::default),
550        });
551
552        let start = Instant::now();
553        sender.send(req).await.unwrap();
554
555        let ResponseOk::ImageGeneration(response) = rx.recv().await.unwrap().as_result().unwrap()
556        else {
557            panic!("Got unexpected response type.")
558        };
559        let end = Instant::now();
560
561        let duration = end.duration_since(start).as_secs_f32();
562        let pixels_per_s = (diffusion_params.height * diffusion_params.width) as f32 / duration;
563
564        println!(
565            "Image generated can be found at: image is at `{}`. Took {duration:.2}s ({pixels_per_s:.2} pixels/s).",
566            response.data[0].url.as_ref().unwrap(),
567        );
568
569        println!();
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::parse_image_path_and_message;
576
577    #[test]
578    fn test_parse_image_with_unquoted_path_and_message() {
579        let input = r#"\image image.jpg What is this"#;
580        let result = parse_image_path_and_message(input);
581        assert_eq!(
582            result,
583            Some(("image.jpg".to_string(), "What is this".to_string()))
584        );
585    }
586
587    #[test]
588    fn test_parse_image_with_quoted_path_and_message() {
589        let input = r#"\image "image name.jpg" What is this?"#;
590        let result = parse_image_path_and_message(input);
591        assert_eq!(
592            result,
593            Some(("image name.jpg".to_string(), "What is this?".to_string()))
594        );
595    }
596
597    #[test]
598    fn test_parse_image_with_only_unquoted_path() {
599        let input = r#"\image image.jpg"#;
600        let result = parse_image_path_and_message(input);
601        assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
602    }
603
604    #[test]
605    fn test_parse_image_with_only_quoted_path() {
606        let input = r#"\image "image name.jpg""#;
607        let result = parse_image_path_and_message(input);
608        assert_eq!(result, Some(("image name.jpg".to_string(), "".to_string())));
609    }
610
611    #[test]
612    fn test_parse_image_with_extra_spaces() {
613        let input = r#"\image    "image with spaces.jpg"    This is a test message with spaces  "#;
614        let result = parse_image_path_and_message(input);
615        assert_eq!(
616            result,
617            Some((
618                "image with spaces.jpg".to_string(),
619                "This is a test message with spaces".to_string()
620            ))
621        );
622    }
623
624    #[test]
625    fn test_parse_image_with_no_message() {
626        let input = r#"\image "image.jpg""#;
627        let result = parse_image_path_and_message(input);
628        assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
629    }
630
631    #[test]
632    fn test_parse_image_missing_path() {
633        let input = r#"\image"#;
634        let result = parse_image_path_and_message(input);
635        assert_eq!(result, None);
636    }
637
638    #[test]
639    fn test_parse_image_invalid_command() {
640        let input = r#"\img "image.jpg" This should fail"#;
641        let result = parse_image_path_and_message(input);
642        assert_eq!(result, None);
643    }
644
645    #[test]
646    fn test_parse_image_with_non_image_text() {
647        let input = r#"Some random text without command"#;
648        let result = parse_image_path_and_message(input);
649        assert_eq!(result, None);
650    }
651
652    #[test]
653    fn test_parse_image_with_path_and_message_special_chars() {
654        let input = r#"\image "path with special chars @#$%^&*().jpg" This is a message with special chars !@#$%^&*()"#;
655        let result = parse_image_path_and_message(input);
656        assert_eq!(
657            result,
658            Some((
659                "path with special chars @#$%^&*().jpg".to_string(),
660                "This is a message with special chars !@#$%^&*()".to_string()
661            ))
662        );
663    }
664}