1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_core::{
4 ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
5 ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
6 Request, RequestMessage, Response, ResponseOk, SamplingParams, WebSearchOptions,
7 TERMINATE_ALL_NEXT_STEP,
8};
9use once_cell::sync::Lazy;
10use regex::Regex;
11use serde_json::Value;
12use std::{
13 io::{self, Write},
14 sync::{atomic::Ordering, Arc, Mutex},
15 time::Instant,
16};
17use tokio::sync::mpsc::channel;
18use tracing::{error, info};
19
20use crate::util;
21
22fn exit_handler() {
23 std::process::exit(0);
24}
25
26fn terminate_handler() {
27 TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst);
28}
29
30static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
31 Lazy::new(|| Mutex::new(&exit_handler));
32
33pub async fn interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool, do_search: bool) {
34 match mistralrs.get_model_category() {
35 ModelCategory::Text => text_interactive_mode(mistralrs, throughput, do_search).await,
36 ModelCategory::Vision { .. } => {
37 vision_interactive_mode(mistralrs, throughput, do_search).await
38 }
39 ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
40 }
41}
42
43const TEXT_INTERACTIVE_HELP: &str = r#"
44Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
45
46Commands:
47- `\help`: Display this message.
48- `\exit`: Quit interactive mode.
49- `\system <system message here>`:
50 Add a system message to the chat without running the model.
51 Ex: `\system Always respond as a pirate.`
52"#;
53
54const VISION_INTERACTIVE_HELP: &str = r#"
55Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
56
57To specify a message with an image, use the `\image` command detailed below.
58
59Commands:
60- `\help`: Display this message.
61- `\exit`: Quit interactive mode.
62- `\system <system message here>`:
63 Add a system message to the chat without running the model.
64 Ex: `\system Always respond as a pirate.`
65- `\image <image URL or local path here> <message here>`:
66 Add a message paired with an image. The image will be fed to the model as if it were the first item in this prompt.
67 You do not need to modify your prompt for specific models.
68 Ex: `\image path/to/image.jpg Describe what is in this image.`
69"#;
70
71const DIFFUSION_INTERACTIVE_HELP: &str = r#"
72Welcome to interactive mode! Because this model is a diffusion model, you can enter prompts and the model will generate an image.
73
74Commands:
75- `\help`: Display this message.
76- `\exit`: Quit interactive mode.
77"#;
78
79const HELP_CMD: &str = "\\help";
80const EXIT_CMD: &str = "\\exit";
81const SYSTEM_CMD: &str = "\\system";
82const IMAGE_CMD: &str = "\\image";
83
84async fn text_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool, do_search: bool) {
85 let sender = mistralrs.get_sender().unwrap();
86 let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
87
88 let sampling_params = SamplingParams {
89 temperature: Some(0.1),
90 top_k: Some(32),
91 top_p: Some(0.1),
92 min_p: Some(0.05),
93 top_n_logprobs: 0,
94 frequency_penalty: Some(0.1),
95 presence_penalty: Some(0.1),
96 max_len: Some(4096),
97 stop_toks: None,
98 logits_bias: None,
99 n_choices: 1,
100 dry_params: Some(DrySamplingParams::default()),
101 };
102
103 info!("Starting interactive loop with sampling params: {sampling_params:?}");
104 println!(
105 "{}{TEXT_INTERACTIVE_HELP}{}",
106 "=".repeat(20),
107 "=".repeat(20)
108 );
109
110 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
112
113 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
114 .expect("Failed to set CTRL-C handler for interactive mode");
115
116 'outer: loop {
117 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
119
120 let mut prompt = String::new();
121 print!("> ");
122 io::stdout().flush().unwrap();
123 io::stdin()
124 .read_line(&mut prompt)
125 .expect("Failed to get input");
126
127 match prompt.as_str().trim() {
128 "" => continue,
129 HELP_CMD => {
130 println!(
131 "{}{TEXT_INTERACTIVE_HELP}{}",
132 "=".repeat(20),
133 "=".repeat(20)
134 );
135 continue;
136 }
137 EXIT_CMD => {
138 break;
139 }
140 prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
141 let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
142 &["", a] => a.trim(),
143 _ => {
144 println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
145 continue;
146 }
147 };
148 info!("Set system message to `{parsed}`.");
149 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
150 user_message.insert("role".to_string(), Either::Left("system".to_string()));
151 user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
152 messages.push(user_message);
153 continue;
154 }
155 message => {
156 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
157 user_message.insert("role".to_string(), Either::Left("user".to_string()));
158 user_message.insert("content".to_string(), Either::Left(message.to_string()));
159 messages.push(user_message);
160 }
161 }
162
163 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
165
166 let request_messages = RequestMessage::Chat(messages.clone());
167
168 let (tx, mut rx) = channel(10_000);
169 let req = Request::Normal(NormalRequest {
170 id: mistralrs.next_request_id(),
171 messages: request_messages,
172 sampling_params: sampling_params.clone(),
173 response: tx,
174 return_logprobs: false,
175 is_streaming: true,
176 constraint: Constraint::None,
177 suffix: None,
178 adapters: None,
179 tool_choice: None,
180 tools: None,
181 logits_processors: None,
182 return_raw_logits: false,
183 web_search_options: do_search.then(WebSearchOptions::default),
184 });
185 sender.send(req).await.unwrap();
186
187 let mut assistant_output = String::new();
188
189 let start = Instant::now();
190 let mut toks = 0;
191 while let Some(resp) = rx.recv().await {
192 match resp {
193 Response::Chunk(chunk) => {
194 if let ChunkChoice {
195 delta:
196 Delta {
197 content: Some(content),
198 ..
199 },
200 finish_reason,
201 ..
202 } = &chunk.choices[0]
203 {
204 assistant_output.push_str(content);
205 print!("{}", content);
206 toks += 1usize; io::stdout().flush().unwrap();
208 if finish_reason.is_some() {
209 if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
210 print!("...");
211 }
212 break;
213 }
214 }
215 }
216 Response::InternalError(e) => {
217 error!("Got an internal error: {e:?}");
218 break 'outer;
219 }
220 Response::ModelError(e, resp) => {
221 error!("Got a model error: {e:?}, response: {resp:?}");
222 break 'outer;
223 }
224 Response::ValidationError(e) => {
225 error!("Got a validation error: {e:?}");
226 break 'outer;
227 }
228 Response::Done(_) => unreachable!(),
229 Response::CompletionDone(_) => unreachable!(),
230 Response::CompletionModelError(_, _) => unreachable!(),
231 Response::CompletionChunk(_) => unreachable!(),
232 Response::ImageGeneration(_) => unreachable!(),
233 Response::Raw { .. } => unreachable!(),
234 }
235 }
236 if throughput {
237 let time = Instant::now().duration_since(start).as_secs_f64();
238 println!();
239 info!("Average T/s: {}", toks as f64 / time);
240 }
241 let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
242 IndexMap::new();
243 assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
244 assistant_message.insert("content".to_string(), Either::Left(assistant_output));
245 messages.push(assistant_message);
246 println!();
247 }
248}
249
250fn parse_image_path_and_message(input: &str) -> Option<(String, String)> {
251 let re = Regex::new(r#"\\image\s+"([^"]+)"\s*(.*)|\\image\s+(\S+)\s*(.*)"#).unwrap();
253
254 if let Some(captures) = re.captures(input) {
255 if let Some(path) = captures.get(1) {
257 if let Some(message) = captures.get(2) {
258 return Some((
259 path.as_str().trim().to_string(),
260 message.as_str().trim().to_string(),
261 ));
262 }
263 } else if let Some(path) = captures.get(3) {
264 if let Some(message) = captures.get(4) {
265 return Some((
266 path.as_str().trim().to_string(),
267 message.as_str().trim().to_string(),
268 ));
269 }
270 }
271 }
272
273 None
274}
275
276async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, throughput: bool, do_search: bool) {
277 let sender = mistralrs.get_sender().unwrap();
278 let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
279 let mut images = Vec::new();
280
281 let prefixer = match &mistralrs.config().category {
282 ModelCategory::Text | ModelCategory::Diffusion => {
283 panic!("`add_image_message` expects a vision model.")
284 }
285 ModelCategory::Vision {
286 has_conv2d: _,
287 prefixer,
288 } => prefixer,
289 };
290
291 let sampling_params = SamplingParams {
292 temperature: Some(0.1),
293 top_k: Some(32),
294 top_p: Some(0.1),
295 min_p: Some(0.05),
296 top_n_logprobs: 0,
297 frequency_penalty: Some(0.1),
298 presence_penalty: Some(0.1),
299 max_len: Some(4096),
300 stop_toks: None,
301 logits_bias: None,
302 n_choices: 1,
303 dry_params: Some(DrySamplingParams::default()),
304 };
305
306 info!("Starting interactive loop with sampling params: {sampling_params:?}");
307 println!(
308 "{}{VISION_INTERACTIVE_HELP}{}",
309 "=".repeat(20),
310 "=".repeat(20)
311 );
312
313 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
315
316 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
317 .expect("Failed to set CTRL-C handler for interactive mode");
318
319 'outer: loop {
320 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
322
323 let mut prompt = String::new();
324 print!("> ");
325 io::stdout().flush().unwrap();
326 io::stdin()
327 .read_line(&mut prompt)
328 .expect("Failed to get input");
329
330 match prompt.as_str().trim() {
331 "" => continue,
332 HELP_CMD => {
333 println!(
334 "{}{VISION_INTERACTIVE_HELP}{}",
335 "=".repeat(20),
336 "=".repeat(20)
337 );
338 continue;
339 }
340 EXIT_CMD => {
341 break;
342 }
343 prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
344 let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
345 &["", a] => a.trim(),
346 _ => {
347 println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
348 continue;
349 }
350 };
351 info!("Set system message to `{parsed}`.");
352 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
353 user_message.insert("role".to_string(), Either::Left("system".to_string()));
354 user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
355 messages.push(user_message);
356 continue;
357 }
358 prompt if prompt.trim().starts_with(IMAGE_CMD) => {
359 let Some((url, message)) = parse_image_path_and_message(prompt.trim()) else {
360 println!("Error: Adding an image message should be done with this format: `{IMAGE_CMD} path/to/image.jpg Describe what is in this image.`");
361 continue;
362 };
363 let message = prefixer.prefix_image(images.len(), &message);
364
365 let image = util::parse_image_url(&url)
366 .await
367 .expect("Failed to read image from URL/path");
368 images.push(image);
369
370 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
371 user_message.insert("role".to_string(), Either::Left("user".to_string()));
372 user_message.insert(
373 "content".to_string(),
374 Either::Right(vec![
375 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
376 IndexMap::from([
377 ("type".to_string(), Value::String("text".to_string())),
378 ("text".to_string(), Value::String(message)),
379 ]),
380 ]),
381 );
382 messages.push(user_message);
383 }
384 message => {
385 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
386 user_message.insert("role".to_string(), Either::Left("user".to_string()));
387 user_message.insert("content".to_string(), Either::Left(message.to_string()));
388 messages.push(user_message);
389 }
390 };
391
392 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
394
395 let request_messages = RequestMessage::VisionChat {
396 images: images.clone(),
397 messages: messages.clone(),
398 };
399
400 let (tx, mut rx) = channel(10_000);
401 let req = Request::Normal(NormalRequest {
402 id: mistralrs.next_request_id(),
403 messages: request_messages,
404 sampling_params: sampling_params.clone(),
405 response: tx,
406 return_logprobs: false,
407 is_streaming: true,
408 constraint: Constraint::None,
409 suffix: None,
410 adapters: None,
411 tool_choice: None,
412 tools: None,
413 logits_processors: None,
414 return_raw_logits: false,
415 web_search_options: do_search.then(WebSearchOptions::default),
416 });
417 sender.send(req).await.unwrap();
418
419 let mut assistant_output = String::new();
420
421 let start = Instant::now();
422 let mut toks = 0;
423 while let Some(resp) = rx.recv().await {
424 match resp {
425 Response::Chunk(chunk) => {
426 if let ChunkChoice {
427 delta:
428 Delta {
429 content: Some(content),
430 ..
431 },
432 finish_reason,
433 ..
434 } = &chunk.choices[0]
435 {
436 assistant_output.push_str(content);
437 print!("{}", content);
438 toks += 1usize; io::stdout().flush().unwrap();
440 if finish_reason.is_some() {
441 if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
442 print!("...");
443 }
444 break;
445 }
446 }
447 }
448 Response::InternalError(e) => {
449 error!("Got an internal error: {e:?}");
450 break 'outer;
451 }
452 Response::ModelError(e, resp) => {
453 error!("Got a model error: {e:?}, response: {resp:?}");
454 break 'outer;
455 }
456 Response::ValidationError(e) => {
457 error!("Got a validation error: {e:?}");
458 break 'outer;
459 }
460 Response::Done(_) => unreachable!(),
461 Response::CompletionDone(_) => unreachable!(),
462 Response::CompletionModelError(_, _) => unreachable!(),
463 Response::CompletionChunk(_) => unreachable!(),
464 Response::ImageGeneration(_) => unreachable!(),
465 Response::Raw { .. } => unreachable!(),
466 }
467 }
468 if throughput {
469 let time = Instant::now().duration_since(start).as_secs_f64();
470 println!();
471 info!("Average T/s: {}", toks as f64 / time);
472 }
473 let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
474 IndexMap::new();
475 assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
476 assistant_message.insert("content".to_string(), Either::Left(assistant_output));
477 messages.push(assistant_message);
478 println!();
479 }
480}
481
482async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
483 let sender = mistralrs.get_sender().unwrap();
484
485 let diffusion_params = DiffusionGenerationParams::default();
486
487 info!("Starting interactive loop with generation params: {diffusion_params:?}");
488 println!(
489 "{}{DIFFUSION_INTERACTIVE_HELP}{}",
490 "=".repeat(20),
491 "=".repeat(20)
492 );
493
494 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
496
497 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
498 .expect("Failed to set CTRL-C handler for interactive mode");
499
500 loop {
501 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
503
504 let mut prompt = String::new();
505 print!("> ");
506 io::stdout().flush().unwrap();
507 io::stdin()
508 .read_line(&mut prompt)
509 .expect("Failed to get input");
510
511 let prompt = match prompt.as_str().trim() {
512 "" => continue,
513 HELP_CMD => {
514 println!(
515 "{}{DIFFUSION_INTERACTIVE_HELP}{}",
516 "=".repeat(20),
517 "=".repeat(20)
518 );
519 continue;
520 }
521 EXIT_CMD => {
522 break;
523 }
524 prompt => prompt.to_string(),
525 };
526
527 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
529
530 let (tx, mut rx) = channel(10_000);
531 let req = Request::Normal(NormalRequest {
532 id: 0,
533 messages: RequestMessage::ImageGeneration {
534 prompt: prompt.to_string(),
535 format: ImageGenerationResponseFormat::Url,
536 generation_params: diffusion_params.clone(),
537 },
538 sampling_params: SamplingParams::deterministic(),
539 response: tx,
540 return_logprobs: false,
541 is_streaming: false,
542 suffix: None,
543 constraint: Constraint::None,
544 adapters: None,
545 tool_choice: None,
546 tools: None,
547 logits_processors: None,
548 return_raw_logits: false,
549 web_search_options: do_search.then(WebSearchOptions::default),
550 });
551
552 let start = Instant::now();
553 sender.send(req).await.unwrap();
554
555 let ResponseOk::ImageGeneration(response) = rx.recv().await.unwrap().as_result().unwrap()
556 else {
557 panic!("Got unexpected response type.")
558 };
559 let end = Instant::now();
560
561 let duration = end.duration_since(start).as_secs_f32();
562 let pixels_per_s = (diffusion_params.height * diffusion_params.width) as f32 / duration;
563
564 println!(
565 "Image generated can be found at: image is at `{}`. Took {duration:.2}s ({pixels_per_s:.2} pixels/s).",
566 response.data[0].url.as_ref().unwrap(),
567 );
568
569 println!();
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::parse_image_path_and_message;
576
577 #[test]
578 fn test_parse_image_with_unquoted_path_and_message() {
579 let input = r#"\image image.jpg What is this"#;
580 let result = parse_image_path_and_message(input);
581 assert_eq!(
582 result,
583 Some(("image.jpg".to_string(), "What is this".to_string()))
584 );
585 }
586
587 #[test]
588 fn test_parse_image_with_quoted_path_and_message() {
589 let input = r#"\image "image name.jpg" What is this?"#;
590 let result = parse_image_path_and_message(input);
591 assert_eq!(
592 result,
593 Some(("image name.jpg".to_string(), "What is this?".to_string()))
594 );
595 }
596
597 #[test]
598 fn test_parse_image_with_only_unquoted_path() {
599 let input = r#"\image image.jpg"#;
600 let result = parse_image_path_and_message(input);
601 assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
602 }
603
604 #[test]
605 fn test_parse_image_with_only_quoted_path() {
606 let input = r#"\image "image name.jpg""#;
607 let result = parse_image_path_and_message(input);
608 assert_eq!(result, Some(("image name.jpg".to_string(), "".to_string())));
609 }
610
611 #[test]
612 fn test_parse_image_with_extra_spaces() {
613 let input = r#"\image "image with spaces.jpg" This is a test message with spaces "#;
614 let result = parse_image_path_and_message(input);
615 assert_eq!(
616 result,
617 Some((
618 "image with spaces.jpg".to_string(),
619 "This is a test message with spaces".to_string()
620 ))
621 );
622 }
623
624 #[test]
625 fn test_parse_image_with_no_message() {
626 let input = r#"\image "image.jpg""#;
627 let result = parse_image_path_and_message(input);
628 assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
629 }
630
631 #[test]
632 fn test_parse_image_missing_path() {
633 let input = r#"\image"#;
634 let result = parse_image_path_and_message(input);
635 assert_eq!(result, None);
636 }
637
638 #[test]
639 fn test_parse_image_invalid_command() {
640 let input = r#"\img "image.jpg" This should fail"#;
641 let result = parse_image_path_and_message(input);
642 assert_eq!(result, None);
643 }
644
645 #[test]
646 fn test_parse_image_with_non_image_text() {
647 let input = r#"Some random text without command"#;
648 let result = parse_image_path_and_message(input);
649 assert_eq!(result, None);
650 }
651
652 #[test]
653 fn test_parse_image_with_path_and_message_special_chars() {
654 let input = r#"\image "path with special chars @#$%^&*().jpg" This is a message with special chars !@#$%^&*()"#;
655 let result = parse_image_path_and_message(input);
656 assert_eq!(
657 result,
658 Some((
659 "path with special chars @#$%^&*().jpg".to_string(),
660 "This is a message with special chars !@#$%^&*()".to_string()
661 ))
662 );
663 }
664}