1use std::{any::Any, num::NonZeroUsize, sync::Arc};
2
3use anyhow::Result;
4use candle_core::{Context, Device, IndexOp, Tensor};
5use image::{imageops::FilterType, DynamicImage, GenericImageView};
6use mistralrs_vision::{
7 ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
8};
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13 device_map::DeviceMapper,
14 pipeline::{
15 text_models_inputs_processor::{
16 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17 },
18 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19 },
20 sequence::Sequence,
21 vision_models::{
22 image_processor::{ImagePreProcessor, PreprocessedImages},
23 preprocessor_config::{PreProcessorConfig, ToFilter},
24 ModelInputs,
25 },
26};
27
28use super::Qwen2VLVisionSpecificArgs;
29
30struct Qwen2VLImageProcessor {
32 max_edge: Option<u32>,
33}
34pub struct Qwen2VLProcessor {
36 max_edge: Option<u32>,
37}
38
39impl Qwen2VLProcessor {
40 pub const VISION_START: &str = "<|vision_start|>";
41 pub const VISION_END: &str = "<|vision_end|>";
42 pub const IMAGE_PAD: &str = "<|image_pad|>";
43 pub const VIDEO_PAD: &str = "<|video_pad|>";
44 pub const PLACEHOLDER: &str = "<|placeholder|>";
45
46 pub fn new(max_edge: Option<u32>) -> Self {
47 Self { max_edge }
48 }
49}
50
51impl Processor for Qwen2VLProcessor {
52 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
53 Arc::new(Qwen2VLImageProcessor {
54 max_edge: self.max_edge,
55 })
56 }
57
58 fn get_special_tokens(&self) -> &[&'static str] {
59 &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
60 }
61
62 fn template_action(&self) -> MessagesAction {
63 MessagesAction::FlattenOnlyText
64 }
65}
66
67fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
68 if let Some(pos) = text.find(to_replace) {
69 let mut result = text.to_string();
70 result.replace_range(pos..pos + to_replace.len(), replacement);
71 result
72 } else {
73 text.to_string()
74 }
75}
76
77fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
78 let mut sequences = Vec::new();
79 let mut start = None;
80
81 for (i, &num) in nums.iter().enumerate() {
82 if num == needle {
83 if start.is_none() {
84 start = Some(i);
85 }
86 } else if let Some(s) = start {
87 sequences.push((s, i));
88 start = None;
89 }
90 }
91
92 if let Some(s) = start {
93 sequences.push((s, nums.len()));
94 }
95
96 sequences
97}
98
99fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
101 let mut indices = Vec::new();
102 let mut start = 0;
103
104 while let Some(pos) = haystack[start..].find(needle) {
105 let index = start + pos;
106 indices.push(index + needle.len());
107 start = index + needle.len(); }
109
110 indices
111}
112
113impl InputsProcessor for Qwen2VLImageProcessor {
114 fn get_type(&self) -> InputsProcessorType {
115 InputsProcessorType::Vision
116 }
117 fn process_inputs(
118 &self,
119 tokenizer: Option<Arc<Tokenizer>>,
120 input_seqs: &mut [&mut Sequence],
121 is_prompt: bool,
122 is_xlora: bool,
123 device: &Device,
124 no_kv_cache: bool,
125 last_n_context_len: Option<(usize, usize)>,
126 return_raw_logits: bool,
127 other_config: Option<Arc<dyn Any>>,
128 mut paged_attn_metadata: Option<PagedAttentionMeta>,
129 prompt_chunksize: Option<NonZeroUsize>,
130 mapper: Option<&dyn DeviceMapper>,
131 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
132 if is_xlora {
133 return Box::new(std::iter::once(Err(anyhow::Error::msg(
134 "Cannot make inputs for X-LoRA vision model.",
135 ))));
136 }
137 if no_kv_cache {
138 return Box::new(std::iter::once(Err(anyhow::Error::msg(
139 "Vision model must have kv cache.",
140 ))));
141 }
142 if prompt_chunksize.is_some() {
144 warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
145 }
146 if input_seqs.len() != 1 {
147 return Box::new(std::iter::once(Err(anyhow::Error::msg(
148 "Qwen2-VL only supports batch size = 1",
149 ))));
150 }
151 let Some(tokenizer) = tokenizer else {
152 return Box::new(std::iter::once(Err(anyhow::Error::msg(
153 "MLlamaInputProcessor requires a specified tokenizer.",
154 ))));
155 };
156
157 let text_models_inputs_processor::InnerInputProcessorOutput {
158 inputs:
159 text_models_inputs_processor::InputMetadata {
160 input,
161 positions,
162 context_lens,
163 position_ids,
164 paged_attn_meta,
165 flash_meta,
166 },
167 seq_indices,
168 } = if is_prompt {
169 get_prompt_input(
170 input_seqs
171 .iter()
172 .map(|seq| seq.get_toks())
173 .collect::<Vec<_>>(),
174 input_seqs,
175 device,
176 last_n_context_len,
177 return_raw_logits,
178 paged_attn_metadata.as_mut(),
179 None, mapper,
181 )
182 .nth(0)
183 .unwrap()
184 .unwrap()
185 } else {
186 get_completion_input(
187 input_seqs
188 .iter()
189 .map(|seq| seq.get_toks())
190 .collect::<Vec<_>>(),
191 input_seqs,
192 device,
193 no_kv_cache,
194 last_n_context_len,
195 return_raw_logits,
196 paged_attn_metadata.as_mut(),
197 None, mapper,
199 )
200 .nth(0)
201 .unwrap()
202 .unwrap()
203 };
204 let config = other_config.expect("Need a PreProcessorConfig config.");
205 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
206
207 let has_images = input_seqs.iter().all(|seq| seq.has_images());
208
209 let (
210 new_input,
211 pixel_values,
212 image_grid_thw,
213 video_grid_thw,
214 continuous_img_pad,
215 continuous_vid_pad,
216 input_ids_searching,
217 image_nums,
218 video_nums,
219 ) = if has_images {
220 let mut pixel_values_accum = Vec::new();
221 let mut image_grid_thw_accum = Vec::new();
222 let mut video_grid_thw_accum = Vec::new();
223
224 let mut detok_seqs = tokenizer
225 .decode_batch(
226 &input_seqs
227 .iter()
228 .map(|seq| seq.get_toks())
229 .collect::<Vec<_>>(),
230 false,
231 )
232 .expect("Detokenization failed!");
233
234 for seq in input_seqs.iter_mut() {
235 let (pixel_values, image_grid_thw, video_grid_thw) =
236 if let Some(cached_pixel_values) = &seq.multimodal.cached_pixel_values {
237 (
238 cached_pixel_values.clone(),
239 seq.multimodal.cached_img_thw.clone(),
240 seq.multimodal.cached_vid_thw.clone(),
241 )
242 } else {
243 let PreprocessedImages {
244 pixel_values,
245 pixel_attention_mask: _,
246 image_sizes: _,
247 num_img_tokens: _,
248 aspect_ratio_ids: _,
249 aspect_ratio_mask: _,
250 num_tiles: _,
251 image_grid_thw,
252 video_grid_thw,
253 rows: _,
254 cols: _,
255 pixel_values_list: _,
256 tgt_sizes: _,
257 image_sizes_all: _,
258 num_crops: _,
259 } = self
260 .preprocess(
261 seq.clone_images()
262 .expect("Need to have images by this point."),
263 vec![],
264 config,
265 device,
266 (usize::MAX, usize::MAX), )
268 .expect("Preprocessing failed");
269
270 seq.multimodal.cached_pixel_values = Some(pixel_values.clone());
271 seq.multimodal.cached_img_thw = image_grid_thw.clone();
272 seq.multimodal.cached_vid_thw = video_grid_thw.clone();
273 (pixel_values, image_grid_thw, video_grid_thw)
274 };
275
276 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
277 image_grid_thw_accum.push(image_grid_thw); video_grid_thw_accum.push(video_grid_thw); }
280
281 let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
282 None
283 } else {
284 Some(
285 image_grid_thw_accum
286 .into_iter()
287 .map(|img| img.unwrap())
288 .collect::<Vec<_>>(),
289 )
290 };
291
292 let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
293 None
294 } else {
295 Some(
296 video_grid_thw_accum
297 .into_iter()
298 .map(|img| img.unwrap())
299 .collect::<Vec<_>>(),
300 )
301 };
302
303 if is_prompt {
304 if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
305 let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
306 for ((batch, text), seq) in
307 detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
308 {
309 if seq.multimodal.has_changed_prompt {
310 continue;
311 }
312 let mut index = 0;
313 while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
314 *text = replace_first_occurrence(
315 text,
316 Qwen2VLProcessor::IMAGE_PAD,
317 &Qwen2VLProcessor::PLACEHOLDER.repeat(
318 image_grid_thw_accum[batch]
319 .i(index)
320 .unwrap()
321 .to_vec1::<u32>()
322 .unwrap()
323 .iter()
324 .product::<u32>()
325 as usize
326 / merge_length,
327 ),
328 );
329 index += 1;
330 }
331 *text = text
332 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
333 }
334 }
335
336 if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
337 let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
338 let mut index = 0;
339 for ((batch, text), seq) in
340 detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
341 {
342 if seq.multimodal.has_changed_prompt {
343 continue;
344 }
345 while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
346 *text = replace_first_occurrence(
347 text,
348 Qwen2VLProcessor::VIDEO_PAD,
349 &Qwen2VLProcessor::PLACEHOLDER.repeat(
350 video_grid_thw_accum[batch]
351 .i(index)
352 .unwrap()
353 .to_vec1::<u32>()
354 .unwrap()
355 .iter()
356 .product::<u32>()
357 as usize
358 / merge_length,
359 ),
360 );
361 index += 1;
362 }
363 *text = text
364 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
365 }
366 }
367 }
368
369 let mut all_ids = Vec::new();
370 let mut all_continuous_img_pad = Vec::new();
371 let mut all_continuous_vid_pad = Vec::new();
372 for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
373 let toks = tokenizer
374 .encode_fast(detok.clone(), false)
375 .expect("Detokenization failed!");
376 let ids = toks.get_ids().to_vec();
377
378 if !seq.multimodal.has_changed_prompt {
379 seq.set_initial_prompt(detok.clone());
380
381 seq.set_toks_and_reallocate(ids.clone(), paged_attn_metadata.as_mut());
382 seq.multimodal.has_changed_prompt = true;
383 }
384 all_ids.push(ids.clone());
385
386 let img_pad = tokenizer
387 .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
388 .expect("Detokenization failed!")
389 .get_ids()
390 .to_vec();
391 let continuous_img_pad = find_sequences(&ids, img_pad[0]);
392 all_continuous_img_pad.push(continuous_img_pad);
393
394 let vid_pad = tokenizer
395 .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
396 .expect("Detokenization failed!")
397 .get_ids()
398 .to_vec();
399 let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
400 all_continuous_vid_pad.push(continuous_vid_pad);
401 }
402
403 let mut input_ids_searching = Vec::new();
404 let mut image_nums = Vec::new();
405 let mut video_nums = Vec::new();
406 for (seq, ids) in input_seqs.iter().zip(&all_ids) {
407 let prompt = seq.get_initial_prompt();
408 let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
409 image_nums.push(
410 match_indices
411 .iter()
412 .filter(|&&idx| {
413 prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
414 == *Qwen2VLProcessor::IMAGE_PAD
415 })
416 .count(),
417 );
418 video_nums.push(
419 match_indices
420 .iter()
421 .filter(|&&idx| {
422 prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
423 == *Qwen2VLProcessor::VIDEO_PAD
424 })
425 .count(),
426 );
427
428 input_ids_searching.push(ids.to_vec());
429 }
430
431 let mut all_ids_new = Vec::new();
432 let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
433 for ids in all_ids {
434 let pad = max_len - ids.len();
435 all_ids_new
436 .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
437 }
438
439 (
440 Some(Tensor::stack(&all_ids_new, 0).unwrap()),
441 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
442 image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
443 video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
444 all_continuous_img_pad,
445 all_continuous_vid_pad,
446 input_ids_searching,
447 image_nums,
448 video_nums,
449 )
450 } else {
451 (
452 None,
453 None,
454 None,
455 None,
456 vec![],
457 vec![],
458 vec![vec![]; input_seqs.len()],
459 vec![0; input_seqs.len()],
460 vec![0; input_seqs.len()],
461 )
462 };
463
464 let (input, input_ids_full) = match (new_input, is_prompt) {
465 (Some(new_input), true) => (new_input.clone(), new_input),
466 (Some(new_input), false) => (input, new_input),
467 (None, _) => (input.clone(), input.clone()),
468 };
469
470 let pixel_values = if is_prompt { pixel_values } else { None };
471
472 let seqlens = input_seqs.iter().map(|seq| seq.len()).collect::<Vec<_>>();
473
474 let inputs: Box<dyn Any> = Box::new(ModelInputs {
475 input_ids: input,
476 seqlen_offsets: positions,
477 context_lens,
478 position_ids,
479 pixel_values,
480 model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
481 input_ids_full,
482 image_grid_thw,
483 video_grid_thw,
484 seqlens,
485 continuous_img_pad,
486 continuous_vid_pad,
487 input_ids_searching,
488 image_nums,
489 video_nums,
490 }),
491 paged_attn_meta,
492 flash_meta,
493 });
494 Box::new(std::iter::once(Ok(InputProcessorOutput {
495 inputs,
496 seq_indices,
497 })))
498 }
499}
500
501impl Qwen2VLImageProcessor {
502 fn smart_resize(
503 &self,
504 height: usize,
505 width: usize,
506 factor: usize,
507 min_pixels: usize,
508 max_pixels: usize,
509 ) -> candle_core::Result<(usize, usize)> {
510 if height < factor || width < factor {
511 candle_core::bail!(
512 "height:{} or width:{} must be larger than factor:{}",
513 height,
514 width,
515 factor
516 );
517 } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
518 candle_core::bail!(
519 "absolute aspect ratio must be smaller than 200, got {:.2}",
520 height.max(width) as f64 / height.min(width) as f64
521 );
522 }
523
524 let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
525 let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
526
527 if h_bar * w_bar > max_pixels {
528 let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
529 h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
530 w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
531 } else if h_bar * w_bar < min_pixels {
532 let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
533 h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
534 w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
535 }
536
537 Ok((h_bar, w_bar))
538 }
539
540 fn preprocess_inner(
542 &self,
543 images: Vec<DynamicImage>,
544 config: &PreProcessorConfig,
545 device: &Device,
546 (mut height, mut width): (u32, u32),
547 ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
548 let mut processed_images = Vec::new();
549
550 for mut image in images {
551 image = image.resize_exact(
552 height,
553 width,
554 config
555 .resampling
556 .map(|resample| Some(resample).to_filter())
557 .unwrap_or(Ok(FilterType::CatmullRom))?,
558 );
559 image = DynamicImage::ImageRgb8(image.to_rgb8());
560 if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
561 let (resized_height, resized_width) = self.smart_resize(
562 height as usize,
563 width as usize,
564 config.patch_size.context("Require `patch_size`.")?
565 * config.merge_size.context("Require `merge_size`")?,
566 config.min_pixels.context("Require `min_pixels`")?,
567 config.max_pixels.context("Require `max_pixels`")?,
568 )?;
569 height = resized_height as u32;
570 width = resized_width as u32;
571 image = image.resize_exact(
572 resized_width as u32,
573 resized_height as u32,
574 config
575 .resampling
576 .map(|resample| Some(resample).to_filter())
577 .unwrap_or(Ok(FilterType::CatmullRom))?,
578 );
579 }
580
581 let to_tensor_rescale = Transforms {
582 input: &ToTensor,
583 inner_transforms: &[],
584 };
585 let image = image.apply(to_tensor_rescale, device)?;
586
587 let transforms = TensorTransforms {
588 inner_transforms: &[&Normalize {
589 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
590 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
591 }],
592 };
593 let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
594
595 processed_images.push(image);
596 }
597
598 let mut patches = Tensor::stack(&processed_images, 0)?;
599 let temporal_patch_size = config
600 .temporal_patch_size
601 .context("Require `temporal_patch_size")?;
602 let patch_size = config.patch_size.context("Require `patch_size")?;
603 let merge_size = config.merge_size.context("Require `merge_size")?;
604 if patches.dim(0)? == 1 {
606 patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
607 }
608 let channel = patches.dim(1)?;
609 let grid_t = patches.dim(0)? / temporal_patch_size;
610 let grid_h = height as usize / patch_size;
611 let grid_w = width as usize / patch_size;
612 patches = patches.reshape(&[
613 grid_t,
614 temporal_patch_size,
615 channel,
616 grid_h / merge_size,
617 merge_size,
618 patch_size,
619 grid_w / merge_size,
620 merge_size,
621 patch_size,
622 ])?;
623 patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
624 let flattened_patches = patches.reshape((
625 grid_t * grid_h * grid_w,
626 channel * temporal_patch_size * patch_size * patch_size,
627 ))?;
628
629 Ok((
630 flattened_patches,
631 (grid_t as u32, grid_h as u32, grid_w as u32),
632 ))
633 }
634}
635
636impl ImagePreProcessor for Qwen2VLImageProcessor {
637 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
638 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
639
640 fn preprocess(
641 &self,
642 mut images: Vec<DynamicImage>,
643 videos: Vec<Vec<DynamicImage>>,
644 config: &PreProcessorConfig,
645 device: &Device,
646 (_, _): (usize, usize),
647 ) -> candle_core::Result<PreprocessedImages> {
648 let mut pixel_values = Vec::new();
649 let mut vision_grid_thw = Vec::new();
650
651 if !images.is_empty() {
652 if let Some(max_edge) = self.max_edge {
653 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
654 }
655
656 let mut height = 0;
657 let mut width = 0;
658 for image in &images {
659 let (w, h) = image.dimensions();
660 if w > width {
661 width = w;
662 }
663 if h > height {
664 height = h;
665 }
666 }
667
668 for image in images {
669 let (patches, (t, h, w)) =
670 self.preprocess_inner(vec![image], config, device, (height, width))?;
671 pixel_values.push(patches);
672 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
673 }
674 let pixel_values = Tensor::stack(&pixel_values, 0)?;
675 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
676 return Ok(PreprocessedImages {
677 pixel_values,
678 pixel_attention_mask: None,
679 image_sizes: None,
680 num_img_tokens: None,
681 aspect_ratio_ids: None,
682 aspect_ratio_mask: None,
683 num_tiles: None,
684 image_grid_thw: Some(vision_grid_thw),
685 video_grid_thw: None,
686 rows: None,
687 cols: None,
688 pixel_values_list: None,
689 tgt_sizes: None,
690 image_sizes_all: None,
691 num_crops: None,
692 });
693 }
694
695 if !videos.is_empty() {
696 let mut height = 0;
697 let mut width = 0;
698 for image in &videos {
699 let (w, h) = image[0].dimensions();
700 if w > width {
701 width = w;
702 }
703 if h > height {
704 height = h;
705 }
706 }
707
708 for images in videos {
709 let (patches, (t, h, w)) =
710 self.preprocess_inner(images, config, device, (height, width))?;
711 pixel_values.push(patches);
712 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
713 }
714 let pixel_values = Tensor::stack(&pixel_values, 0)?;
715 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
716 return Ok(PreprocessedImages {
717 pixel_values,
718 pixel_attention_mask: None,
719 image_sizes: None,
720 num_img_tokens: None,
721 aspect_ratio_ids: None,
722 aspect_ratio_mask: None,
723 num_tiles: None,
724 image_grid_thw: None,
725 video_grid_thw: Some(vision_grid_thw),
726 rows: None,
727 cols: None,
728 pixel_values_list: None,
729 tgt_sizes: None,
730 image_sizes_all: None,
731 num_crops: None,
732 });
733 }
734 unreachable!()
735 }
736}