1use std::{any::Any, sync::Arc};
2
3use anyhow::Result;
4use candle_core::{Context, Device, IndexOp, Tensor};
5use image::{imageops::FilterType, DynamicImage, GenericImageView};
6use mistralrs_vision::{
7 ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
8};
9use tokenizers::Tokenizer;
10
11use crate::{
12 device_map::DeviceMapper,
13 pipeline::{
14 text_models_inputs_processor::{
15 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16 },
17 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18 },
19 sequence::Sequence,
20 vision_models::{
21 image_processor::{ImagePreProcessor, PreprocessedImages},
22 preprocessor_config::{PreProcessorConfig, ToFilter},
23 ModelInputs,
24 },
25};
26
27use super::Qwen2VLVisionSpecificArgs;
28
29struct Qwen2VLImageProcessor {
31 max_edge: Option<u32>,
32}
33pub struct Qwen2VLProcessor {
35 max_edge: Option<u32>,
36}
37
38impl Qwen2VLProcessor {
39 pub const VISION_START: &str = "<|vision_start|>";
40 pub const VISION_END: &str = "<|vision_end|>";
41 pub const IMAGE_PAD: &str = "<|image_pad|>";
42 pub const VIDEO_PAD: &str = "<|video_pad|>";
43 pub const PLACEHOLDER: &str = "<|placeholder|>";
44
45 pub fn new(max_edge: Option<u32>) -> Self {
46 Self { max_edge }
47 }
48}
49
50impl Processor for Qwen2VLProcessor {
51 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
52 Arc::new(Qwen2VLImageProcessor {
53 max_edge: self.max_edge,
54 })
55 }
56
57 fn get_special_tokens(&self) -> &[&'static str] {
58 &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
59 }
60
61 fn template_action(&self) -> MessagesAction {
62 MessagesAction::FlattenOnlyText
63 }
64}
65
66fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
67 if let Some(pos) = text.find(to_replace) {
68 let mut result = text.to_string();
69 result.replace_range(pos..pos + to_replace.len(), replacement);
70 result
71 } else {
72 text.to_string()
73 }
74}
75
76fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
77 let mut sequences = Vec::new();
78 let mut start = None;
79
80 for (i, &num) in nums.iter().enumerate() {
81 if num == needle {
82 if start.is_none() {
83 start = Some(i);
84 }
85 } else if let Some(s) = start {
86 sequences.push((s, i));
87 start = None;
88 }
89 }
90
91 if let Some(s) = start {
92 sequences.push((s, nums.len()));
93 }
94
95 sequences
96}
97
98fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
100 let mut indices = Vec::new();
101 let mut start = 0;
102
103 while let Some(pos) = haystack[start..].find(needle) {
104 let index = start + pos;
105 indices.push(index + needle.len());
106 start = index + needle.len(); }
108
109 indices
110}
111
112impl InputsProcessor for Qwen2VLImageProcessor {
113 fn get_type(&self) -> InputsProcessorType {
114 InputsProcessorType::Vision
115 }
116 fn process_inputs(
117 &self,
118 tokenizer: Option<Arc<Tokenizer>>,
119 input_seqs: &mut [&mut Sequence],
120 is_prompt: bool,
121 is_xlora: bool,
122 device: &Device,
123 no_kv_cache: bool,
124 last_n_context_len: Option<(usize, usize)>,
125 return_raw_logits: bool,
126 other_config: Option<Arc<dyn Any>>,
127 mut paged_attn_metadata: Option<PagedAttentionMeta>,
128 mapper: Option<&dyn DeviceMapper>,
129 ) -> Result<InputProcessorOutput> {
130 if is_xlora {
131 return Err(anyhow::Error::msg(
132 "Cannot make inputs for X-LoRA vision model.",
133 ));
134 }
135 if no_kv_cache {
136 return Err(anyhow::Error::msg("Vision model must have kv cache."));
137 }
138 if input_seqs.len() != 1 {
139 return Err(anyhow::Error::msg("Qwen2-VL only supports batch size = 1"));
140 }
141 let Some(tokenizer) = tokenizer else {
142 return Err(anyhow::Error::msg(
143 "MLlamaInputProcessor requires a specified tokenizer.",
144 ));
145 };
146
147 let text_models_inputs_processor::InnerInputProcessorOutput {
148 inputs:
149 text_models_inputs_processor::InputMetadata {
150 input,
151 positions,
152 context_lens,
153 position_ids,
154 paged_attn_meta,
155 flash_meta,
156 },
157 seq_indices,
158 } = if is_prompt {
159 get_prompt_input(
160 input_seqs
161 .iter()
162 .map(|seq| seq.get_toks())
163 .collect::<Vec<_>>(),
164 input_seqs,
165 device,
166 last_n_context_len,
167 return_raw_logits,
168 paged_attn_metadata.as_mut(),
169 mapper,
170 )
171 .unwrap()
172 } else {
173 get_completion_input(
174 input_seqs
175 .iter()
176 .map(|seq| seq.get_toks())
177 .collect::<Vec<_>>(),
178 input_seqs,
179 device,
180 no_kv_cache,
181 last_n_context_len,
182 return_raw_logits,
183 paged_attn_metadata.as_mut(),
184 mapper,
185 )
186 .unwrap()
187 };
188 let config = other_config.expect("Need a PreProcessorConfig config.");
189 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
190
191 let has_images = input_seqs.iter().all(|seq| seq.has_images());
192
193 let (
194 new_input,
195 pixel_values,
196 image_grid_thw,
197 video_grid_thw,
198 continuous_img_pad,
199 continuous_vid_pad,
200 input_ids_searching,
201 image_nums,
202 video_nums,
203 ) = if has_images {
204 let mut pixel_values_accum = Vec::new();
205 let mut image_grid_thw_accum = Vec::new();
206 let mut video_grid_thw_accum = Vec::new();
207
208 let mut detok_seqs = tokenizer
209 .decode_batch(
210 &input_seqs
211 .iter()
212 .map(|seq| seq.get_toks())
213 .collect::<Vec<_>>(),
214 false,
215 )
216 .expect("Detokenization failed!");
217
218 for seq in input_seqs.iter_mut() {
219 let (pixel_values, image_grid_thw, video_grid_thw) =
220 if let Some(cached_pixel_values) = &seq.multimodal.cached_pixel_values {
221 (
222 cached_pixel_values.clone(),
223 seq.multimodal.cached_img_thw.clone(),
224 seq.multimodal.cached_vid_thw.clone(),
225 )
226 } else {
227 let PreprocessedImages {
228 pixel_values,
229 pixel_attention_mask: _,
230 image_sizes: _,
231 num_img_tokens: _,
232 aspect_ratio_ids: _,
233 aspect_ratio_mask: _,
234 num_tiles: _,
235 image_grid_thw,
236 video_grid_thw,
237 rows: _,
238 cols: _,
239 pixel_values_list: _,
240 tgt_sizes: _,
241 image_sizes_all: _,
242 num_crops: _,
243 } = self
244 .preprocess(
245 seq.clone_images()
246 .expect("Need to have images by this point."),
247 vec![],
248 config,
249 device,
250 (usize::MAX, usize::MAX), )
252 .expect("Preprocessing failed");
253
254 seq.multimodal.cached_pixel_values = Some(pixel_values.clone());
255 seq.multimodal.cached_img_thw = image_grid_thw.clone();
256 seq.multimodal.cached_vid_thw = video_grid_thw.clone();
257 (pixel_values, image_grid_thw, video_grid_thw)
258 };
259
260 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
261 image_grid_thw_accum.push(image_grid_thw); video_grid_thw_accum.push(video_grid_thw); }
264
265 let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
266 None
267 } else {
268 Some(
269 image_grid_thw_accum
270 .into_iter()
271 .map(|img| img.unwrap())
272 .collect::<Vec<_>>(),
273 )
274 };
275
276 let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
277 None
278 } else {
279 Some(
280 video_grid_thw_accum
281 .into_iter()
282 .map(|img| img.unwrap())
283 .collect::<Vec<_>>(),
284 )
285 };
286
287 if is_prompt {
288 if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
289 let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
290 for ((batch, text), seq) in
291 detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
292 {
293 if seq.multimodal.has_changed_prompt {
294 continue;
295 }
296 let mut index = 0;
297 while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
298 *text = replace_first_occurrence(
299 text,
300 Qwen2VLProcessor::IMAGE_PAD,
301 &Qwen2VLProcessor::PLACEHOLDER.repeat(
302 image_grid_thw_accum[batch]
303 .i(index)
304 .unwrap()
305 .to_vec1::<u32>()
306 .unwrap()
307 .iter()
308 .product::<u32>()
309 as usize
310 / merge_length,
311 ),
312 );
313 index += 1;
314 }
315 *text = text
316 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
317 }
318 }
319
320 if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
321 let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
322 let mut index = 0;
323 for ((batch, text), seq) in
324 detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
325 {
326 if seq.multimodal.has_changed_prompt {
327 continue;
328 }
329 while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
330 *text = replace_first_occurrence(
331 text,
332 Qwen2VLProcessor::VIDEO_PAD,
333 &Qwen2VLProcessor::PLACEHOLDER.repeat(
334 video_grid_thw_accum[batch]
335 .i(index)
336 .unwrap()
337 .to_vec1::<u32>()
338 .unwrap()
339 .iter()
340 .product::<u32>()
341 as usize
342 / merge_length,
343 ),
344 );
345 index += 1;
346 }
347 *text = text
348 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
349 }
350 }
351 }
352
353 let mut all_ids = Vec::new();
354 let mut all_continuous_img_pad = Vec::new();
355 let mut all_continuous_vid_pad = Vec::new();
356 for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
357 let toks = tokenizer
358 .encode_fast(detok.clone(), false)
359 .expect("Detokenization failed!");
360 let ids = toks.get_ids().to_vec();
361
362 if !seq.multimodal.has_changed_prompt {
363 seq.set_initial_prompt(detok.clone());
364
365 seq.set_toks_and_reallocate(ids.clone(), paged_attn_metadata.as_mut());
366 seq.multimodal.has_changed_prompt = true;
367 }
368 all_ids.push(ids.clone());
369
370 let img_pad = tokenizer
371 .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
372 .expect("Detokenization failed!")
373 .get_ids()
374 .to_vec();
375 let continuous_img_pad = find_sequences(&ids, img_pad[0]);
376 all_continuous_img_pad.push(continuous_img_pad);
377
378 let vid_pad = tokenizer
379 .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
380 .expect("Detokenization failed!")
381 .get_ids()
382 .to_vec();
383 let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
384 all_continuous_vid_pad.push(continuous_vid_pad);
385 }
386
387 let mut input_ids_searching = Vec::new();
388 let mut image_nums = Vec::new();
389 let mut video_nums = Vec::new();
390 for (seq, ids) in input_seqs.iter().zip(&all_ids) {
391 let prompt = seq.get_initial_prompt();
392 let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
393 image_nums.push(
394 match_indices
395 .iter()
396 .filter(|&&idx| {
397 prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
398 == *Qwen2VLProcessor::IMAGE_PAD
399 })
400 .count(),
401 );
402 video_nums.push(
403 match_indices
404 .iter()
405 .filter(|&&idx| {
406 prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
407 == *Qwen2VLProcessor::VIDEO_PAD
408 })
409 .count(),
410 );
411
412 input_ids_searching.push(ids.to_vec());
413 }
414
415 let mut all_ids_new = Vec::new();
416 let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
417 for ids in all_ids {
418 let pad = max_len - ids.len();
419 all_ids_new
420 .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
421 }
422
423 (
424 Some(Tensor::stack(&all_ids_new, 0).unwrap()),
425 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
426 image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
427 video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
428 all_continuous_img_pad,
429 all_continuous_vid_pad,
430 input_ids_searching,
431 image_nums,
432 video_nums,
433 )
434 } else {
435 (
436 None,
437 None,
438 None,
439 None,
440 vec![],
441 vec![],
442 vec![vec![]; input_seqs.len()],
443 vec![0; input_seqs.len()],
444 vec![0; input_seqs.len()],
445 )
446 };
447
448 let (input, input_ids_full) = match (new_input, is_prompt) {
449 (Some(new_input), true) => (new_input.clone(), new_input),
450 (Some(new_input), false) => (input, new_input),
451 (None, _) => (input.clone(), input.clone()),
452 };
453
454 let pixel_values = if is_prompt { pixel_values } else { None };
455
456 let seqlens = input_seqs.iter().map(|seq| seq.len()).collect::<Vec<_>>();
457
458 let inputs: Box<dyn Any> = Box::new(ModelInputs {
459 input_ids: input,
460 seqlen_offsets: positions,
461 context_lens,
462 position_ids,
463 pixel_values,
464 model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
465 input_ids_full,
466 image_grid_thw,
467 video_grid_thw,
468 seqlens,
469 continuous_img_pad,
470 continuous_vid_pad,
471 input_ids_searching,
472 image_nums,
473 video_nums,
474 }),
475 paged_attn_meta,
476 flash_meta,
477 });
478 Ok(InputProcessorOutput {
479 inputs,
480 seq_indices,
481 })
482 }
483}
484
485impl Qwen2VLImageProcessor {
486 fn smart_resize(
487 &self,
488 height: usize,
489 width: usize,
490 factor: usize,
491 min_pixels: usize,
492 max_pixels: usize,
493 ) -> candle_core::Result<(usize, usize)> {
494 if height < factor || width < factor {
495 candle_core::bail!(
496 "height:{} or width:{} must be larger than factor:{}",
497 height,
498 width,
499 factor
500 );
501 } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
502 candle_core::bail!(
503 "absolute aspect ratio must be smaller than 200, got {:.2}",
504 height.max(width) as f64 / height.min(width) as f64
505 );
506 }
507
508 let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
509 let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
510
511 if h_bar * w_bar > max_pixels {
512 let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
513 h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
514 w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
515 } else if h_bar * w_bar < min_pixels {
516 let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
517 h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
518 w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
519 }
520
521 Ok((h_bar, w_bar))
522 }
523
524 fn preprocess_inner(
526 &self,
527 images: Vec<DynamicImage>,
528 config: &PreProcessorConfig,
529 device: &Device,
530 (mut height, mut width): (u32, u32),
531 ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
532 let mut processed_images = Vec::new();
533
534 for mut image in images {
535 image = image.resize_exact(
536 height,
537 width,
538 config
539 .resampling
540 .map(|resample| Some(resample).to_filter())
541 .unwrap_or(Ok(FilterType::CatmullRom))?,
542 );
543 image = DynamicImage::ImageRgb8(image.to_rgb8());
544 if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
545 let (resized_height, resized_width) = self.smart_resize(
546 height as usize,
547 width as usize,
548 config.patch_size.context("Require `patch_size`.")?
549 * config.merge_size.context("Require `merge_size`")?,
550 config.min_pixels.context("Require `min_pixels`")?,
551 config.max_pixels.context("Require `max_pixels`")?,
552 )?;
553 height = resized_height as u32;
554 width = resized_width as u32;
555 image = image.resize_exact(
556 resized_width as u32,
557 resized_height as u32,
558 config
559 .resampling
560 .map(|resample| Some(resample).to_filter())
561 .unwrap_or(Ok(FilterType::CatmullRom))?,
562 );
563 }
564
565 let to_tensor_rescale = Transforms {
566 input: &ToTensor,
567 inner_transforms: &[],
568 };
569 let image = image.apply(to_tensor_rescale, device)?;
570
571 let transforms = TensorTransforms {
572 inner_transforms: &[&Normalize {
573 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
574 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
575 }],
576 };
577 let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
578
579 processed_images.push(image);
580 }
581
582 let mut patches = Tensor::stack(&processed_images, 0)?;
583 let temporal_patch_size = config
584 .temporal_patch_size
585 .context("Require `temporal_patch_size")?;
586 let patch_size = config.patch_size.context("Require `patch_size")?;
587 let merge_size = config.merge_size.context("Require `merge_size")?;
588 if patches.dim(0)? == 1 {
590 patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
591 }
592 let channel = patches.dim(1)?;
593 let grid_t = patches.dim(0)? / temporal_patch_size;
594 let grid_h = height as usize / patch_size;
595 let grid_w = width as usize / patch_size;
596 patches = patches.reshape(&[
597 grid_t,
598 temporal_patch_size,
599 channel,
600 grid_h / merge_size,
601 merge_size,
602 patch_size,
603 grid_w / merge_size,
604 merge_size,
605 patch_size,
606 ])?;
607 patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
608 let flattened_patches = patches.reshape((
609 grid_t * grid_h * grid_w,
610 channel * temporal_patch_size * patch_size * patch_size,
611 ))?;
612
613 Ok((
614 flattened_patches,
615 (grid_t as u32, grid_h as u32, grid_w as u32),
616 ))
617 }
618}
619
620impl ImagePreProcessor for Qwen2VLImageProcessor {
621 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
622 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
623
624 fn preprocess(
625 &self,
626 mut images: Vec<DynamicImage>,
627 videos: Vec<Vec<DynamicImage>>,
628 config: &PreProcessorConfig,
629 device: &Device,
630 (_, _): (usize, usize),
631 ) -> candle_core::Result<PreprocessedImages> {
632 let mut pixel_values = Vec::new();
633 let mut vision_grid_thw = Vec::new();
634
635 if !images.is_empty() {
636 if let Some(max_edge) = self.max_edge {
637 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
638 }
639
640 let mut height = 0;
641 let mut width = 0;
642 for image in &images {
643 let (w, h) = image.dimensions();
644 if w > width {
645 width = w;
646 }
647 if h > height {
648 height = h;
649 }
650 }
651
652 for image in images {
653 let (patches, (t, h, w)) =
654 self.preprocess_inner(vec![image], config, device, (height, width))?;
655 pixel_values.push(patches);
656 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
657 }
658 let pixel_values = Tensor::stack(&pixel_values, 0)?;
659 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
660 return Ok(PreprocessedImages {
661 pixel_values,
662 pixel_attention_mask: None,
663 image_sizes: None,
664 num_img_tokens: None,
665 aspect_ratio_ids: None,
666 aspect_ratio_mask: None,
667 num_tiles: None,
668 image_grid_thw: Some(vision_grid_thw),
669 video_grid_thw: None,
670 rows: None,
671 cols: None,
672 pixel_values_list: None,
673 tgt_sizes: None,
674 image_sizes_all: None,
675 num_crops: None,
676 });
677 }
678
679 if !videos.is_empty() {
680 let mut height = 0;
681 let mut width = 0;
682 for image in &videos {
683 let (w, h) = image[0].dimensions();
684 if w > width {
685 width = w;
686 }
687 if h > height {
688 height = h;
689 }
690 }
691
692 for images in videos {
693 let (patches, (t, h, w)) =
694 self.preprocess_inner(images, config, device, (height, width))?;
695 pixel_values.push(patches);
696 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
697 }
698 let pixel_values = Tensor::stack(&pixel_values, 0)?;
699 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
700 return Ok(PreprocessedImages {
701 pixel_values,
702 pixel_attention_mask: None,
703 image_sizes: None,
704 num_img_tokens: None,
705 aspect_ratio_ids: None,
706 aspect_ratio_mask: None,
707 num_tiles: None,
708 image_grid_thw: None,
709 video_grid_thw: Some(vision_grid_thw),
710 rows: None,
711 cols: None,
712 pixel_values_list: None,
713 tgt_sizes: None,
714 image_sizes_all: None,
715 num_crops: None,
716 });
717 }
718 unreachable!()
719 }
720}