1use either::Either;
2use indexmap::IndexMap;
3use mistralrs_core::{
4 ChunkChoice, Constraint, Delta, DiffusionGenerationParams, DrySamplingParams,
5 ImageGenerationResponseFormat, MessageContent, MistralRs, ModelCategory, NormalRequest,
6 Request, RequestMessage, Response, ResponseOk, SamplingParams, WebSearchOptions,
7 TERMINATE_ALL_NEXT_STEP,
8};
9use once_cell::sync::Lazy;
10use regex::Regex;
11use serde_json::Value;
12use std::{
13 io::{self, Write},
14 sync::{atomic::Ordering, Arc, Mutex},
15 time::Instant,
16};
17use tokio::sync::mpsc::channel;
18use tracing::{error, info};
19
20use crate::util;
21
22fn exit_handler() {
23 std::process::exit(0);
24}
25
26fn terminate_handler() {
27 TERMINATE_ALL_NEXT_STEP.store(true, Ordering::SeqCst);
28}
29
30static CTRLC_HANDLER: Lazy<Mutex<&'static (dyn Fn() + Sync)>> =
31 Lazy::new(|| Mutex::new(&exit_handler));
32
33pub async fn interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
34 match mistralrs.get_model_category() {
35 ModelCategory::Text => text_interactive_mode(mistralrs, do_search).await,
36 ModelCategory::Vision { .. } => vision_interactive_mode(mistralrs, do_search).await,
37 ModelCategory::Diffusion => diffusion_interactive_mode(mistralrs, do_search).await,
38 }
39}
40
41const TEXT_INTERACTIVE_HELP: &str = r#"
42Welcome to interactive mode! Because this model is a text model, you can enter prompts and chat with the model.
43
44Commands:
45- `\help`: Display this message.
46- `\exit`: Quit interactive mode.
47- `\system <system message here>`:
48 Add a system message to the chat without running the model.
49 Ex: `\system Always respond as a pirate.`
50"#;
51
52const VISION_INTERACTIVE_HELP: &str = r#"
53Welcome to interactive mode! Because this model is a vision model, you can enter prompts and chat with the model.
54
55To specify a message with an image, use the `\image` command detailed below.
56
57Commands:
58- `\help`: Display this message.
59- `\exit`: Quit interactive mode.
60- `\system <system message here>`:
61 Add a system message to the chat without running the model.
62 Ex: `\system Always respond as a pirate.`
63- `\image <image URL or local path here> <message here>`:
64 Add a message paired with an image. The image will be fed to the model as if it were the first item in this prompt.
65 You do not need to modify your prompt for specific models.
66 Ex: `\image path/to/image.jpg Describe what is in this image.`
67"#;
68
69const DIFFUSION_INTERACTIVE_HELP: &str = r#"
70Welcome to interactive mode! Because this model is a diffusion model, you can enter prompts and the model will generate an image.
71
72Commands:
73- `\help`: Display this message.
74- `\exit`: Quit interactive mode.
75"#;
76
77const HELP_CMD: &str = "\\help";
78const EXIT_CMD: &str = "\\exit";
79const SYSTEM_CMD: &str = "\\system";
80const IMAGE_CMD: &str = "\\image";
81
82async fn text_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
83 let sender = mistralrs.get_sender().unwrap();
84 let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
85
86 let sampling_params = SamplingParams {
87 temperature: Some(0.1),
88 top_k: Some(32),
89 top_p: Some(0.1),
90 min_p: Some(0.05),
91 top_n_logprobs: 0,
92 frequency_penalty: Some(0.1),
93 presence_penalty: Some(0.1),
94 max_len: Some(4096),
95 stop_toks: None,
96 logits_bias: None,
97 n_choices: 1,
98 dry_params: Some(DrySamplingParams::default()),
99 };
100
101 info!("Starting interactive loop with sampling params: {sampling_params:?}");
102 println!(
103 "{}{TEXT_INTERACTIVE_HELP}{}",
104 "=".repeat(20),
105 "=".repeat(20)
106 );
107
108 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
110
111 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
112 .expect("Failed to set CTRL-C handler for interactive mode");
113
114 'outer: loop {
115 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
117
118 let mut prompt = String::new();
119 print!("> ");
120 io::stdout().flush().unwrap();
121 io::stdin()
122 .read_line(&mut prompt)
123 .expect("Failed to get input");
124
125 match prompt.as_str().trim() {
126 "" => continue,
127 HELP_CMD => {
128 println!(
129 "{}{TEXT_INTERACTIVE_HELP}{}",
130 "=".repeat(20),
131 "=".repeat(20)
132 );
133 continue;
134 }
135 EXIT_CMD => {
136 break;
137 }
138 prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
139 let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
140 &["", a] => a.trim(),
141 _ => {
142 println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
143 continue;
144 }
145 };
146 info!("Set system message to `{parsed}`.");
147 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
148 user_message.insert("role".to_string(), Either::Left("system".to_string()));
149 user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
150 messages.push(user_message);
151 continue;
152 }
153 message => {
154 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
155 user_message.insert("role".to_string(), Either::Left("user".to_string()));
156 user_message.insert("content".to_string(), Either::Left(message.to_string()));
157 messages.push(user_message);
158 }
159 }
160
161 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
163
164 let request_messages = RequestMessage::Chat(messages.clone());
165
166 let (tx, mut rx) = channel(10_000);
167 let req = Request::Normal(NormalRequest {
168 id: mistralrs.next_request_id(),
169 messages: request_messages,
170 sampling_params: sampling_params.clone(),
171 response: tx,
172 return_logprobs: false,
173 is_streaming: true,
174 constraint: Constraint::None,
175 suffix: None,
176 tool_choice: None,
177 tools: None,
178 logits_processors: None,
179 return_raw_logits: false,
180 web_search_options: do_search.then(WebSearchOptions::default),
181 });
182 sender.send(req).await.unwrap();
183
184 let mut assistant_output = String::new();
185
186 let mut last_usage = None;
187 while let Some(resp) = rx.recv().await {
188 match resp {
189 Response::Chunk(chunk) => {
190 last_usage = chunk.usage.clone();
191 if let ChunkChoice {
192 delta:
193 Delta {
194 content: Some(content),
195 ..
196 },
197 finish_reason,
198 ..
199 } = &chunk.choices[0]
200 {
201 assistant_output.push_str(content);
202 print!("{}", content);
203 io::stdout().flush().unwrap();
204 if finish_reason.is_some() {
205 if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
206 print!("...");
207 }
208 break;
209 }
210 }
211 }
212 Response::InternalError(e) => {
213 error!("Got an internal error: {e:?}");
214 break 'outer;
215 }
216 Response::ModelError(e, resp) => {
217 error!("Got a model error: {e:?}, response: {resp:?}");
218 break 'outer;
219 }
220 Response::ValidationError(e) => {
221 error!("Got a validation error: {e:?}");
222 break 'outer;
223 }
224 Response::Done(_) => unreachable!(),
225 Response::CompletionDone(_) => unreachable!(),
226 Response::CompletionModelError(_, _) => unreachable!(),
227 Response::CompletionChunk(_) => unreachable!(),
228 Response::ImageGeneration(_) => unreachable!(),
229 Response::Raw { .. } => unreachable!(),
230 }
231 }
232
233 if let Some(last_usage) = last_usage {
234 println!();
235 println!();
236 println!("Stats:");
237 println!(
238 "Prompt: {} tokens, {:.2} T/s",
239 last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
240 );
241 println!(
242 "Decode: {} tokens, {:.2} T/s",
243 last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
244 );
245 }
246 let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
247 IndexMap::new();
248 assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
249 assistant_message.insert("content".to_string(), Either::Left(assistant_output));
250 messages.push(assistant_message);
251 println!();
252 }
253}
254
255fn parse_image_path_and_message(input: &str) -> Option<(String, String)> {
256 let re = Regex::new(r#"\\image\s+"([^"]+)"\s*(.*)|\\image\s+(\S+)\s*(.*)"#).unwrap();
258
259 if let Some(captures) = re.captures(input) {
260 if let Some(path) = captures.get(1) {
262 if let Some(message) = captures.get(2) {
263 return Some((
264 path.as_str().trim().to_string(),
265 message.as_str().trim().to_string(),
266 ));
267 }
268 } else if let Some(path) = captures.get(3) {
269 if let Some(message) = captures.get(4) {
270 return Some((
271 path.as_str().trim().to_string(),
272 message.as_str().trim().to_string(),
273 ));
274 }
275 }
276 }
277
278 None
279}
280
281async fn vision_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
282 let sender = mistralrs.get_sender().unwrap();
283 let mut messages: Vec<IndexMap<String, MessageContent>> = Vec::new();
284 let mut images = Vec::new();
285
286 let prefixer = match &mistralrs.config().category {
287 ModelCategory::Text | ModelCategory::Diffusion => {
288 panic!("`add_image_message` expects a vision model.")
289 }
290 ModelCategory::Vision {
291 has_conv2d: _,
292 prefixer,
293 } => prefixer,
294 };
295
296 let sampling_params = SamplingParams {
297 temperature: Some(0.1),
298 top_k: Some(32),
299 top_p: Some(0.1),
300 min_p: Some(0.05),
301 top_n_logprobs: 0,
302 frequency_penalty: Some(0.1),
303 presence_penalty: Some(0.1),
304 max_len: Some(4096),
305 stop_toks: None,
306 logits_bias: None,
307 n_choices: 1,
308 dry_params: Some(DrySamplingParams::default()),
309 };
310
311 info!("Starting interactive loop with sampling params: {sampling_params:?}");
312 println!(
313 "{}{VISION_INTERACTIVE_HELP}{}",
314 "=".repeat(20),
315 "=".repeat(20)
316 );
317
318 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
320
321 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
322 .expect("Failed to set CTRL-C handler for interactive mode");
323
324 'outer: loop {
325 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
327
328 let mut prompt = String::new();
329 print!("> ");
330 io::stdout().flush().unwrap();
331 io::stdin()
332 .read_line(&mut prompt)
333 .expect("Failed to get input");
334
335 match prompt.as_str().trim() {
336 "" => continue,
337 HELP_CMD => {
338 println!(
339 "{}{VISION_INTERACTIVE_HELP}{}",
340 "=".repeat(20),
341 "=".repeat(20)
342 );
343 continue;
344 }
345 EXIT_CMD => {
346 break;
347 }
348 prompt if prompt.trim().starts_with(SYSTEM_CMD) => {
349 let parsed = match &prompt.split(SYSTEM_CMD).collect::<Vec<_>>()[..] {
350 &["", a] => a.trim(),
351 _ => {
352 println!("Error: Setting the system command should be done with this format: `{SYSTEM_CMD} This is a system message.`");
353 continue;
354 }
355 };
356 info!("Set system message to `{parsed}`.");
357 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
358 user_message.insert("role".to_string(), Either::Left("system".to_string()));
359 user_message.insert("content".to_string(), Either::Left(parsed.to_string()));
360 messages.push(user_message);
361 continue;
362 }
363 prompt if prompt.trim().starts_with(IMAGE_CMD) => {
364 let Some((url, message)) = parse_image_path_and_message(prompt.trim()) else {
365 println!("Error: Adding an image message should be done with this format: `{IMAGE_CMD} path/to/image.jpg Describe what is in this image.`");
366 continue;
367 };
368 let message = prefixer.prefix_image(images.len(), &message);
369
370 let image = util::parse_image_url(&url)
371 .await
372 .expect("Failed to read image from URL/path");
373 images.push(image);
374
375 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
376 user_message.insert("role".to_string(), Either::Left("user".to_string()));
377 user_message.insert(
378 "content".to_string(),
379 Either::Right(vec![
380 IndexMap::from([("type".to_string(), Value::String("image".to_string()))]),
381 IndexMap::from([
382 ("type".to_string(), Value::String("text".to_string())),
383 ("text".to_string(), Value::String(message)),
384 ]),
385 ]),
386 );
387 messages.push(user_message);
388 }
389 message => {
390 let mut user_message: IndexMap<String, MessageContent> = IndexMap::new();
391 user_message.insert("role".to_string(), Either::Left("user".to_string()));
392 user_message.insert("content".to_string(), Either::Left(message.to_string()));
393 messages.push(user_message);
394 }
395 };
396
397 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
399
400 let request_messages = RequestMessage::VisionChat {
401 images: images.clone(),
402 messages: messages.clone(),
403 };
404
405 let (tx, mut rx) = channel(10_000);
406 let req = Request::Normal(NormalRequest {
407 id: mistralrs.next_request_id(),
408 messages: request_messages,
409 sampling_params: sampling_params.clone(),
410 response: tx,
411 return_logprobs: false,
412 is_streaming: true,
413 constraint: Constraint::None,
414 suffix: None,
415 tool_choice: None,
416 tools: None,
417 logits_processors: None,
418 return_raw_logits: false,
419 web_search_options: do_search.then(WebSearchOptions::default),
420 });
421 sender.send(req).await.unwrap();
422
423 let mut assistant_output = String::new();
424
425 let mut last_usage = None;
426 while let Some(resp) = rx.recv().await {
427 match resp {
428 Response::Chunk(chunk) => {
429 last_usage = chunk.usage.clone();
430 if let ChunkChoice {
431 delta:
432 Delta {
433 content: Some(content),
434 ..
435 },
436 finish_reason,
437 ..
438 } = &chunk.choices[0]
439 {
440 assistant_output.push_str(content);
441 print!("{}", content);
442 io::stdout().flush().unwrap();
443 if finish_reason.is_some() {
444 if matches!(finish_reason.as_ref().unwrap().as_str(), "length") {
445 print!("...");
446 }
447 break;
448 }
449 }
450 }
451 Response::InternalError(e) => {
452 error!("Got an internal error: {e:?}");
453 break 'outer;
454 }
455 Response::ModelError(e, resp) => {
456 error!("Got a model error: {e:?}, response: {resp:?}");
457 break 'outer;
458 }
459 Response::ValidationError(e) => {
460 error!("Got a validation error: {e:?}");
461 break 'outer;
462 }
463 Response::Done(_) => unreachable!(),
464 Response::CompletionDone(_) => unreachable!(),
465 Response::CompletionModelError(_, _) => unreachable!(),
466 Response::CompletionChunk(_) => unreachable!(),
467 Response::ImageGeneration(_) => unreachable!(),
468 Response::Raw { .. } => unreachable!(),
469 }
470 }
471
472 if let Some(last_usage) = last_usage {
473 println!();
474 println!();
475 println!("Stats:");
476 println!(
477 "Prompt: {} tokens, {:.2} T/s",
478 last_usage.prompt_tokens, last_usage.avg_prompt_tok_per_sec
479 );
480 println!(
481 "Decode: {} tokens, {:.2} T/s",
482 last_usage.completion_tokens, last_usage.avg_compl_tok_per_sec
483 );
484 }
485 let mut assistant_message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
486 IndexMap::new();
487 assistant_message.insert("role".to_string(), Either::Left("assistant".to_string()));
488 assistant_message.insert("content".to_string(), Either::Left(assistant_output));
489 messages.push(assistant_message);
490 println!();
491 }
492}
493
494async fn diffusion_interactive_mode(mistralrs: Arc<MistralRs>, do_search: bool) {
495 let sender = mistralrs.get_sender().unwrap();
496
497 let diffusion_params = DiffusionGenerationParams::default();
498
499 info!("Starting interactive loop with generation params: {diffusion_params:?}");
500 println!(
501 "{}{DIFFUSION_INTERACTIVE_HELP}{}",
502 "=".repeat(20),
503 "=".repeat(20)
504 );
505
506 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
508
509 ctrlc::set_handler(move || CTRLC_HANDLER.lock().unwrap()())
510 .expect("Failed to set CTRL-C handler for interactive mode");
511
512 loop {
513 *CTRLC_HANDLER.lock().unwrap() = &exit_handler;
515
516 let mut prompt = String::new();
517 print!("> ");
518 io::stdout().flush().unwrap();
519 io::stdin()
520 .read_line(&mut prompt)
521 .expect("Failed to get input");
522
523 let prompt = match prompt.as_str().trim() {
524 "" => continue,
525 HELP_CMD => {
526 println!(
527 "{}{DIFFUSION_INTERACTIVE_HELP}{}",
528 "=".repeat(20),
529 "=".repeat(20)
530 );
531 continue;
532 }
533 EXIT_CMD => {
534 break;
535 }
536 prompt => prompt.to_string(),
537 };
538
539 *CTRLC_HANDLER.lock().unwrap() = &terminate_handler;
541
542 let (tx, mut rx) = channel(10_000);
543 let req = Request::Normal(NormalRequest {
544 id: 0,
545 messages: RequestMessage::ImageGeneration {
546 prompt: prompt.to_string(),
547 format: ImageGenerationResponseFormat::Url,
548 generation_params: diffusion_params.clone(),
549 },
550 sampling_params: SamplingParams::deterministic(),
551 response: tx,
552 return_logprobs: false,
553 is_streaming: false,
554 suffix: None,
555 constraint: Constraint::None,
556 tool_choice: None,
557 tools: None,
558 logits_processors: None,
559 return_raw_logits: false,
560 web_search_options: do_search.then(WebSearchOptions::default),
561 });
562
563 let start = Instant::now();
564 sender.send(req).await.unwrap();
565
566 let ResponseOk::ImageGeneration(response) = rx.recv().await.unwrap().as_result().unwrap()
567 else {
568 panic!("Got unexpected response type.")
569 };
570 let end = Instant::now();
571
572 let duration = end.duration_since(start).as_secs_f32();
573 let pixels_per_s = (diffusion_params.height * diffusion_params.width) as f32 / duration;
574
575 println!(
576 "Image generated can be found at: image is at `{}`. Took {duration:.2}s ({pixels_per_s:.2} pixels/s).",
577 response.data[0].url.as_ref().unwrap(),
578 );
579
580 println!();
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::parse_image_path_and_message;
587
588 #[test]
589 fn test_parse_image_with_unquoted_path_and_message() {
590 let input = r#"\image image.jpg What is this"#;
591 let result = parse_image_path_and_message(input);
592 assert_eq!(
593 result,
594 Some(("image.jpg".to_string(), "What is this".to_string()))
595 );
596 }
597
598 #[test]
599 fn test_parse_image_with_quoted_path_and_message() {
600 let input = r#"\image "image name.jpg" What is this?"#;
601 let result = parse_image_path_and_message(input);
602 assert_eq!(
603 result,
604 Some(("image name.jpg".to_string(), "What is this?".to_string()))
605 );
606 }
607
608 #[test]
609 fn test_parse_image_with_only_unquoted_path() {
610 let input = r#"\image image.jpg"#;
611 let result = parse_image_path_and_message(input);
612 assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
613 }
614
615 #[test]
616 fn test_parse_image_with_only_quoted_path() {
617 let input = r#"\image "image name.jpg""#;
618 let result = parse_image_path_and_message(input);
619 assert_eq!(result, Some(("image name.jpg".to_string(), "".to_string())));
620 }
621
622 #[test]
623 fn test_parse_image_with_extra_spaces() {
624 let input = r#"\image "image with spaces.jpg" This is a test message with spaces "#;
625 let result = parse_image_path_and_message(input);
626 assert_eq!(
627 result,
628 Some((
629 "image with spaces.jpg".to_string(),
630 "This is a test message with spaces".to_string()
631 ))
632 );
633 }
634
635 #[test]
636 fn test_parse_image_with_no_message() {
637 let input = r#"\image "image.jpg""#;
638 let result = parse_image_path_and_message(input);
639 assert_eq!(result, Some(("image.jpg".to_string(), "".to_string())));
640 }
641
642 #[test]
643 fn test_parse_image_missing_path() {
644 let input = r#"\image"#;
645 let result = parse_image_path_and_message(input);
646 assert_eq!(result, None);
647 }
648
649 #[test]
650 fn test_parse_image_invalid_command() {
651 let input = r#"\img "image.jpg" This should fail"#;
652 let result = parse_image_path_and_message(input);
653 assert_eq!(result, None);
654 }
655
656 #[test]
657 fn test_parse_image_with_non_image_text() {
658 let input = r#"Some random text without command"#;
659 let result = parse_image_path_and_message(input);
660 assert_eq!(result, None);
661 }
662
663 #[test]
664 fn test_parse_image_with_path_and_message_special_chars() {
665 let input = r#"\image "path with special chars @#$%^&*().jpg" This is a message with special chars !@#$%^&*()"#;
666 let result = parse_image_path_and_message(input);
667 assert_eq!(
668 result,
669 Some((
670 "path with special chars @#$%^&*().jpg".to_string(),
671 "This is a message with special chars !@#$%^&*()".to_string()
672 ))
673 );
674 }
675}