1use std::{
2 any::Any,
3 num::NonZeroUsize,
4 sync::{Arc, RwLock},
5};
6
7use anyhow::Result;
8use candle_core::{Context, Device, IndexOp, Tensor};
9use image::{imageops::FilterType, DynamicImage, GenericImageView};
10use mistralrs_vision::{
11 ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
12};
13use tokenizers::Tokenizer;
14use tracing::warn;
15
16use crate::{
17 device_map::DeviceMapper,
18 pipeline::{
19 text_models_inputs_processor::{
20 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
21 },
22 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
23 },
24 sequence::Sequence,
25 vision_models::{
26 image_processor::{ImagePreProcessor, PreprocessedImages},
27 preprocessor_config::{PreProcessorConfig, ToFilter},
28 ModelInputs,
29 },
30};
31
32use super::Qwen2VLVisionSpecificArgs;
33
34struct Qwen2VLImageProcessor {
36 merge_size: RwLock<Option<usize>>,
38 max_edge: Option<u32>,
39}
40pub struct Qwen2VLProcessor {
42 max_edge: Option<u32>,
43}
44
45impl Qwen2VLProcessor {
46 pub const VISION_START: &str = "<|vision_start|>";
47 pub const VISION_END: &str = "<|vision_end|>";
48 pub const IMAGE_PAD: &str = "<|image_pad|>";
49 pub const VIDEO_PAD: &str = "<|video_pad|>";
50 pub const PLACEHOLDER: &str = "<|placeholder|>";
51
52 pub fn new(max_edge: Option<u32>) -> Self {
53 Self { max_edge }
54 }
55}
56
57impl Processor for Qwen2VLProcessor {
58 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
59 Arc::new(Qwen2VLImageProcessor {
60 merge_size: RwLock::new(None),
61 max_edge: self.max_edge,
62 })
63 }
64
65 fn get_special_tokens(&self) -> &[&'static str] {
66 &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
67 }
68
69 fn template_action(&self) -> MessagesAction {
70 MessagesAction::FlattenOnlyText
71 }
72}
73
74fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
75 if let Some(pos) = text.find(to_replace) {
76 let mut result = text.to_string();
77 result.replace_range(pos..pos + to_replace.len(), replacement);
78 result
79 } else {
80 text.to_string()
81 }
82}
83
84fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
85 let mut sequences = Vec::new();
86 let mut start = None;
87
88 for (i, &num) in nums.iter().enumerate() {
89 if num == needle {
90 if start.is_none() {
91 start = Some(i);
92 }
93 } else if let Some(s) = start {
94 sequences.push((s, i));
95 start = None;
96 }
97 }
98
99 if let Some(s) = start {
100 sequences.push((s, nums.len()));
101 }
102
103 sequences
104}
105
106fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
108 let mut indices = Vec::new();
109 let mut start = 0;
110
111 while let Some(pos) = haystack[start..].find(needle) {
112 let index = start + pos;
113 indices.push(index + needle.len());
114 start = index + needle.len(); }
116
117 indices
118}
119
120impl InputsProcessor for Qwen2VLImageProcessor {
121 fn get_type(&self) -> InputsProcessorType {
122 InputsProcessorType::Vision
123 }
124 fn process_inputs(
125 &self,
126 tokenizer: Option<Arc<Tokenizer>>,
127 input_seqs: &mut [&mut Sequence],
128 is_prompt: bool,
129 is_xlora: bool,
130 device: &Device,
131 no_kv_cache: bool,
132 last_n_context_len: Option<(usize, usize)>,
133 return_raw_logits: bool,
134 other_config: Option<Arc<dyn Any>>,
135 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
136 prompt_chunksize: Option<NonZeroUsize>,
137 mapper: Option<&dyn DeviceMapper>,
138 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
139 if is_xlora {
140 return Box::new(std::iter::once(Err(anyhow::Error::msg(
141 "Cannot make inputs for X-LoRA vision model.",
142 ))));
143 }
144 if no_kv_cache {
145 return Box::new(std::iter::once(Err(anyhow::Error::msg(
146 "Vision model must have kv cache.",
147 ))));
148 }
149 if prompt_chunksize.is_some() {
151 warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
152 }
153 if input_seqs.len() != 1 {
154 return Box::new(std::iter::once(Err(anyhow::Error::msg(
155 "Qwen2-VL only supports batch size = 1",
156 ))));
157 }
158 let Some(tokenizer) = tokenizer else {
159 return Box::new(std::iter::once(Err(anyhow::Error::msg(
160 "MLlamaInputProcessor requires a specified tokenizer.",
161 ))));
162 };
163
164 let text_models_inputs_processor::InnerInputProcessorOutput {
165 inputs:
166 text_models_inputs_processor::InputMetadata {
167 input,
168 positions,
169 context_lens,
170 position_ids,
171 paged_attn_meta,
172 flash_meta,
173 },
174 seq_indices,
175 } = if is_prompt {
176 get_prompt_input(
177 input_seqs
178 .iter()
179 .map(|seq| seq.get_toks().to_vec())
180 .collect::<Vec<_>>(),
181 input_seqs,
182 device,
183 last_n_context_len,
184 return_raw_logits,
185 paged_attn_metadata.as_mut(),
186 None, mapper,
188 )
189 .nth(0)
190 .unwrap()
191 .unwrap()
192 } else {
193 get_completion_input(
194 input_seqs
195 .iter()
196 .map(|seq| seq.get_toks().to_vec())
197 .collect::<Vec<_>>(),
198 input_seqs,
199 device,
200 no_kv_cache,
201 last_n_context_len,
202 return_raw_logits,
203 paged_attn_metadata.as_mut(),
204 None, mapper,
206 )
207 .nth(0)
208 .unwrap()
209 .unwrap()
210 };
211 let config = other_config.expect("Need a PreProcessorConfig config.");
212 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
213
214 let has_images = input_seqs.iter().all(|seq| seq.has_images());
215
216 let (
217 new_input,
218 pixel_values,
219 image_grid_thw,
220 video_grid_thw,
221 continuous_img_pad,
222 continuous_vid_pad,
223 input_ids_searching,
224 image_nums,
225 video_nums,
226 ) = if has_images {
227 let mut pixel_values_accum = Vec::new();
228 let mut image_grid_thw_accum = Vec::new();
229 let mut video_grid_thw_accum = Vec::new();
230
231 let mut detok_seqs = tokenizer
232 .decode_batch(
233 &input_seqs
234 .iter()
235 .map(|seq| seq.get_toks())
236 .collect::<Vec<_>>(),
237 false,
238 )
239 .expect("Detokenization failed!");
240
241 for seq in input_seqs.iter_mut() {
242 let (pixel_values, image_grid_thw, video_grid_thw) =
243 if let Some(cached_pixel_values) = &seq.cached_pixel_values {
244 (
245 cached_pixel_values.clone(),
246 seq.cached_img_thw.clone(),
247 seq.cached_vid_thw.clone(),
248 )
249 } else {
250 let PreprocessedImages {
251 pixel_values,
252 pixel_attention_mask: _,
253 image_sizes: _,
254 num_img_tokens: _,
255 aspect_ratio_ids: _,
256 aspect_ratio_mask: _,
257 num_tiles: _,
258 image_grid_thw,
259 video_grid_thw,
260 rows: _,
261 cols: _,
262 pixel_values_list: _,
263 tgt_sizes: _,
264 image_sizes_all: _,
265 num_crops: _,
266 } = self
267 .preprocess(
268 seq.clone_images()
269 .expect("Need to have images by this point."),
270 vec![],
271 config,
272 device,
273 (usize::MAX, usize::MAX), )
275 .expect("Preprocessing failed");
276
277 seq.cached_pixel_values = Some(pixel_values.clone());
278 seq.cached_img_thw = image_grid_thw.clone();
279 seq.cached_vid_thw = video_grid_thw.clone();
280 (pixel_values, image_grid_thw, video_grid_thw)
281 };
282
283 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
284 image_grid_thw_accum.push(image_grid_thw); video_grid_thw_accum.push(video_grid_thw); }
287
288 let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
289 None
290 } else {
291 Some(
292 image_grid_thw_accum
293 .into_iter()
294 .map(|img| img.unwrap())
295 .collect::<Vec<_>>(),
296 )
297 };
298
299 let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
300 None
301 } else {
302 Some(
303 video_grid_thw_accum
304 .into_iter()
305 .map(|img| img.unwrap())
306 .collect::<Vec<_>>(),
307 )
308 };
309
310 if is_prompt {
311 if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
312 let merge_length = self.merge_size.read().unwrap().unwrap().pow(2);
313 let mut index = 0;
314 for (batch, text) in detok_seqs.iter_mut().enumerate() {
315 while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
316 *text = replace_first_occurrence(
317 text,
318 Qwen2VLProcessor::IMAGE_PAD,
319 &Qwen2VLProcessor::PLACEHOLDER.repeat(
320 image_grid_thw_accum[batch]
321 .i(index)
322 .unwrap()
323 .to_vec1::<u32>()
324 .unwrap()
325 .iter()
326 .product::<u32>()
327 as usize
328 / merge_length,
329 ),
330 );
331 index += 1;
332 }
333 *text = text
334 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
335 }
336 }
337
338 if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
339 let merge_length = self.merge_size.read().unwrap().unwrap().pow(2);
340 let mut index = 0;
341 for (batch, text) in detok_seqs.iter_mut().enumerate() {
342 while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
343 *text = replace_first_occurrence(
344 text,
345 Qwen2VLProcessor::VIDEO_PAD,
346 &Qwen2VLProcessor::PLACEHOLDER.repeat(
347 video_grid_thw_accum[batch]
348 .i(index)
349 .unwrap()
350 .to_vec1::<u32>()
351 .unwrap()
352 .iter()
353 .product::<u32>()
354 as usize
355 / merge_length,
356 ),
357 );
358 index += 1;
359 }
360 *text = text
361 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
362 }
363 }
364 }
365
366 let mut all_ids = Vec::new();
367 let mut all_continuous_img_pad = Vec::new();
368 let mut all_continuous_vid_pad = Vec::new();
369 for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
370 seq.set_initial_prompt(detok.clone());
371
372 let toks = tokenizer
373 .encode_fast(detok, false)
374 .expect("Detokenization failed!");
375
376 let ids = toks.get_ids().to_vec();
377 all_ids.push(ids.clone());
378
379 let img_pad = tokenizer
380 .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
381 .expect("Detokenization failed!")
382 .get_ids()
383 .to_vec();
384 let continuous_img_pad = find_sequences(&ids, img_pad[0]);
385 all_continuous_img_pad.push(continuous_img_pad);
386
387 let vid_pad = tokenizer
388 .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
389 .expect("Detokenization failed!")
390 .get_ids()
391 .to_vec();
392 let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
393 all_continuous_vid_pad.push(continuous_vid_pad);
394
395 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
396 }
397
398 let mut input_ids_searching = Vec::new();
399 let mut image_nums = Vec::new();
400 let mut video_nums = Vec::new();
401 for seq in input_seqs.iter() {
402 let prompt = seq.get_initial_prompt();
403 let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
404 image_nums.push(
405 match_indices
406 .iter()
407 .filter(|&&idx| {
408 prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
409 == *Qwen2VLProcessor::IMAGE_PAD
410 })
411 .count(),
412 );
413 video_nums.push(
414 match_indices
415 .iter()
416 .filter(|&&idx| {
417 prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
418 == *Qwen2VLProcessor::VIDEO_PAD
419 })
420 .count(),
421 );
422
423 let ids = tokenizer
424 .encode_fast(prompt, false)
425 .expect("Tokenization failed!");
426
427 input_ids_searching.push(ids.get_ids().to_vec());
428 }
429
430 let mut all_ids_new = Vec::new();
431 let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
432 for ids in all_ids {
433 let pad = max_len - ids.len();
434 all_ids_new
435 .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
436 }
437
438 (
439 Some(Tensor::stack(&all_ids_new, 0).unwrap()),
440 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
441 image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
442 video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
443 all_continuous_img_pad,
444 all_continuous_vid_pad,
445 input_ids_searching,
446 image_nums,
447 video_nums,
448 )
449 } else {
450 (
451 None,
452 None,
453 None,
454 None,
455 vec![],
456 vec![],
457 vec![vec![]; input_seqs.len()],
458 vec![0; input_seqs.len()],
459 vec![0; input_seqs.len()],
460 )
461 };
462
463 let (input, input_ids_full) = match (new_input, is_prompt) {
464 (Some(new_input), true) => (new_input.clone(), new_input),
465 (Some(new_input), false) => (input, new_input),
466 (None, _) => (input.clone(), input.clone()),
467 };
468
469 let pixel_values = if is_prompt { pixel_values } else { None };
470
471 let seqlens = input_seqs
472 .iter()
473 .map(|seq| seq.prompt_tokens())
474 .collect::<Vec<_>>();
475
476 let inputs: Box<dyn Any> = Box::new(ModelInputs {
477 input_ids: input,
478 seqlen_offsets: positions,
479 context_lens,
480 position_ids,
481 pixel_values,
482 model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
483 input_ids_full,
484 image_grid_thw,
485 video_grid_thw,
486 seqlens,
487 continuous_img_pad,
488 continuous_vid_pad,
489 input_ids_searching,
490 image_nums,
491 video_nums,
492 }),
493 paged_attn_meta,
494 flash_meta,
495 });
496 Box::new(std::iter::once(Ok(InputProcessorOutput {
497 inputs,
498 seq_indices,
499 })))
500 }
501}
502
503impl Qwen2VLImageProcessor {
504 fn smart_resize(
505 &self,
506 height: usize,
507 width: usize,
508 factor: usize,
509 min_pixels: usize,
510 max_pixels: usize,
511 ) -> candle_core::Result<(usize, usize)> {
512 if height < factor || width < factor {
513 candle_core::bail!(
514 "height:{} or width:{} must be larger than factor:{}",
515 height,
516 width,
517 factor
518 );
519 } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
520 candle_core::bail!(
521 "absolute aspect ratio must be smaller than 200, got {:.2}",
522 height.max(width) as f64 / height.min(width) as f64
523 );
524 }
525
526 let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
527 let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
528
529 if h_bar * w_bar > max_pixels {
530 let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
531 h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
532 w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
533 } else if h_bar * w_bar < min_pixels {
534 let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
535 h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
536 w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
537 }
538
539 Ok((h_bar, w_bar))
540 }
541
542 fn preprocess_inner(
544 &self,
545 images: Vec<DynamicImage>,
546 config: &PreProcessorConfig,
547 device: &Device,
548 (mut height, mut width): (u32, u32),
549 ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
550 let mut processed_images = Vec::new();
551
552 for mut image in images {
553 image = image.resize_exact(
554 height,
555 width,
556 config
557 .resampling
558 .map(|resample| Some(resample).to_filter())
559 .unwrap_or(Ok(FilterType::CatmullRom))?,
560 );
561 image = DynamicImage::ImageRgb8(image.to_rgb8());
562 if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
563 let (resized_height, resized_width) = self.smart_resize(
564 height as usize,
565 width as usize,
566 config.patch_size.context("Require `patch_size`.")?
567 * config.merge_size.context("Require `merge_size`")?,
568 config.min_pixels.context("Require `min_pixels`")?,
569 config.max_pixels.context("Require `max_pixels`")?,
570 )?;
571 height = resized_height as u32;
572 width = resized_width as u32;
573 image = image.resize_exact(
574 resized_width as u32,
575 resized_height as u32,
576 config
577 .resampling
578 .map(|resample| Some(resample).to_filter())
579 .unwrap_or(Ok(FilterType::CatmullRom))?,
580 );
581 }
582
583 let to_tensor_rescale = Transforms {
584 input: &ToTensor,
585 inner_transforms: &[],
586 };
587 let image = image.apply(to_tensor_rescale, device)?;
588
589 let transforms = TensorTransforms {
590 inner_transforms: &[&Normalize {
591 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
592 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
593 }],
594 };
595 let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
596
597 processed_images.push(image);
598 }
599
600 let mut patches = Tensor::stack(&processed_images, 0)?;
601 let temporal_patch_size = config
602 .temporal_patch_size
603 .context("Require `temporal_patch_size")?;
604 let patch_size = config.patch_size.context("Require `patch_size")?;
605 let merge_size = config.merge_size.context("Require `merge_size")?;
606 *self.merge_size.write().unwrap() = Some(merge_size);
608 if patches.dim(0)? == 1 {
610 patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
611 }
612 let channel = patches.dim(1)?;
613 let grid_t = patches.dim(0)? / temporal_patch_size;
614 let grid_h = height as usize / patch_size;
615 let grid_w = width as usize / patch_size;
616 patches = patches.reshape(&[
617 grid_t,
618 temporal_patch_size,
619 channel,
620 grid_h / merge_size,
621 merge_size,
622 patch_size,
623 grid_w / merge_size,
624 merge_size,
625 patch_size,
626 ])?;
627 patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
628 let flattened_patches = patches.reshape((
629 grid_t * grid_h * grid_w,
630 channel * temporal_patch_size * patch_size * patch_size,
631 ))?;
632
633 Ok((
634 flattened_patches,
635 (grid_t as u32, grid_h as u32, grid_w as u32),
636 ))
637 }
638}
639
640impl ImagePreProcessor for Qwen2VLImageProcessor {
641 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
642 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
643
644 fn preprocess(
645 &self,
646 mut images: Vec<DynamicImage>,
647 videos: Vec<Vec<DynamicImage>>,
648 config: &PreProcessorConfig,
649 device: &Device,
650 (_, _): (usize, usize),
651 ) -> candle_core::Result<PreprocessedImages> {
652 let mut pixel_values = Vec::new();
653 let mut vision_grid_thw = Vec::new();
654
655 if !images.is_empty() {
656 if let Some(max_edge) = self.max_edge {
657 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
658 }
659
660 let mut height = 0;
661 let mut width = 0;
662 for image in &images {
663 let (w, h) = image.dimensions();
664 if w > width {
665 width = w;
666 }
667 if h > height {
668 height = h;
669 }
670 }
671
672 for image in images {
673 let (patches, (t, h, w)) =
674 self.preprocess_inner(vec![image], config, device, (height, width))?;
675 pixel_values.push(patches);
676 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
677 }
678 let pixel_values = Tensor::stack(&pixel_values, 0)?;
679 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
680 return Ok(PreprocessedImages {
681 pixel_values,
682 pixel_attention_mask: None,
683 image_sizes: None,
684 num_img_tokens: None,
685 aspect_ratio_ids: None,
686 aspect_ratio_mask: None,
687 num_tiles: None,
688 image_grid_thw: Some(vision_grid_thw),
689 video_grid_thw: None,
690 rows: None,
691 cols: None,
692 pixel_values_list: None,
693 tgt_sizes: None,
694 image_sizes_all: None,
695 num_crops: None,
696 });
697 }
698
699 if !videos.is_empty() {
700 let mut height = 0;
701 let mut width = 0;
702 for image in &videos {
703 let (w, h) = image[0].dimensions();
704 if w > width {
705 width = w;
706 }
707 if h > height {
708 height = h;
709 }
710 }
711
712 for images in videos {
713 let (patches, (t, h, w)) =
714 self.preprocess_inner(images, config, device, (height, width))?;
715 pixel_values.push(patches);
716 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
717 }
718 let pixel_values = Tensor::stack(&pixel_values, 0)?;
719 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
720 return Ok(PreprocessedImages {
721 pixel_values,
722 pixel_attention_mask: None,
723 image_sizes: None,
724 num_img_tokens: None,
725 aspect_ratio_ids: None,
726 aspect_ratio_mask: None,
727 num_tiles: None,
728 image_grid_thw: None,
729 video_grid_thw: Some(vision_grid_thw),
730 rows: None,
731 cols: None,
732 pixel_values_list: None,
733 tgt_sizes: None,
734 image_sizes_all: None,
735 num_crops: None,
736 });
737 }
738 unreachable!()
739 }
740}