1use std::{
2 any::Any,
3 num::NonZeroUsize,
4 sync::{Arc, RwLock},
5};
6
7use anyhow::Result;
8use candle_core::{Context, Device, IndexOp, Tensor};
9use image::{imageops::FilterType, DynamicImage, GenericImageView};
10use mistralrs_vision::{
11 ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
12};
13use tokenizers::Tokenizer;
14use tracing::warn;
15
16use crate::{
17 device_map::DeviceMapper,
18 pipeline::{
19 text_models_inputs_processor::{
20 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
21 },
22 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
23 },
24 sequence::Sequence,
25 vision_models::{
26 image_processor::{ImagePreProcessor, PreprocessedImages},
27 preprocessor_config::{PreProcessorConfig, ToFilter},
28 ModelInputs,
29 },
30};
31
32use super::Qwen2VLVisionSpecificArgs;
33
34struct Qwen2VLImageProcessor {
36 merge_size: RwLock<Option<usize>>,
38 max_edge: Option<u32>,
39}
40pub struct Qwen2VLProcessor {
42 max_edge: Option<u32>,
43}
44
45impl Qwen2VLProcessor {
46 pub const VISION_START: &str = "<|vision_start|>";
47 pub const VISION_END: &str = "<|vision_end|>";
48 pub const IMAGE_PAD: &str = "<|image_pad|>";
49 pub const VIDEO_PAD: &str = "<|video_pad|>";
50 pub const PLACEHOLDER: &str = "<|placeholder|>";
51
52 pub fn new(max_edge: Option<u32>) -> Self {
53 Self { max_edge }
54 }
55}
56
57impl Processor for Qwen2VLProcessor {
58 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
59 Arc::new(Qwen2VLImageProcessor {
60 merge_size: RwLock::new(None),
61 max_edge: self.max_edge,
62 })
63 }
64
65 fn get_special_tokens(&self) -> &[&'static str] {
66 &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
67 }
68
69 fn template_action(&self) -> MessagesAction {
70 MessagesAction::FlattenOnlyText
71 }
72}
73
74fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
75 if let Some(pos) = text.find(to_replace) {
76 let mut result = text.to_string();
77 result.replace_range(pos..pos + to_replace.len(), replacement);
78 result
79 } else {
80 text.to_string()
81 }
82}
83
84fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
85 let mut sequences = Vec::new();
86 let mut start = None;
87
88 for (i, &num) in nums.iter().enumerate() {
89 if num == needle {
90 if start.is_none() {
91 start = Some(i);
92 }
93 } else if let Some(s) = start {
94 sequences.push((s, i));
95 start = None;
96 }
97 }
98
99 if let Some(s) = start {
100 sequences.push((s, nums.len()));
101 }
102
103 sequences
104}
105
106fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
108 let mut indices = Vec::new();
109 let mut start = 0;
110
111 while let Some(pos) = haystack[start..].find(needle) {
112 let index = start + pos;
113 indices.push(index + needle.len());
114 start = index + needle.len(); }
116
117 indices
118}
119
120impl InputsProcessor for Qwen2VLImageProcessor {
121 fn get_type(&self) -> InputsProcessorType {
122 InputsProcessorType::Vision
123 }
124 fn process_inputs(
125 &self,
126 tokenizer: Option<Arc<Tokenizer>>,
127 input_seqs: &mut [&mut Sequence],
128 is_prompt: bool,
129 is_xlora: bool,
130 device: &Device,
131 no_kv_cache: bool,
132 last_n_context_len: Option<(usize, usize)>,
133 return_raw_logits: bool,
134 other_config: Option<Arc<dyn Any>>,
135 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
136 prompt_chunksize: Option<NonZeroUsize>,
137 mapper: Option<&dyn DeviceMapper>,
138 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
139 if is_xlora {
140 return Box::new(std::iter::once(Err(anyhow::Error::msg(
141 "Cannot make inputs for X-LoRA vision model.",
142 ))));
143 }
144 if no_kv_cache {
145 return Box::new(std::iter::once(Err(anyhow::Error::msg(
146 "Vision model must have kv cache.",
147 ))));
148 }
149 if prompt_chunksize.is_some() {
151 warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
152 }
153 if input_seqs.len() != 1 {
154 return Box::new(std::iter::once(Err(anyhow::Error::msg(
155 "Qwen2-VL only supports batch size = 1",
156 ))));
157 }
158 let Some(tokenizer) = tokenizer else {
159 return Box::new(std::iter::once(Err(anyhow::Error::msg(
160 "MLlamaInputProcessor requires a specified tokenizer.",
161 ))));
162 };
163
164 let text_models_inputs_processor::InnerInputProcessorOutput {
165 inputs:
166 text_models_inputs_processor::InputMetadata {
167 input,
168 positions,
169 context_lens,
170 position_ids,
171 paged_attn_meta,
172 flash_meta,
173 },
174 seq_indices,
175 } = if is_prompt {
176 get_prompt_input(
177 input_seqs
178 .iter()
179 .map(|seq| seq.get_toks().to_vec())
180 .collect::<Vec<_>>(),
181 input_seqs,
182 device,
183 last_n_context_len,
184 return_raw_logits,
185 paged_attn_metadata.as_mut(),
186 None, mapper,
188 )
189 .nth(0)
190 .unwrap()
191 .unwrap()
192 } else {
193 get_completion_input(
194 input_seqs
195 .iter()
196 .map(|seq| seq.get_toks().to_vec())
197 .collect::<Vec<_>>(),
198 input_seqs,
199 device,
200 no_kv_cache,
201 last_n_context_len,
202 return_raw_logits,
203 paged_attn_metadata.as_mut(),
204 None, mapper,
206 )
207 .nth(0)
208 .unwrap()
209 .unwrap()
210 };
211 let config = other_config.expect("Need a PreProcessorConfig config.");
212 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
213
214 let has_images = input_seqs.iter().all(|seq| seq.has_images());
215
216 let (
217 new_input,
218 pixel_values,
219 image_grid_thw,
220 video_grid_thw,
221 continuous_img_pad,
222 continuous_vid_pad,
223 input_ids_searching,
224 image_nums,
225 video_nums,
226 ) = if has_images {
227 let mut pixel_values_accum = Vec::new();
228 let mut image_grid_thw_accum = Vec::new();
229 let mut video_grid_thw_accum = Vec::new();
230
231 let mut detok_seqs = tokenizer
232 .decode_batch(
233 &input_seqs
234 .iter()
235 .map(|seq| seq.get_toks())
236 .collect::<Vec<_>>(),
237 false,
238 )
239 .expect("Detokenization failed!");
240
241 for seq in input_seqs.iter_mut() {
242 let (pixel_values, image_grid_thw, video_grid_thw) =
243 if let Some(cached_pixel_values) = &seq.cached_pixel_values {
244 (
245 cached_pixel_values.clone(),
246 seq.cached_img_thw.clone(),
247 seq.cached_vid_thw.clone(),
248 )
249 } else {
250 let PreprocessedImages {
251 pixel_values,
252 pixel_attention_mask: _,
253 image_sizes: _,
254 num_img_tokens: _,
255 aspect_ratio_ids: _,
256 aspect_ratio_mask: _,
257 num_tiles: _,
258 image_grid_thw,
259 video_grid_thw,
260 rows: _,
261 cols: _,
262 pixel_values_list: _,
263 tgt_sizes: _,
264 image_sizes_all: _,
265 num_crops: _,
266 } = self
267 .preprocess(
268 seq.clone_images()
269 .expect("Need to have images by this point."),
270 vec![],
271 config,
272 device,
273 (usize::MAX, usize::MAX), )
275 .expect("Preprocessing failed");
276
277 seq.cached_pixel_values = Some(pixel_values.clone());
278 seq.cached_img_thw = image_grid_thw.clone();
279 seq.cached_vid_thw = video_grid_thw.clone();
280 (pixel_values, image_grid_thw, video_grid_thw)
281 };
282
283 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
284 image_grid_thw_accum.push(image_grid_thw); video_grid_thw_accum.push(video_grid_thw); }
287
288 let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
289 None
290 } else {
291 Some(
292 image_grid_thw_accum
293 .into_iter()
294 .map(|img| img.unwrap())
295 .collect::<Vec<_>>(),
296 )
297 };
298
299 let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
300 None
301 } else {
302 Some(
303 video_grid_thw_accum
304 .into_iter()
305 .map(|img| img.unwrap())
306 .collect::<Vec<_>>(),
307 )
308 };
309
310 if is_prompt {
311 if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
312 let merge_length = self.merge_size.read().unwrap().unwrap().pow(2);
313 let mut index = 0;
314 for (batch, text) in detok_seqs.iter_mut().enumerate() {
315 while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
316 *text = replace_first_occurrence(
317 text,
318 Qwen2VLProcessor::IMAGE_PAD,
319 &Qwen2VLProcessor::PLACEHOLDER.repeat(
320 image_grid_thw_accum[batch]
321 .i(index)
322 .unwrap()
323 .to_vec1::<u32>()
324 .unwrap()
325 .iter()
326 .product::<u32>()
327 as usize
328 / merge_length,
329 ),
330 );
331 index += 1;
332 }
333 *text = text
334 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
335 }
336 }
337
338 if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
339 let merge_length = self.merge_size.read().unwrap().unwrap().pow(2);
340 let mut index = 0;
341 for (batch, text) in detok_seqs.iter_mut().enumerate() {
342 while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
343 *text = replace_first_occurrence(
344 text,
345 Qwen2VLProcessor::VIDEO_PAD,
346 &Qwen2VLProcessor::PLACEHOLDER.repeat(
347 video_grid_thw_accum[batch]
348 .i(index)
349 .unwrap()
350 .to_vec1::<u32>()
351 .unwrap()
352 .iter()
353 .product::<u32>()
354 as usize
355 / merge_length,
356 ),
357 );
358 index += 1;
359 }
360 *text = text
361 .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
362 }
363 }
364 }
365
366 let mut all_ids = Vec::new();
367 let mut all_continuous_img_pad = Vec::new();
368 let mut all_continuous_vid_pad = Vec::new();
369 for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
370 let toks = tokenizer
371 .encode_fast(detok.clone(), false)
372 .expect("Detokenization failed!");
373 let ids = toks.get_ids().to_vec();
374
375 if !seq.has_changed_prompt {
376 seq.set_initial_prompt(detok.clone());
377
378 seq.set_toks_and_reallocate(ids.clone(), paged_attn_metadata.as_mut());
379 seq.has_changed_prompt = true;
380 }
381 all_ids.push(ids.clone());
382
383 let img_pad = tokenizer
384 .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
385 .expect("Detokenization failed!")
386 .get_ids()
387 .to_vec();
388 let continuous_img_pad = find_sequences(&ids, img_pad[0]);
389 all_continuous_img_pad.push(continuous_img_pad);
390
391 let vid_pad = tokenizer
392 .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
393 .expect("Detokenization failed!")
394 .get_ids()
395 .to_vec();
396 let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
397 all_continuous_vid_pad.push(continuous_vid_pad);
398 }
399
400 let mut input_ids_searching = Vec::new();
401 let mut image_nums = Vec::new();
402 let mut video_nums = Vec::new();
403 for seq in input_seqs.iter() {
404 let prompt = seq.get_initial_prompt();
405 let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
406 image_nums.push(
407 match_indices
408 .iter()
409 .filter(|&&idx| {
410 prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
411 == *Qwen2VLProcessor::IMAGE_PAD
412 })
413 .count(),
414 );
415 video_nums.push(
416 match_indices
417 .iter()
418 .filter(|&&idx| {
419 prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
420 == *Qwen2VLProcessor::VIDEO_PAD
421 })
422 .count(),
423 );
424
425 let ids = tokenizer
426 .encode_fast(prompt, false)
427 .expect("Tokenization failed!");
428
429 input_ids_searching.push(ids.get_ids().to_vec());
430 }
431
432 let mut all_ids_new = Vec::new();
433 let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
434 for ids in all_ids {
435 let pad = max_len - ids.len();
436 all_ids_new
437 .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
438 }
439
440 (
441 Some(Tensor::stack(&all_ids_new, 0).unwrap()),
442 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
443 image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
444 video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
445 all_continuous_img_pad,
446 all_continuous_vid_pad,
447 input_ids_searching,
448 image_nums,
449 video_nums,
450 )
451 } else {
452 (
453 None,
454 None,
455 None,
456 None,
457 vec![],
458 vec![],
459 vec![vec![]; input_seqs.len()],
460 vec![0; input_seqs.len()],
461 vec![0; input_seqs.len()],
462 )
463 };
464
465 let (input, input_ids_full) = match (new_input, is_prompt) {
466 (Some(new_input), true) => (new_input.clone(), new_input),
467 (Some(new_input), false) => (input, new_input),
468 (None, _) => (input.clone(), input.clone()),
469 };
470
471 let pixel_values = if is_prompt { pixel_values } else { None };
472
473 let seqlens = input_seqs
474 .iter()
475 .map(|seq| seq.prompt_tokens())
476 .collect::<Vec<_>>();
477
478 let inputs: Box<dyn Any> = Box::new(ModelInputs {
479 input_ids: input,
480 seqlen_offsets: positions,
481 context_lens,
482 position_ids,
483 pixel_values,
484 model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
485 input_ids_full,
486 image_grid_thw,
487 video_grid_thw,
488 seqlens,
489 continuous_img_pad,
490 continuous_vid_pad,
491 input_ids_searching,
492 image_nums,
493 video_nums,
494 }),
495 paged_attn_meta,
496 flash_meta,
497 });
498 Box::new(std::iter::once(Ok(InputProcessorOutput {
499 inputs,
500 seq_indices,
501 })))
502 }
503}
504
505impl Qwen2VLImageProcessor {
506 fn smart_resize(
507 &self,
508 height: usize,
509 width: usize,
510 factor: usize,
511 min_pixels: usize,
512 max_pixels: usize,
513 ) -> candle_core::Result<(usize, usize)> {
514 if height < factor || width < factor {
515 candle_core::bail!(
516 "height:{} or width:{} must be larger than factor:{}",
517 height,
518 width,
519 factor
520 );
521 } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
522 candle_core::bail!(
523 "absolute aspect ratio must be smaller than 200, got {:.2}",
524 height.max(width) as f64 / height.min(width) as f64
525 );
526 }
527
528 let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
529 let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
530
531 if h_bar * w_bar > max_pixels {
532 let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
533 h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
534 w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
535 } else if h_bar * w_bar < min_pixels {
536 let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
537 h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
538 w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
539 }
540
541 Ok((h_bar, w_bar))
542 }
543
544 fn preprocess_inner(
546 &self,
547 images: Vec<DynamicImage>,
548 config: &PreProcessorConfig,
549 device: &Device,
550 (mut height, mut width): (u32, u32),
551 ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
552 let mut processed_images = Vec::new();
553
554 for mut image in images {
555 image = image.resize_exact(
556 height,
557 width,
558 config
559 .resampling
560 .map(|resample| Some(resample).to_filter())
561 .unwrap_or(Ok(FilterType::CatmullRom))?,
562 );
563 image = DynamicImage::ImageRgb8(image.to_rgb8());
564 if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
565 let (resized_height, resized_width) = self.smart_resize(
566 height as usize,
567 width as usize,
568 config.patch_size.context("Require `patch_size`.")?
569 * config.merge_size.context("Require `merge_size`")?,
570 config.min_pixels.context("Require `min_pixels`")?,
571 config.max_pixels.context("Require `max_pixels`")?,
572 )?;
573 height = resized_height as u32;
574 width = resized_width as u32;
575 image = image.resize_exact(
576 resized_width as u32,
577 resized_height as u32,
578 config
579 .resampling
580 .map(|resample| Some(resample).to_filter())
581 .unwrap_or(Ok(FilterType::CatmullRom))?,
582 );
583 }
584
585 let to_tensor_rescale = Transforms {
586 input: &ToTensor,
587 inner_transforms: &[],
588 };
589 let image = image.apply(to_tensor_rescale, device)?;
590
591 let transforms = TensorTransforms {
592 inner_transforms: &[&Normalize {
593 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
594 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
595 }],
596 };
597 let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
598
599 processed_images.push(image);
600 }
601
602 let mut patches = Tensor::stack(&processed_images, 0)?;
603 let temporal_patch_size = config
604 .temporal_patch_size
605 .context("Require `temporal_patch_size")?;
606 let patch_size = config.patch_size.context("Require `patch_size")?;
607 let merge_size = config.merge_size.context("Require `merge_size")?;
608 *self.merge_size.write().unwrap() = Some(merge_size);
610 if patches.dim(0)? == 1 {
612 patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
613 }
614 let channel = patches.dim(1)?;
615 let grid_t = patches.dim(0)? / temporal_patch_size;
616 let grid_h = height as usize / patch_size;
617 let grid_w = width as usize / patch_size;
618 patches = patches.reshape(&[
619 grid_t,
620 temporal_patch_size,
621 channel,
622 grid_h / merge_size,
623 merge_size,
624 patch_size,
625 grid_w / merge_size,
626 merge_size,
627 patch_size,
628 ])?;
629 patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
630 let flattened_patches = patches.reshape((
631 grid_t * grid_h * grid_w,
632 channel * temporal_patch_size * patch_size * patch_size,
633 ))?;
634
635 Ok((
636 flattened_patches,
637 (grid_t as u32, grid_h as u32, grid_w as u32),
638 ))
639 }
640}
641
642impl ImagePreProcessor for Qwen2VLImageProcessor {
643 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
644 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
645
646 fn preprocess(
647 &self,
648 mut images: Vec<DynamicImage>,
649 videos: Vec<Vec<DynamicImage>>,
650 config: &PreProcessorConfig,
651 device: &Device,
652 (_, _): (usize, usize),
653 ) -> candle_core::Result<PreprocessedImages> {
654 let mut pixel_values = Vec::new();
655 let mut vision_grid_thw = Vec::new();
656
657 if !images.is_empty() {
658 if let Some(max_edge) = self.max_edge {
659 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
660 }
661
662 let mut height = 0;
663 let mut width = 0;
664 for image in &images {
665 let (w, h) = image.dimensions();
666 if w > width {
667 width = w;
668 }
669 if h > height {
670 height = h;
671 }
672 }
673
674 for image in images {
675 let (patches, (t, h, w)) =
676 self.preprocess_inner(vec![image], config, device, (height, width))?;
677 pixel_values.push(patches);
678 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
679 }
680 let pixel_values = Tensor::stack(&pixel_values, 0)?;
681 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
682 return Ok(PreprocessedImages {
683 pixel_values,
684 pixel_attention_mask: None,
685 image_sizes: None,
686 num_img_tokens: None,
687 aspect_ratio_ids: None,
688 aspect_ratio_mask: None,
689 num_tiles: None,
690 image_grid_thw: Some(vision_grid_thw),
691 video_grid_thw: None,
692 rows: None,
693 cols: None,
694 pixel_values_list: None,
695 tgt_sizes: None,
696 image_sizes_all: None,
697 num_crops: None,
698 });
699 }
700
701 if !videos.is_empty() {
702 let mut height = 0;
703 let mut width = 0;
704 for image in &videos {
705 let (w, h) = image[0].dimensions();
706 if w > width {
707 width = w;
708 }
709 if h > height {
710 height = h;
711 }
712 }
713
714 for images in videos {
715 let (patches, (t, h, w)) =
716 self.preprocess_inner(images, config, device, (height, width))?;
717 pixel_values.push(patches);
718 vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
719 }
720 let pixel_values = Tensor::stack(&pixel_values, 0)?;
721 let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
722 return Ok(PreprocessedImages {
723 pixel_values,
724 pixel_attention_mask: None,
725 image_sizes: None,
726 num_img_tokens: None,
727 aspect_ratio_ids: None,
728 aspect_ratio_mask: None,
729 num_tiles: None,
730 image_grid_thw: None,
731 video_grid_thw: Some(vision_grid_thw),
732 rows: None,
733 cols: None,
734 pixel_values_list: None,
735 tgt_sizes: None,
736 image_sizes_all: None,
737 num_crops: None,
738 });
739 }
740 unreachable!()
741 }
742}