1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10
11use crate::{
12 device_map::DeviceMapper,
13 pipeline::{
14 text_models_inputs_processor::{
15 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16 },
17 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18 },
19 sequence::Sequence,
20 vision_models::ModelInputs,
21};
22
23use crate::vision_models::{
24 image_processor::{ImagePreProcessor, PreprocessedImages},
25 preprocessor_config::PreProcessorConfig,
26 processor_config::ProcessorConfig,
27};
28
29use super::MiniCpmOSpecificArgs;
30
31const DEFAULT_MAX_SLICE_NUMS: usize = 9;
32const DEFAULT_SCALE_RESOLUTION: usize = 448;
33const DEFAULT_PATCH_SIZE: usize = 14;
34const DEFAULT_IMAGE_FEATURE_SIZE: usize = 64;
35const DEFAULT_IM_START_TOKEN: &str = "<image>";
36const DEFAULT_IM_END_TOKEN: &str = "</image>";
37const DEFAULT_IM_ID_START: &str = "<image_id>";
38const DEFAULT_IM_ID_END: &str = "</image_id>";
39const DEFAULT_SLICE_START_TOKEN: &str = "<slice>";
40const DEFAULT_SLICE_END_TOKEN: &str = "</slice>";
41const DEFAULT_UNK_TOKEN: &str = "<unk>";
42const DEFAULT_USE_IMAGE_ID: bool = false;
43const DEFAULT_SLICE_MODE: bool = true;
44
45pub struct MiniCpmOImageProcessor {
46 config: PreProcessorConfig,
47}
48
49pub struct MiniCpmOProcessor {
50 preprocessor_config: PreProcessorConfig,
51}
52
53impl MiniCpmOProcessor {
54 pub fn new(
55 _config: ProcessorConfig,
56 preprocessor_config: PreProcessorConfig,
57 _max_edge: Option<u32>,
58 ) -> Self {
59 Self {
60 preprocessor_config,
61 }
62 }
63}
64
65impl Processor for MiniCpmOProcessor {
66 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
67 Arc::new(MiniCpmOImageProcessor {
68 config: self.preprocessor_config.clone(),
69 })
70 }
71
72 fn get_special_tokens(&self) -> &[&'static str] {
73 &[
74 DEFAULT_IM_START_TOKEN,
75 DEFAULT_IM_END_TOKEN,
76 DEFAULT_SLICE_START_TOKEN,
77 DEFAULT_SLICE_END_TOKEN,
78 DEFAULT_UNK_TOKEN,
79 ]
80 }
81
82 fn template_action(&self) -> MessagesAction {
83 MessagesAction::FlattenOnlyText
84 }
85}
86
87impl InputsProcessor for MiniCpmOImageProcessor {
88 fn get_type(&self) -> InputsProcessorType {
89 InputsProcessorType::Vision
90 }
91 fn process_inputs(
92 &self,
93 tokenizer: Option<Arc<Tokenizer>>,
94 input_seqs: &mut [&mut Sequence],
95 is_prompt: bool,
96 is_xlora: bool,
97 device: &Device,
98 no_kv_cache: bool,
99 last_n_context_len: Option<(usize, usize)>,
100 return_raw_logits: bool,
101 other_config: Option<Arc<dyn Any>>,
102 mut paged_attn_metadata: Option<PagedAttentionMeta>,
103 mapper: Option<&dyn DeviceMapper>,
104 ) -> anyhow::Result<InputProcessorOutput> {
105 if is_xlora {
106 return Err(anyhow::Error::msg(
107 "Cannot make inputs for X-LoRA vision model.",
108 ));
109 }
110 if no_kv_cache {
111 return Err(anyhow::Error::msg("Vision model must have kv cache."));
112 }
113 let Some(tokenizer) = tokenizer else {
114 return Err(anyhow::Error::msg(
115 "MiniCpmOImageProcessor requires a specified tokenizer.",
116 ));
117 };
118
119 let config = other_config.expect("Need a PreProcessorConfig config.");
120 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
121
122 let has_images = input_seqs.iter().all(|seq| seq.has_images());
123
124 let (pixel_values_all, image_bound, tgt_sizes) = if has_images {
125 const IMAGE_TAG: &str = "(<image>./</image>)";
126 const IMAGE_PATTERN: &str = r"\(<image>./</image>\)";
127 const AUDIO_PATTERN: &str = r"\(<audio>./</audio>\)";
128
129 let image_pattern = Regex::new(IMAGE_PATTERN).unwrap();
130 let _audio_pattern = Regex::new(AUDIO_PATTERN).unwrap();
131 let split_pattern = Regex::new(&format!(r"({IMAGE_PATTERN}|{AUDIO_PATTERN})")).unwrap();
132
133 let mut pixel_values_accum = Vec::new();
134 let mut tgt_sizes_accum = Vec::new();
135 let mut image_bounds_accum = Vec::new();
136
137 for seq in input_seqs.iter_mut() {
138 let PreprocessedImages {
139 pixel_values: _,
140 pixel_attention_mask: _,
141 image_sizes: _,
142 num_img_tokens: _,
143 aspect_ratio_ids: _,
144 aspect_ratio_mask: _,
145 num_tiles: _,
146 image_grid_thw: _,
147 video_grid_thw: _,
148 rows: _,
149 cols: _,
150 pixel_values_list,
151 tgt_sizes,
152 image_sizes_all,
153 num_crops: _,
154 } = self
155 .preprocess(
156 seq.take_images()
157 .expect("Need to have images by this point."),
158 vec![],
159 config,
160 device,
161 (usize::MAX, usize::MAX), )
163 .expect("Preprocessing failed");
164 let pixel_values_list = pixel_values_list.unwrap();
165 let tgt_sizes = tgt_sizes.unwrap();
166 let image_sizes_all = image_sizes_all.unwrap();
167
168 let text = tokenizer
169 .decode(seq.get_toks(), false)
170 .expect("Detokenization failed!");
171
172 let mut text_chunks = {
173 let mut results = Vec::new();
174 let mut last_end = 0;
175
176 for m in split_pattern.find_iter(&text) {
177 if m.start() > last_end {
179 results.push((false, &text[last_end..m.start()]));
180 }
181 results.push((true, m.as_str()));
182 last_end = m.end();
183 }
184 if last_end < text.len() {
186 results.push((false, &text[last_end..]));
187 }
188
189 results
190 .into_iter()
191 .map(|(_, x)| x.to_string())
192 .collect::<Vec<_>>()
193 };
194
195 let image_tags = image_pattern.find_iter(&text).collect::<Vec<_>>();
196
197 if !image_tags.is_empty() {
198 assert_eq!(image_tags.len(), image_sizes_all.len());
199 }
200
201 let mut image_id = 0;
202 for chunk in &mut text_chunks {
203 if chunk == IMAGE_TAG {
204 *chunk =
205 self.get_slice_image_placeholder(image_sizes_all[image_id], image_id);
206 image_id += 1;
207 }
208 }
209
210 let final_text = text_chunks.join("");
211
212 let input_ids = tokenizer
213 .encode_fast(final_text.clone(), false)
214 .unwrap()
215 .get_ids()
216 .to_vec();
217
218 if !seq.multimodal.has_changed_prompt {
219 seq.set_initial_prompt(final_text.clone());
220
221 seq.set_toks_and_reallocate(input_ids.clone(), paged_attn_metadata.as_mut());
222 seq.multimodal.has_changed_prompt = true;
223 }
224
225 let image_bounds = {
226 let im_start_id = tokenizer
227 .encode_fast(
228 self.config
229 .im_start_token
230 .clone()
231 .unwrap_or(DEFAULT_IM_START_TOKEN.to_string()),
232 false,
233 )
234 .unwrap()
235 .get_ids()[0];
236 let im_end_id = tokenizer
237 .encode_fast(
238 self.config
239 .im_end_token
240 .clone()
241 .unwrap_or(DEFAULT_IM_END_TOKEN.to_string()),
242 false,
243 )
244 .unwrap()
245 .get_ids()[0];
246 let slice_start_id = tokenizer
247 .encode_fast(
248 self.config
249 .slice_start_token
250 .clone()
251 .unwrap_or(DEFAULT_SLICE_START_TOKEN.to_string()),
252 false,
253 )
254 .unwrap()
255 .get_ids()[0];
256 let slice_end_id = tokenizer
257 .encode_fast(
258 self.config
259 .slice_end_token
260 .clone()
261 .unwrap_or(DEFAULT_SLICE_END_TOKEN.to_string()),
262 false,
263 )
264 .unwrap()
265 .get_ids()[0];
266
267 let image_start_idx = input_ids
268 .iter()
269 .enumerate()
270 .filter_map(|(i, &id)| {
271 if id == im_start_id || id == slice_start_id {
272 Some(i as u32 + 1)
273 } else {
274 None
275 }
276 })
277 .collect::<Vec<_>>();
278
279 let image_end_idx = input_ids
280 .iter()
281 .enumerate()
282 .filter_map(|(i, &id)| {
283 if id == im_end_id || id == slice_end_id {
284 Some(i as u32)
285 } else {
286 None
287 }
288 })
289 .collect::<Vec<_>>();
290
291 let valid_image_nums = image_start_idx.len().max(image_end_idx.len());
292
293 let image_start_idx = Tensor::from_slice(
294 &image_start_idx[..valid_image_nums],
295 (valid_image_nums, 1),
296 device,
297 )
298 .unwrap();
299 let image_end_idx = Tensor::from_slice(
300 &image_end_idx[..valid_image_nums],
301 (valid_image_nums, 1),
302 device,
303 )
304 .unwrap();
305
306 Tensor::cat(&[image_start_idx, image_end_idx], 1).unwrap()
307 };
308
309 pixel_values_accum.push(pixel_values_list);
310 tgt_sizes_accum.push(tgt_sizes);
311 image_bounds_accum.push(image_bounds);
312 }
313
314 (
315 Some(pixel_values_accum),
316 Some(image_bounds_accum),
317 Some(tgt_sizes_accum),
318 )
319 } else {
320 (None, None, None)
321 };
322
323 let text_models_inputs_processor::InnerInputProcessorOutput {
324 inputs:
325 text_models_inputs_processor::InputMetadata {
326 input,
327 positions,
328 context_lens,
329 position_ids,
330 paged_attn_meta,
331 flash_meta,
332 },
333 seq_indices,
334 } = if is_prompt {
335 get_prompt_input(
336 input_seqs
337 .iter()
338 .map(|seq| seq.get_toks())
339 .collect::<Vec<_>>(),
340 input_seqs,
341 device,
342 last_n_context_len,
343 return_raw_logits,
344 paged_attn_metadata.as_mut(),
345 mapper,
346 )
347 .unwrap()
348 } else {
349 get_completion_input(
350 input_seqs
351 .iter()
352 .map(|seq| seq.get_toks())
353 .collect::<Vec<_>>(),
354 input_seqs,
355 device,
356 no_kv_cache,
357 last_n_context_len,
358 return_raw_logits,
359 paged_attn_metadata.as_mut(),
360 mapper,
361 )
362 .unwrap()
363 };
364
365 let args = MiniCpmOSpecificArgs {
366 pixel_values_all,
367 tgt_sizes,
368 image_bound,
369 };
370
371 let inputs: Box<dyn Any> = Box::new(ModelInputs {
373 input_ids: input,
374 seqlen_offsets: positions,
375 context_lens,
376 position_ids,
377 pixel_values: None,
378 model_specific_args: Box::new(args),
379 paged_attn_meta,
380 flash_meta,
381 });
382 Ok(InputProcessorOutput {
383 inputs,
384 seq_indices,
385 })
386 }
387}
388
389impl MiniCpmOImageProcessor {
390 fn get_sliced_grid(
391 &self,
392 (w, h): (usize, usize),
393 max_slice_nums: usize,
394 scale_resolution: usize,
395 never_split: bool,
396 ) -> Option<(usize, usize)> {
397 let log_ratio = ((w / h) as f32).ln();
398 let ratio = (w * h) as f32 / (scale_resolution * scale_resolution) as f32;
399 let multiple = ratio.ceil().min(max_slice_nums as f32);
400 if multiple <= 1. || never_split {
401 return None;
402 }
403
404 let mut candidate_split_grid_nums = Vec::new();
405 for i in [multiple - 1., multiple, multiple + 1.] {
406 if i == 1. || i > max_slice_nums as f32 {
407 continue;
408 }
409 candidate_split_grid_nums.push(i);
410 }
411
412 let mut candidate_grids = Vec::new();
413 for split_grid_nums in candidate_split_grid_nums {
414 let mut m = 1.;
415 while m <= split_grid_nums {
416 if split_grid_nums % m == 0. {
417 candidate_grids.push((m as usize, split_grid_nums as usize / m as usize));
418 }
419 m += 1.;
420 }
421 }
422
423 let mut best_grid = (1, 1);
424 let mut min_error = f32::INFINITY;
425 for grid in candidate_grids {
426 let error = (log_ratio - (grid.0 as f32 / grid.1 as f32).ln()).abs();
427 if error < min_error {
428 best_grid = grid;
429 min_error = error;
430 }
431 }
432
433 Some(best_grid)
434 }
435
436 fn ensure_divide(&self, length: usize, patch_size: usize) -> usize {
437 ((length as f32 / patch_size as f32).round() * patch_size as f32).max(patch_size as f32)
438 as usize
439 }
440
441 fn find_best_resize(
442 &self,
443 (mut w, mut h): (usize, usize),
444 scale_resolution: usize,
445 patch_size: usize,
446 allow_upscale: bool,
447 ) -> (usize, usize) {
448 if w * h > scale_resolution * scale_resolution || allow_upscale {
449 let r = w as f32 / h as f32;
450 h = (scale_resolution as f32 / r.sqrt()) as usize;
451 w = (scale_resolution as f32 * r) as usize;
452 }
453 let best_w = self.ensure_divide(w, patch_size);
454 let best_h = self.ensure_divide(h, patch_size);
455 (best_w, best_h)
456 }
457
458 fn get_refine_size(
459 &self,
460 (w, h): (usize, usize),
461 (grid_x, grid_y): (usize, usize),
462 scale_resolution: usize,
463 patch_size: usize,
464 allow_upscale: bool,
465 ) -> (usize, usize) {
466 let refine_w = self.ensure_divide(w, grid_x);
467 let refine_h = self.ensure_divide(h, grid_y);
468
469 let grid_w = refine_h / grid_x;
470 let grid_h = refine_w / grid_y;
471
472 let best_grid_size = self.find_best_resize(
473 (grid_w, grid_h),
474 scale_resolution,
475 patch_size,
476 allow_upscale,
477 );
478
479 (best_grid_size.0 * grid_x, best_grid_size.1 * grid_y)
480 }
481
482 fn split_to_patches(
483 &self,
484 image: &DynamicImage,
485 grid: (usize, usize),
486 ) -> Vec<Vec<DynamicImage>> {
487 let mut patches = Vec::new();
488 let (w, h) = image.dimensions();
489 let (w, h) = (w as usize, h as usize);
490 let grid_x = w / grid.0;
491 let grid_y = h / grid.1;
492 for i in (0..h).step_by(grid_y) {
493 let mut images = Vec::new();
494 for j in (0..w).step_by(grid_x) {
495 images.push(image.crop_imm(j as u32, i as u32, grid_x as u32, grid_y as u32));
496 }
497 patches.push(images);
498 }
499 patches
500 }
501
502 fn get_sliced_images(
503 &self,
504 image: &DynamicImage,
505 max_slice_nums: usize,
506 scale_resolution: usize,
507 patch_size: usize,
508 ) -> Vec<DynamicImage> {
509 if !self.config.slice_mode.unwrap_or(DEFAULT_SLICE_MODE) {
510 return vec![image.clone()];
511 }
512
513 let dims = image.dimensions();
514 let (w, h) = (dims.0 as usize, dims.1 as usize);
515
516 let best_grid = self.get_sliced_grid((w, h), max_slice_nums, scale_resolution, false);
517
518 let (source_images, patches) = if let Some(best_grid) = best_grid {
519 let best_resize = self.find_best_resize((w, h), scale_resolution, patch_size, false);
521 let source_image = image.resize_exact(
522 best_resize.0 as u32,
523 best_resize.1 as u32,
524 FilterType::CatmullRom,
525 );
526 let refine_size =
527 self.get_refine_size((w, h), best_grid, scale_resolution, patch_size, true);
528 let refine_image = image.resize_exact(
529 refine_size.0 as u32,
530 refine_size.1 as u32,
531 FilterType::CatmullRom,
532 );
533 let patches = self
534 .split_to_patches(&refine_image, best_grid)
535 .into_iter()
536 .flatten()
537 .collect::<Vec<_>>();
538
539 (source_image, patches)
540 } else {
541 let best_size = self.find_best_resize((w, h), scale_resolution, patch_size, true);
543 let source_images = image.resize_exact(
544 best_size.0 as u32,
545 best_size.1 as u32,
546 FilterType::CatmullRom,
547 );
548
549 (source_images, vec![])
550 };
551
552 [vec![source_images], patches].concat()
553 }
554
555 fn reshape_by_patch(&self, image: &Tensor, patch_size: usize) -> Result<Tensor> {
558 let (_c, h, w) = image.dims3()?;
560 let (kh, kw) = (patch_size, patch_size);
562 let (sh, sw) = (patch_size, patch_size);
564
565 let out_h = (h - kh) / sh + 1;
566 let out_w = (w - kw) / sw + 1;
567
568 let mut patches = Vec::new();
569 for i in 0..out_h {
570 for j in 0..out_w {
571 let patch = image.i((.., i * sh..i * sh + kh, j * sw..j * sw + kw))?;
573 patches.push(patch.flatten_all()?);
575 }
576 }
577 let mut patches = Tensor::stack(&patches, 1)?;
579
580 patches = patches.reshape((image.dim(0)?, patch_size, patch_size, ()))?;
581 patches
582 .permute((0, 1, 3, 2))?
583 .reshape((image.dim(0)?, patch_size, ()))
584 }
585
586 fn get_image_id_placeholder(&self, image_idx: usize) -> String {
587 format!(
588 "{}{image_idx}{}",
589 self.config
590 .im_id_start
591 .clone()
592 .unwrap_or(DEFAULT_IM_ID_START.to_string()),
593 self.config
594 .im_id_end
595 .clone()
596 .unwrap_or(DEFAULT_IM_ID_END.to_string())
597 )
598 }
599
600 fn get_grid_placeholder(&self, grid: Option<(usize, usize)>) -> String {
601 if let Some(grid) = grid {
602 let slice_image_placeholder = format!(
603 "{}{}{}",
604 self.config
605 .slice_start_token
606 .clone()
607 .unwrap_or(DEFAULT_SLICE_START_TOKEN.to_string()),
608 self.config
609 .unk_token
610 .clone()
611 .unwrap_or(DEFAULT_UNK_TOKEN.to_string())
612 .repeat(
613 self.config
614 .image_feature_size
615 .unwrap_or(DEFAULT_IMAGE_FEATURE_SIZE)
616 ),
617 self.config
618 .slice_end_token
619 .clone()
620 .unwrap_or(DEFAULT_SLICE_END_TOKEN.to_string())
621 );
622
623 let (cols, rows) = grid;
624 let mut slices = Vec::new();
625 for _ in 0..rows {
626 let mut lines = Vec::new();
627 for _ in 0..cols {
628 lines.push(slice_image_placeholder.clone());
629 }
630 slices.push(lines.join(""));
631 }
632
633 slices.join("\n")
634 } else {
635 "".to_string()
636 }
637 }
638
639 fn get_slice_image_placeholder(&self, image_size: (u32, u32), image_idx: usize) -> String {
640 let max_slice_nums = self.config.max_slice_nums.unwrap_or(DEFAULT_MAX_SLICE_NUMS);
641 let use_image_id = self.config.use_image_id.unwrap_or(DEFAULT_USE_IMAGE_ID);
642 let slice_mode = self.config.slice_mode.unwrap_or(DEFAULT_SLICE_MODE);
643
644 let grid = self.get_sliced_grid(
645 (image_size.0 as usize, image_size.1 as usize),
646 max_slice_nums,
647 DEFAULT_SCALE_RESOLUTION,
648 false,
649 );
650
651 let image_placeholder = format!(
652 "{}{}{}",
653 self.config
654 .im_start_token
655 .clone()
656 .unwrap_or(DEFAULT_IM_START_TOKEN.to_string()),
657 self.config
658 .unk_token
659 .clone()
660 .unwrap_or(DEFAULT_UNK_TOKEN.to_string())
661 .repeat(
662 self.config
663 .image_feature_size
664 .unwrap_or(DEFAULT_IMAGE_FEATURE_SIZE)
665 ),
666 self.config
667 .im_end_token
668 .clone()
669 .unwrap_or(DEFAULT_IM_END_TOKEN.to_string())
670 );
671
672 let final_placeholder = if use_image_id {
673 format!(
674 "{}{image_placeholder}",
675 self.get_image_id_placeholder(image_idx)
676 )
677 } else {
678 image_placeholder
679 };
680
681 if slice_mode {
682 format!("{final_placeholder}{}", self.get_grid_placeholder(grid))
683 } else {
684 final_placeholder
685 }
686 }
687}
688
689impl ImagePreProcessor for MiniCpmOImageProcessor {
690 #[allow(clippy::excessive_precision)]
691 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
692 #[allow(clippy::excessive_precision)]
693 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
694
695 fn preprocess(
696 &self,
697 images: Vec<DynamicImage>,
698 _videos: Vec<Vec<DynamicImage>>,
699 config: &PreProcessorConfig,
700 device: &Device,
701 (_bs, _max_num_images): (usize, usize),
702 ) -> Result<PreprocessedImages> {
703 let mut pixel_values = Vec::new();
704 let mut tgt_sizes = Vec::new();
705 let image_sizes = images
706 .iter()
707 .map(|img| img.dimensions())
708 .collect::<Vec<_>>();
709 for image in images {
710 let max_slice_nums = config.max_slice_nums.unwrap_or(DEFAULT_MAX_SLICE_NUMS);
711 let scale_resolution = config.scale_resolution.unwrap_or(DEFAULT_SCALE_RESOLUTION);
712 let patch_size = config.patch_size.unwrap_or(DEFAULT_PATCH_SIZE);
713
714 let image_patches =
715 self.get_sliced_images(&image, max_slice_nums, scale_resolution, patch_size);
716
717 for slice_image in image_patches {
718 let (w, h) = slice_image.dimensions();
719 let to_tensor_rescale = Transforms {
720 input: &ToTensor,
721 inner_transforms: &[&Normalize {
722 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
723 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
724 }],
725 };
726 let mut image = slice_image.apply(to_tensor_rescale, device)?;
727 image = self.reshape_by_patch(&image, patch_size)?;
728 pixel_values.push(image);
729 tgt_sizes.push(Tensor::from_vec(
730 vec![h / patch_size as u32, w / patch_size as u32],
731 (1, 2),
732 &Device::Cpu,
733 )?);
734 }
735 }
736
737 let tgt_sizes = Tensor::cat(&tgt_sizes, 0)?.to_device(device)?;
738 Ok(PreprocessedImages {
740 pixel_values: Tensor::new(0u32, &Device::Cpu)?,
741 pixel_attention_mask: None,
742 image_sizes: None,
743 num_img_tokens: None,
744 aspect_ratio_ids: None,
745 aspect_ratio_mask: None,
746 num_tiles: None,
747 image_grid_thw: None,
748 video_grid_thw: None,
749 rows: None,
750 cols: None,
751 pixel_values_list: Some(pixel_values),
752 tgt_sizes: Some(tgt_sizes),
753 image_sizes_all: Some(image_sizes),
754 num_crops: None,
755 })
756 }
757}