mistralrs_server/
interactive_mode.rs

1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_core::{
4    ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
5    ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
6    Request, RequestMessage, Response, ResponseOk, SamplingParams, WebSearchOptions,
7    TERMINATE_ALL_NEXT_STEP,
8};
9use once_cell::sync::Lazy;
10use regex::Regex;
11use serde_json::Value;
12use std::{
13    io::{self, Write},
14    sync::{atomic::Ordering, Arc, Mutex},
15    time::Instant,
16};
17use tokio::sync::mpsc::channel;
18use tracing::{error, info};
19
20use crate::util;
21
22fn exit_handler() {
23    std::process::exit(0);
24}
25
26fn terminate_handler() {
27    TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst);
28}
29
30static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
31    Lazy::new(|| Mutex::new(&exit_handler));
32
33pub async fn interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
34    match mistralrs.get_model_category() {
35        ModelCategory::Text => text_interactive_mode(mistralrs, do_search).await,
36        ModelCategory::Vision { .. } => vision_interactive_mode(mistralrs, do_search).await,
37        ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
38    }
39}
40
41const TEXT_INTERACTIVE_HELP: &str = r#"
42Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
43
44Commands:
45- `\help`: Display this message.
46- `\exit`: Quit interactive mode.
47- `\system <system message here>`:
48    Add a system message to the chat without running the model.
49    Ex: `\system Always respond as a pirate.`
50"#;
51
52const VISION_INTERACTIVE_HELP: &str = r#"
53Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
54
55To specify a message with an image, use the `\image` command detailed below.
56
57Commands:
58- `\help`: Display this message.
59- `\exit`: Quit interactive mode.
60- `\system <system message here>`:
61    Add a system message to the chat without running the model.
62    Ex: `\system Always respond as a pirate.`
63- `\image <image URL or local path here> <message here>`:
64    Add a message paired with an image. The image will be fed to the model as if it were the first item in this prompt.
65    You do not need to modify your prompt for specific models.
66    Ex: `\image path/to/image.jpg Describe what is in this image.`
67"#;
68
69const DIFFUSION_INTERACTIVE_HELP: &str = r#"
70Welcome to interactive mode! Because this model is a diffusion model, you can enter prompts and the model will generate an image.
71
72Commands:
73- `\help`: Display this message.
74- `\exit`: Quit interactive mode.
75"#;
76
77const HELP_CMD: &str = "\\help";
78const EXIT_CMD: &str = "\\exit";
79const SYSTEM_CMD: &str = "\\system";
80const IMAGE_CMD: &str = "\\image";
81
82async fn text_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
83    let sender = mistralrs.get_sender().unwrap();
84    let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
85
86    let sampling_params = SamplingParams {
87        temperature: Some(0.1),
88        top_k: Some(32),
89        top_p: Some(0.1),
90        min_p: Some(0.05),
91        top_n_logprobs: 0,
92        frequency_penalty: Some(0.1),
93        presence_penalty: Some(0.1),
94        max_len: Some(4096),
95        stop_toks: None,
96        logits_bias: None,
97        n_choices: 1,
98        dry_params: Some(DrySamplingParams::default()),
99    };
100
101    info!("Starting interactive loop with sampling params: {sampling_params:?}");
102    println!(
103        "{}{TEXT_INTERACTIVE_HELP}{}",
104        "=".repeat(20),
105        "=".repeat(20)
106    );
107
108    // Set the handler to process exit
109    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
110
111    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
112        .expect("Failed to set CTRL-C handler for interactive mode");
113
114    'outer: loop {
115        // Set the handler to process exit
116        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
117
118        let mut prompt = String::new();
119        print!("> ");
120        io::stdout().flush().unwrap();
121        io::stdin()
122            .read_line(&mut prompt)
123            .expect("Failed to get input");
124
125        match prompt.as_str().trim() {
126            "" => continue,
127            HELP_CMD => {
128                println!(
129                    "{}{TEXT_INTERACTIVE_HELP}{}",
130                    "=".repeat(20),
131                    "=".repeat(20)
132                );
133                continue;
134            }
135            EXIT_CMD => {
136                break;
137            }
138            prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
139                let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
140                    &["", a] => a.trim(),
141                    _ => {
142                        println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
143                        continue;
144                    }
145                };
146                info!("Set system message to `{parsed}`.");
147                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
148                user_message.insert("role".to_string(), Either::Left("system".to_string()));
149                user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
150                messages.push(user_message);
151                continue;
152            }
153            message => {
154                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
155                user_message.insert("role".to_string(), Either::Left("user".to_string()));
156                user_message.insert("content".to_string(), Either::Left(message.to_string()));
157                messages.push(user_message);
158            }
159        }
160
161        // Set the handler to terminate all seqs, so allowing cancelling running
162        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
163
164        let request_messages = RequestMessage::Chat(messages.clone());
165
166        let (tx, mut rx) = channel(10_000);
167        let req = Request::Normal(NormalRequest {
168            id: mistralrs.next_request_id(),
169            messages: request_messages,
170            sampling_params: sampling_params.clone(),
171            response: tx,
172            return_logprobs: false,
173            is_streaming: true,
174            constraint: Constraint::None,
175            suffix: None,
176            tool_choice: None,
177            tools: None,
178            logits_processors: None,
179            return_raw_logits: false,
180            web_search_options: do_search.then(WebSearchOptions::default),
181        });
182        sender.send(req).await.unwrap();
183
184        let mut assistant_output = String::new();
185
186        let mut last_usage = None;
187        while let Some(resp) = rx.recv().await {
188            match resp {
189                Response::Chunk(chunk) => {
190                    last_usage = chunk.usage.clone();
191                    if let ChunkChoice {
192                        delta:
193                            Delta {
194                                content: Some(content),
195                                ..
196                            },
197                        finish_reason,
198                        ..
199                    } = &chunk.choices[0]
200                    {
201                        assistant_output.push_str(content);
202                        print!("{}", content);
203                        io::stdout().flush().unwrap();
204                        if finish_reason.is_some() {
205                            if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
206                                print!("...");
207                            }
208                            break;
209                        }
210                    }
211                }
212                Response::InternalError(e) => {
213                    error!("Got an internal error: {e:?}");
214                    break 'outer;
215                }
216                Response::ModelError(e, resp) => {
217                    error!("Got a model error: {e:?}, response: {resp:?}");
218                    break 'outer;
219                }
220                Response::ValidationError(e) => {
221                    error!("Got a validation error: {e:?}");
222                    break 'outer;
223                }
224                Response::Done(_) => unreachable!(),
225                Response::CompletionDone(_) => unreachable!(),
226                Response::CompletionModelError(_, _) => unreachable!(),
227                Response::CompletionChunk(_) => unreachable!(),
228                Response::ImageGeneration(_) => unreachable!(),
229                Response::Raw { .. } => unreachable!(),
230            }
231        }
232
233        if let Some(last_usage) = last_usage {
234            println!();
235            println!();
236            println!("Stats:");
237            println!(
238                "Prompt: {} tokens, {:.2} T/s",
239                last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
240            );
241            println!(
242                "Decode: {} tokens, {:.2} T/s",
243                last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
244            );
245        }
246        let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
247            IndexMap::new();
248        assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
249        assistant_message.insert("content".to_string(), Either::Left(assistant_output));
250        messages.push(assistant_message);
251        println!();
252    }
253}
254
255fn parse_image_path_and_message(input: &str) -> Option<(String, String)> {
256    // Regex to capture the image path and the following message
257    let re = Regex::new(r#"\\image\s+"([^"]+)"\s*(.*)|\\image\s+(\S+)\s*(.*)"#).unwrap();
258
259    if let Some(captures) = re.captures(input) {
260        // Capture either the quoted or unquoted path and the message
261        if let Some(path) = captures.get(1) {
262            if let Some(message) = captures.get(2) {
263                return Some((
264                    path.as_str().trim().to_string(),
265                    message.as_str().trim().to_string(),
266                ));
267            }
268        } else if let Some(path) = captures.get(3) {
269            if let Some(message) = captures.get(4) {
270                return Some((
271                    path.as_str().trim().to_string(),
272                    message.as_str().trim().to_string(),
273                ));
274            }
275        }
276    }
277
278    None
279}
280
281async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
282    let sender = mistralrs.get_sender().unwrap();
283    let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
284    let mut images = Vec::new();
285
286    let prefixer = match &mistralrs.config().category {
287        ModelCategory::Text | ModelCategory::Diffusion => {
288            panic!("`add_image_message` expects a vision model.")
289        }
290        ModelCategory::Vision {
291            has_conv2d: _,
292            prefixer,
293        } => prefixer,
294    };
295
296    let sampling_params = SamplingParams {
297        temperature: Some(0.1),
298        top_k: Some(32),
299        top_p: Some(0.1),
300        min_p: Some(0.05),
301        top_n_logprobs: 0,
302        frequency_penalty: Some(0.1),
303        presence_penalty: Some(0.1),
304        max_len: Some(4096),
305        stop_toks: None,
306        logits_bias: None,
307        n_choices: 1,
308        dry_params: Some(DrySamplingParams::default()),
309    };
310
311    info!("Starting interactive loop with sampling params: {sampling_params:?}");
312    println!(
313        "{}{VISION_INTERACTIVE_HELP}{}",
314        "=".repeat(20),
315        "=".repeat(20)
316    );
317
318    // Set the handler to process exit
319    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
320
321    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
322        .expect("Failed to set CTRL-C handler for interactive mode");
323
324    'outer: loop {
325        // Set the handler to process exit
326        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
327
328        let mut prompt = String::new();
329        print!("> ");
330        io::stdout().flush().unwrap();
331        io::stdin()
332            .read_line(&mut prompt)
333            .expect("Failed to get input");
334
335        match prompt.as_str().trim() {
336            "" => continue,
337            HELP_CMD => {
338                println!(
339                    "{}{VISION_INTERACTIVE_HELP}{}",
340                    "=".repeat(20),
341                    "=".repeat(20)
342                );
343                continue;
344            }
345            EXIT_CMD => {
346                break;
347            }
348            prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
349                let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
350                    &["", a] => a.trim(),
351                    _ => {
352                        println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
353                        continue;
354                    }
355                };
356                info!("Set system message to `{parsed}`.");
357                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
358                user_message.insert("role".to_string(), Either::Left("system".to_string()));
359                user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
360                messages.push(user_message);
361                continue;
362            }
363            prompt if prompt.trim().starts_with(IMAGE_CMD) => {
364                let Some((url, message)) = parse_image_path_and_message(prompt.trim()) else {
365                    println!("Error: Adding an image message should be done with this format: `{IMAGE_CMD} path/to/image.jpg Describe what is in this image.`");
366                    continue;
367                };
368                let message = prefixer.prefix_image(images.len(), &message);
369
370                let image = util::parse_image_url(&url)
371                    .await
372                    .expect("Failed to read image from URL/path");
373                images.push(image);
374
375                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
376                user_message.insert("role".to_string(), Either::Left("user".to_string()));
377                user_message.insert(
378                    "content".to_string(),
379                    Either::Right(vec![
380                        IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
381                        IndexMap::from([
382                            ("type".to_string(), Value::String("text".to_string())),
383                            ("text".to_string(), Value::String(message)),
384                        ]),
385                    ]),
386                );
387                messages.push(user_message);
388            }
389            message => {
390                let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
391                user_message.insert("role".to_string(), Either::Left("user".to_string()));
392                user_message.insert("content".to_string(), Either::Left(message.to_string()));
393                messages.push(user_message);
394            }
395        };
396
397        // Set the handler to terminate all seqs, so allowing cancelling running
398        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
399
400        let request_messages = RequestMessage::VisionChat {
401            images: images.clone(),
402            messages: messages.clone(),
403        };
404
405        let (tx, mut rx) = channel(10_000);
406        let req = Request::Normal(NormalRequest {
407            id: mistralrs.next_request_id(),
408            messages: request_messages,
409            sampling_params: sampling_params.clone(),
410            response: tx,
411            return_logprobs: false,
412            is_streaming: true,
413            constraint: Constraint::None,
414            suffix: None,
415            tool_choice: None,
416            tools: None,
417            logits_processors: None,
418            return_raw_logits: false,
419            web_search_options: do_search.then(WebSearchOptions::default),
420        });
421        sender.send(req).await.unwrap();
422
423        let mut assistant_output = String::new();
424
425        let mut last_usage = None;
426        while let Some(resp) = rx.recv().await {
427            match resp {
428                Response::Chunk(chunk) => {
429                    last_usage = chunk.usage.clone();
430                    if let ChunkChoice {
431                        delta:
432                            Delta {
433                                content: Some(content),
434                                ..
435                            },
436                        finish_reason,
437                        ..
438                    } = &chunk.choices[0]
439                    {
440                        assistant_output.push_str(content);
441                        print!("{}", content);
442                        io::stdout().flush().unwrap();
443                        if finish_reason.is_some() {
444                            if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
445                                print!("...");
446                            }
447                            break;
448                        }
449                    }
450                }
451                Response::InternalError(e) => {
452                    error!("Got an internal error: {e:?}");
453                    break 'outer;
454                }
455                Response::ModelError(e, resp) => {
456                    error!("Got a model error: {e:?}, response: {resp:?}");
457                    break 'outer;
458                }
459                Response::ValidationError(e) => {
460                    error!("Got a validation error: {e:?}");
461                    break 'outer;
462                }
463                Response::Done(_) => unreachable!(),
464                Response::CompletionDone(_) => unreachable!(),
465                Response::CompletionModelError(_, _) => unreachable!(),
466                Response::CompletionChunk(_) => unreachable!(),
467                Response::ImageGeneration(_) => unreachable!(),
468                Response::Raw { .. } => unreachable!(),
469            }
470        }
471
472        if let Some(last_usage) = last_usage {
473            println!();
474            println!();
475            println!("Stats:");
476            println!(
477                "Prompt: {} tokens, {:.2} T/s",
478                last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
479            );
480            println!(
481                "Decode: {} tokens, {:.2} T/s",
482                last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
483            );
484        }
485        let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
486            IndexMap::new();
487        assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
488        assistant_message.insert("content".to_string(), Either::Left(assistant_output));
489        messages.push(assistant_message);
490        println!();
491    }
492}
493
494async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
495    let sender = mistralrs.get_sender().unwrap();
496
497    let diffusion_params = DiffusionGenerationParams::default();
498
499    info!("Starting interactive loop with generation params: {diffusion_params:?}");
500    println!(
501        "{}{DIFFUSION_INTERACTIVE_HELP}{}",
502        "=".repeat(20),
503        "=".repeat(20)
504    );
505
506    // Set the handler to process exit
507    *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
508
509    ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
510        .expect("Failed to set CTRL-C handler for interactive mode");
511
512    loop {
513        // Set the handler to process exit
514        *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
515
516        let mut prompt = String::new();
517        print!("> ");
518        io::stdout().flush().unwrap();
519        io::stdin()
520            .read_line(&mut prompt)
521            .expect("Failed to get input");
522
523        let prompt = match prompt.as_str().trim() {
524            "" => continue,
525            HELP_CMD => {
526                println!(
527                    "{}{DIFFUSION_INTERACTIVE_HELP}{}",
528                    "=".repeat(20),
529                    "=".repeat(20)
530                );
531                continue;
532            }
533            EXIT_CMD => {
534                break;
535            }
536            prompt => prompt.to_string(),
537        };
538
539        // Set the handler to terminate all seqs, so allowing cancelling running
540        *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
541
542        let (tx, mut rx) = channel(10_000);
543        let req = Request::Normal(NormalRequest {
544            id: 0,
545            messages: RequestMessage::ImageGeneration {
546                prompt: prompt.to_string(),
547                format: ImageGenerationResponseFormat::Url,
548                generation_params: diffusion_params.clone(),
549            },
550            sampling_params: SamplingParams::deterministic(),
551            response: tx,
552            return_logprobs: false,
553            is_streaming: false,
554            suffix: None,
555            constraint: Constraint::None,
556            tool_choice: None,
557            tools: None,
558            logits_processors: None,
559            return_raw_logits: false,
560            web_search_options: do_search.then(WebSearchOptions::default),
561        });
562
563        let start = Instant::now();
564        sender.send(req).await.unwrap();
565
566        let ResponseOk::ImageGeneration(response) = rx.recv().await.unwrap().as_result().unwrap()
567        else {
568            panic!("Got unexpected response type.")
569        };
570        let end = Instant::now();
571
572        let duration = end.duration_since(start).as_secs_f32();
573        let pixels_per_s = (diffusion_params.height * diffusion_params.width) as f32 / duration;
574
575        println!(
576            "Image generated can be found at: image is at `{}`. Took {duration:.2}s ({pixels_per_s:.2} pixels/s).",
577            response.data[0].url.as_ref().unwrap(),
578        );
579
580        println!();
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::parse_image_path_and_message;
587
588    #[test]
589    fn test_parse_image_with_unquoted_path_and_message() {
590        let input = r#"\image image.jpg What is this"#;
591        let result = parse_image_path_and_message(input);
592        assert_eq!(
593            result,
594            Some(("image.jpg".to_string(), "What is this".to_string()))
595        );
596    }
597
598    #[test]
599    fn test_parse_image_with_quoted_path_and_message() {
600        let input = r#"\image "image name.jpg" What is this?"#;
601        let result = parse_image_path_and_message(input);
602        assert_eq!(
603            result,
604            Some(("image name.jpg".to_string(), "What is this?".to_string()))
605        );
606    }
607
608    #[test]
609    fn test_parse_image_with_only_unquoted_path() {
610        let input = r#"\image image.jpg"#;
611        let result = parse_image_path_and_message(input);
612        assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
613    }
614
615    #[test]
616    fn test_parse_image_with_only_quoted_path() {
617        let input = r#"\image "image name.jpg""#;
618        let result = parse_image_path_and_message(input);
619        assert_eq!(result, Some(("image name.jpg".to_string(), "".to_string())));
620    }
621
622    #[test]
623    fn test_parse_image_with_extra_spaces() {
624        let input = r#"\image    "image with spaces.jpg"    This is a test message with spaces  "#;
625        let result = parse_image_path_and_message(input);
626        assert_eq!(
627            result,
628            Some((
629                "image with spaces.jpg".to_string(),
630                "This is a test message with spaces".to_string()
631            ))
632        );
633    }
634
635    #[test]
636    fn test_parse_image_with_no_message() {
637        let input = r#"\image "image.jpg""#;
638        let result = parse_image_path_and_message(input);
639        assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
640    }
641
642    #[test]
643    fn test_parse_image_missing_path() {
644        let input = r#"\image"#;
645        let result = parse_image_path_and_message(input);
646        assert_eq!(result, None);
647    }
648
649    #[test]
650    fn test_parse_image_invalid_command() {
651        let input = r#"\img "image.jpg" This should fail"#;
652        let result = parse_image_path_and_message(input);
653        assert_eq!(result, None);
654    }
655
656    #[test]
657    fn test_parse_image_with_non_image_text() {
658        let input = r#"Some random text without command"#;
659        let result = parse_image_path_and_message(input);
660        assert_eq!(result, None);
661    }
662
663    #[test]
664    fn test_parse_image_with_path_and_message_special_chars() {
665        let input = r#"\image "path with special chars @#$%^&*().jpg" This is a message with special chars !@#$%^&*()"#;
666        let result = parse_image_path_and_message(input);
667        assert_eq!(
668            result,
669            Some((
670                "path with special chars @#$%^&*().jpg".to_string(),
671                "This is a message with special chars !@#$%^&*()".to_string()
672            ))
673        );
674    }
675}