1use std::{any::Any, sync::Arc};
2
3use anyhow::Result;
4use candle_core::{Context, Device, IndexOp, Tensor};
5use image::{imageops::FilterType, DynamicImage, GenericImageView};
6use mistralrs_vision::{
7 ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
8};
9use tokenizers::Tokenizer;
10
11use crate::{
12 device_map::DeviceMapper,
13 pipeline::{
14 text_models_inputs_processor::{
15 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16 },
17 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18 },
19 sequence::Sequence,
20 vision_models::{
21 image_processor::{ImagePreProcessor, PreprocessedImages},
22 preprocessor_config::{PreProcessorConfig, ToFilter},
23 ModelInputs,
24 },
25};
26
27use super::Qwen2VLVisionSpecificArgs;
28
29struct Qwen2VLImageProcessor {
31 max_edge: Option<u32>,
32}
33pub struct Qwen2VLProcessor {
35 max_edge: Option<u32>,
36}
37
38impl Qwen2VLProcessor {
39 pub const VISION_START: &str = "<|vision_start|>";
40 pub const VISION_END: &str = "<|vision_end|>";
41 pub const IMAGE_PAD: &str = "<|image_pad|>";
42 pub const VIDEO_PAD: &str = "<|video_pad|>";
43 pub const PLACEHOLDER: &str = "<|placeholder|>";
44
45 pub fn new(max_edge: Option<u32>) -> Self {
46 Self { max_edge }
47 }
48}
49
50impl Processor for Qwen2VLProcessor {
51 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
52 Arc::new(Qwen2VLImageProcessor {
53 max_edge: self.max_edge,
54 })
55 }
56
57 fn get_special_tokens(&self) -> &[&'static str] {
58 &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
59 }
60
61 fn template_action(&self) -> MessagesAction {
62 MessagesAction::FlattenOnlyText
63 }
64}
65
66fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
67 if let Some(pos) = text.find(to_replace) {
68 let mut result = text.to_string();
69 result.replace_range(pos..pos + to_replace.len(), replacement);
70 result
71 } else {
72 text.to_string()
73 }
74}
75
76fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
77 let mut sequences = Vec::new();
78 let mut start = None;
79
80 for (i, &num) in nums.iter().enumerate() {
81 if num == needle {
82 if start.is_none() {
83 start = Some(i);
84 }
85 } else if let Some(s) = start {
86 sequences.push((s, i));
87 start = None;
88 }
89 }
90
91 if let Some(s) = start {
92 sequences.push((s, nums.len()));
93 }
94
95 sequences
96}
97
98fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
100 let mut indices = Vec::new();
101 let mut start = 0;
102
103 while let Some(pos) = haystack[start..].find(needle) {
104 let index = start + pos;
105 indices.push(index + needle.len());
106 start = index + needle.len(); }
108
109 indices
110}
111
112impl InputsProcessor for Qwen2VLImageProcessor {
113 fn get_type(&self) -> InputsProcessorType {
114 InputsProcessorType::Vision
115 }
116 fn process_inputs(
117 &self,
118 tokenizer: Option<Arc<Tokenizer>>,
119 input_seqs: &mut [&mut Sequence],
120 is_prompt: bool,
121 is_xlora: bool,
122 device: &Device,
123 no_kv_cache: bool,
124 last_n_context_len: Option<(usize, usize)>,
125 return_raw_logits: bool,
126 other_config: Option<Arc<dyn Any>>,
127 mut paged_attn_metadata: Option<PagedAttentionMeta>,
128 mapper: Option<&dyn DeviceMapper>,
129 ) -> Result<InputProcessorOutput> {
130 if is_xlora {
131 return Err(anyhow::Error::msg(
132 "Cannot make inputs for X-LoRA vision model.",
133 ));
134 }
135 if no_kv_cache {
136 return Err(anyhow::Error::msg("Vision model must have kv cache."));
137 }
138 let Some(tokenizer) = tokenizer else {
139 return Err(anyhow::Error::msg(
140 "MLlamaInputProcessor requires a specified tokenizer.",
141 ));
142 };
143
144 let text_models_inputs_processor::InnerInputProcessorOutput {
145 inputs:
146 text_models_inputs_processor::InputMetadata {
147 input,
148 positions,
149 context_lens,
150 position_ids,
151 paged_attn_meta,
152 flash_meta,
153 },
154 seq_indices,
155 } = if is_prompt {
156 get_prompt_input(
157 input_seqs
158 .iter()
159 .map(|seq| seq.get_toks())
160 .collect::<Vec<_>>(),
161 input_seqs,
162 device,
163 last_n_context_len,
164 return_raw_logits,
165 paged_attn_metadata.as_mut(),
166 mapper,
167 )
168 .unwrap()
169 } else {
170 get_completion_input(
171 input_seqs
172 .iter()
173 .map(|seq| seq.get_toks())
174 .collect::<Vec<_>>(),
175 input_seqs,
176 device,
177 no_kv_cache,
178 last_n_context_len,
179 return_raw_logits,
180 paged_attn_metadata.as_mut(),
181 mapper,
182 )
183 .unwrap()
184 };
185 let config = other_config.expect("Need a PreProcessorConfig config.");
186 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
187
188 let has_images = input_seqs.iter().all(|seq| seq.has_images());
189
190 let (
191 new_input,
192 pixel_values,
193 image_grid_thw,
194 video_grid_thw,
195 continuous_img_pad,
196 continuous_vid_pad,
197 input_ids_searching,
198 image_nums,
199 video_nums,
200 ) = if has_images {
201 let mut pixel_values_accum = Vec::new();
202 let mut image_grid_thw_accum = Vec::new();
203 let mut video_grid_thw_accum = Vec::new();
204
205 let mut detok_seqs = tokenizer
206 .decode_batch(
207 &input_seqs
208 .iter()
209 .map(|seq| seq.get_toks())
210 .collect::<Vec<_>>(),
211 false,
212 )
213 .expect("Detokenization failed!");
214
215 for seq in input_seqs.iter_mut() {
216 let (pixel_values, image_grid_thw, video_grid_thw) =
217 if let Some(cached_pixel_values) = &seq.multimodal.cached_pixel_values {
218 (
219 cached_pixel_values.clone(),
220 seq.multimodal.cached_img_thw.clone(),
221 seq.multimodal.cached_vid_thw.clone(),
222 )
223 } else {
224 let PreprocessedImages {
225 pixel_values,
226 pixel_attention_mask: _,
227 image_sizes: _,
228 num_img_tokens: _,
229 aspect_ratio_ids: _,
230 aspect_ratio_mask: _,
231 num_tiles: _,
232 image_grid_thw,
233 video_grid_thw,
234 rows: _,
235 cols: _,
236 pixel_values_list: _,
237 tgt_sizes: _,
238 image_sizes_all: _,
239 num_crops: _,
240 } = self
241 .preprocess(
242 seq.clone_images()
243 .expect("Need to have images by this point."),
244 vec![],
245 config,
246 device,
247 (usize::MAX, usize::MAX), )
249 .expect("Preprocessing failed");
250
251 seq.multimodal.cached_pixel_values = Some(pixel_values.clone());
252 seq.multimodal.cached_img_thw = image_grid_thw.clone();
253 seq.multimodal.cached_vid_thw = video_grid_thw.clone();
254 (pixel_values, image_grid_thw, video_grid_thw)
255 };
256
257 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
258 image_grid_thw_accum.push(image_grid_thw); video_grid_thw_accum.push(video_grid_thw); }
261
262 let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
263 None
264 } else {
265 Some(
266 image_grid_thw_accum
267 .into_iter()
268 .map(|img| img.unwrap())
269 .collect::<Vec<_>>(),
270 )
271 };
272
273 let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
274 None
275 } else {
276 Some(
277 video_grid_thw_accum
278 .into_iter()
279 .map(|img| img.unwrap())
280 .collect::<Vec<_>>(),
281 )
282 };
283
284 if is_prompt {
285 if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
286 let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
287 for ((batch, text), seq) in
288 detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
289 {
290 if seq.multimodal.has_changed_prompt {
291 continue;
292 }
293 let mut index = 0;
294 while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
295 *text = replace_first_occurrence(
296 text,
297 Qwen2VLProcessor::IMAGE_PAD,
298 &Qwen2VLProcessor::PLACEHOLDER.repeat(
299 image_grid_thw_accum[batch]
300 .i(index)
301 .unwrap()
302 .to_vec1::<u32>()
303 .unwrap()
304 .iter()
305 .product::<u32>()
306 as usize
307 / merge_length,
308 ),
309 );
310 index += 1;
311 }
312 *text = text
313 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
314 }
315 }
316
317 if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
318 let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
319 let mut index = 0;
320 for ((batch, text), seq) in
321 detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
322 {
323 if seq.multimodal.has_changed_prompt {
324 continue;
325 }
326 while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
327 *text = replace_first_occurrence(
328 text,
329 Qwen2VLProcessor::VIDEO_PAD,
330 &Qwen2VLProcessor::PLACEHOLDER.repeat(
331 video_grid_thw_accum[batch]
332 .i(index)
333 .unwrap()
334 .to_vec1::<u32>()
335 .unwrap()
336 .iter()
337 .product::<u32>()
338 as usize
339 / merge_length,
340 ),
341 );
342 index += 1;
343 }
344 *text = text
345 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
346 }
347 }
348 }
349
350 let mut all_ids = Vec::new();
351 let mut all_continuous_img_pad = Vec::new();
352 let mut all_continuous_vid_pad = Vec::new();
353 for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
354 let toks = tokenizer
355 .encode_fast(detok.clone(), false)
356 .expect("Detokenization failed!");
357 let ids = toks.get_ids().to_vec();
358
359 if !seq.multimodal.has_changed_prompt {
360 seq.set_initial_prompt(detok.clone());
361
362 seq.set_toks_and_reallocate(ids.clone(), paged_attn_metadata.as_mut());
363 seq.multimodal.has_changed_prompt = true;
364 }
365 all_ids.push(ids.clone());
366
367 let img_pad = tokenizer
368 .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
369 .expect("Detokenization failed!")
370 .get_ids()
371 .to_vec();
372 let continuous_img_pad = find_sequences(&ids, img_pad[0]);
373 all_continuous_img_pad.push(continuous_img_pad);
374
375 let vid_pad = tokenizer
376 .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
377 .expect("Detokenization failed!")
378 .get_ids()
379 .to_vec();
380 let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
381 all_continuous_vid_pad.push(continuous_vid_pad);
382 }
383
384 let mut input_ids_searching = Vec::new();
385 let mut image_nums = Vec::new();
386 let mut video_nums = Vec::new();
387 for (seq, ids) in input_seqs.iter().zip(&all_ids) {
388 let prompt = seq.get_initial_prompt();
389 let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
390 image_nums.push(
391 match_indices
392 .iter()
393 .filter(|&&idx| {
394 prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
395 == *Qwen2VLProcessor::IMAGE_PAD
396 })
397 .count(),
398 );
399 video_nums.push(
400 match_indices
401 .iter()
402 .filter(|&&idx| {
403 prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
404 == *Qwen2VLProcessor::VIDEO_PAD
405 })
406 .count(),
407 );
408
409 input_ids_searching.push(ids.to_vec());
410 }
411
412 let mut all_ids_new = Vec::new();
413 let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
414 for ids in all_ids {
415 let pad = max_len - ids.len();
416 all_ids_new
417 .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
418 }
419
420 (
421 Some(Tensor::stack(&all_ids_new, 0).unwrap()),
422 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
423 image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
424 video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
425 all_continuous_img_pad,
426 all_continuous_vid_pad,
427 input_ids_searching,
428 image_nums,
429 video_nums,
430 )
431 } else {
432 (
433 None,
434 None,
435 None,
436 None,
437 vec![],
438 vec![],
439 vec![vec![]; input_seqs.len()],
440 vec![0; input_seqs.len()],
441 vec![0; input_seqs.len()],
442 )
443 };
444
445 let (input, input_ids_full) = match (new_input, is_prompt) {
446 (Some(new_input), true) => (new_input.clone(), new_input),
447 (Some(new_input), false) => (input, new_input),
448 (None, _) => (input.clone(), input.clone()),
449 };
450
451 let pixel_values = if is_prompt { pixel_values } else { None };
452
453 let seqlens = input_seqs.iter().map(|seq| seq.len()).collect::<Vec<_>>();
454
455 let inputs: Box<dyn Any> = Box::new(ModelInputs {
456 input_ids: input,
457 seqlen_offsets: positions,
458 context_lens,
459 position_ids,
460 pixel_values,
461 model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
462 input_ids_full,
463 image_grid_thw,
464 video_grid_thw,
465 seqlens,
466 continuous_img_pad,
467 continuous_vid_pad,
468 input_ids_searching,
469 image_nums,
470 video_nums,
471 }),
472 paged_attn_meta,
473 flash_meta,
474 });
475 Ok(InputProcessorOutput {
476 inputs,
477 seq_indices,
478 })
479 }
480}
481
482impl Qwen2VLImageProcessor {
483 fn smart_resize(
484 &self,
485 height: usize,
486 width: usize,
487 factor: usize,
488 min_pixels: usize,
489 max_pixels: usize,
490 ) -> candle_core::Result<(usize, usize)> {
491 if height < factor || width < factor {
492 candle_core::bail!(
493 "height:{} or width:{} must be larger than factor:{}",
494 height,
495 width,
496 factor
497 );
498 } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
499 candle_core::bail!(
500 "absolute aspect ratio must be smaller than 200, got {:.2}",
501 height.max(width) as f64 / height.min(width) as f64
502 );
503 }
504
505 let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
506 let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
507
508 if h_bar * w_bar > max_pixels {
509 let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
510 h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
511 w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
512 } else if h_bar * w_bar < min_pixels {
513 let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
514 h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
515 w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
516 }
517
518 Ok((h_bar, w_bar))
519 }
520
521 fn preprocess_inner(
523 &self,
524 images: Vec<DynamicImage>,
525 config: &PreProcessorConfig,
526 device: &Device,
527 (mut height, mut width): (u32, u32),
528 ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
529 let mut processed_images = Vec::new();
530
531 for mut image in images {
532 image = image.resize_exact(
533 height,
534 width,
535 config
536 .resampling
537 .map(|resample| Some(resample).to_filter())
538 .unwrap_or(Ok(FilterType::CatmullRom))?,
539 );
540 image = DynamicImage::ImageRgb8(image.to_rgb8());
541 if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
542 let (resized_height, resized_width) = self.smart_resize(
543 height as usize,
544 width as usize,
545 config.patch_size.context("Require `patch_size`.")?
546 * config.merge_size.context("Require `merge_size`")?,
547 config.min_pixels.context("Require `min_pixels`")?,
548 config.max_pixels.context("Require `max_pixels`")?,
549 )?;
550 height = resized_height as u32;
551 width = resized_width as u32;
552 image = image.resize_exact(
553 resized_width as u32,
554 resized_height as u32,
555 config
556 .resampling
557 .map(|resample| Some(resample).to_filter())
558 .unwrap_or(Ok(FilterType::CatmullRom))?,
559 );
560 }
561
562 let to_tensor_rescale = Transforms {
563 input: &ToTensor,
564 inner_transforms: &[],
565 };
566 let image = image.apply(to_tensor_rescale, device)?;
567
568 let transforms = TensorTransforms {
569 inner_transforms: &[&Normalize {
570 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
571 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
572 }],
573 };
574 let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
575
576 processed_images.push(image);
577 }
578
579 let mut patches = Tensor::stack(&processed_images, 0)?;
580 let temporal_patch_size = config
581 .temporal_patch_size
582 .context("Require `temporal_patch_size")?;
583 let patch_size = config.patch_size.context("Require `patch_size")?;
584 let merge_size = config.merge_size.context("Require `merge_size")?;
585 if patches.dim(0)? == 1 {
587 patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
588 }
589 let channel = patches.dim(1)?;
590 let grid_t = patches.dim(0)? / temporal_patch_size;
591 let grid_h = height as usize / patch_size;
592 let grid_w = width as usize / patch_size;
593 patches = patches.reshape(&[
594 grid_t,
595 temporal_patch_size,
596 channel,
597 grid_h / merge_size,
598 merge_size,
599 patch_size,
600 grid_w / merge_size,
601 merge_size,
602 patch_size,
603 ])?;
604 patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
605 let flattened_patches = patches.reshape((
606 grid_t * grid_h * grid_w,
607 channel * temporal_patch_size * patch_size * patch_size,
608 ))?;
609
610 Ok((
611 flattened_patches,
612 (grid_t as u32, grid_h as u32, grid_w as u32),
613 ))
614 }
615}
616
617impl ImagePreProcessor for Qwen2VLImageProcessor {
618 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
619 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
620
621 fn preprocess(
622 &self,
623 mut images: Vec<DynamicImage>,
624 videos: Vec<Vec<DynamicImage>>,
625 config: &PreProcessorConfig,
626 device: &Device,
627 (_, _): (usize, usize),
628 ) -> candle_core::Result<PreprocessedImages> {
629 let mut pixel_values = Vec::new();
630 let mut vision_grid_thw = Vec::new();
631
632 if !images.is_empty() {
633 if let Some(max_edge) = self.max_edge {
634 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
635 }
636
637 let mut height = 0;
638 let mut width = 0;
639 for image in &images {
640 let (w, h) = image.dimensions();
641 if w > width {
642 width = w;
643 }
644 if h > height {
645 height = h;
646 }
647 }
648
649 for image in images {
650 let (patches, (t, h, w)) =
651 self.preprocess_inner(vec![image], config, device, (height, width))?;
652 pixel_values.push(patches);
653 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
654 }
655 let pixel_values = Tensor::stack(&pixel_values, 0)?;
656 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
657 return Ok(PreprocessedImages {
658 pixel_values,
659 pixel_attention_mask: None,
660 image_sizes: None,
661 num_img_tokens: None,
662 aspect_ratio_ids: None,
663 aspect_ratio_mask: None,
664 num_tiles: None,
665 image_grid_thw: Some(vision_grid_thw),
666 video_grid_thw: None,
667 rows: None,
668 cols: None,
669 pixel_values_list: None,
670 tgt_sizes: None,
671 image_sizes_all: None,
672 num_crops: None,
673 });
674 }
675
676 if !videos.is_empty() {
677 let mut height = 0;
678 let mut width = 0;
679 for image in &videos {
680 let (w, h) = image[0].dimensions();
681 if w > width {
682 width = w;
683 }
684 if h > height {
685 height = h;
686 }
687 }
688
689 for images in videos {
690 let (patches, (t, h, w)) =
691 self.preprocess_inner(images, config, device, (height, width))?;
692 pixel_values.push(patches);
693 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
694 }
695 let pixel_values = Tensor::stack(&pixel_values, 0)?;
696 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
697 return Ok(PreprocessedImages {
698 pixel_values,
699 pixel_attention_mask: None,
700 image_sizes: None,
701 num_img_tokens: None,
702 aspect_ratio_ids: None,
703 aspect_ratio_mask: None,
704 num_tiles: None,
705 image_grid_thw: None,
706 video_grid_thw: Some(vision_grid_thw),
707 rows: None,
708 cols: None,
709 pixel_values_list: None,
710 tgt_sizes: None,
711 image_sizes_all: None,
712 num_crops: None,
713 });
714 }
715 unreachable!()
716 }
717}