mistralrs_core/vision_models/qwen2vl/
inputs_processor.rs

1use std::{any::Any, sync::Arc};
2
3use anyhow::Result;
4use candle_core::{Context, Device, IndexOp, Tensor};
5use image::{imageops::FilterType, DynamicImage, GenericImageView};
6use mistralrs_vision::{
7    ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
8};
9use tokenizers::Tokenizer;
10
11use crate::{
12    device_map::DeviceMapper,
13    pipeline::{
14        text_models_inputs_processor::{
15            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16        },
17        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18    },
19    sequence::Sequence,
20    vision_models::{
21        image_processor::{ImagePreProcessor, PreprocessedImages},
22        preprocessor_config::{PreProcessorConfig, ToFilter},
23        ModelInputs,
24    },
25};
26
27use super::Qwen2VLVisionSpecificArgs;
28
29// Input processor
30struct Qwen2VLImageProcessor {
31    max_edge: Option<u32>,
32}
33// Processor
34pub struct Qwen2VLProcessor {
35    max_edge: Option<u32>,
36}
37
38impl Qwen2VLProcessor {
39    pub const VISION_START: &str = "<|vision_start|>";
40    pub const VISION_END: &str = "<|vision_end|>";
41    pub const IMAGE_PAD: &str = "<|image_pad|>";
42    pub const VIDEO_PAD: &str = "<|video_pad|>";
43    pub const PLACEHOLDER: &str = "<|placeholder|>";
44
45    pub fn new(max_edge: Option<u32>) -> Self {
46        Self { max_edge }
47    }
48}
49
50impl Processor for Qwen2VLProcessor {
51    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
52        Arc::new(Qwen2VLImageProcessor {
53            max_edge: self.max_edge,
54        })
55    }
56
57    fn get_special_tokens(&self) -> &[&'static str] {
58        &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
59    }
60
61    fn template_action(&self) -> MessagesAction {
62        MessagesAction::FlattenOnlyText
63    }
64}
65
66fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
67    if let Some(pos) = text.find(to_replace) {
68        let mut result = text.to_string();
69        result.replace_range(pos..pos + to_replace.len(), replacement);
70        result
71    } else {
72        text.to_string()
73    }
74}
75
76fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
77    let mut sequences = Vec::new();
78    let mut start = None;
79
80    for (i, &num) in nums.iter().enumerate() {
81        if num == needle {
82            if start.is_none() {
83                start = Some(i);
84            }
85        } else if let Some(s) = start {
86            sequences.push((s, i));
87            start = None;
88        }
89    }
90
91    if let Some(s) = start {
92        sequences.push((s, nums.len()));
93    }
94
95    sequences
96}
97
98// index + needle length
99fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
100    let mut indices = Vec::new();
101    let mut start = 0;
102
103    while let Some(pos) = haystack[start..].find(needle) {
104        let index = start + pos;
105        indices.push(index + needle.len());
106        start = index + needle.len(); // Move past the last found occurrence
107    }
108
109    indices
110}
111
112impl InputsProcessor for Qwen2VLImageProcessor {
113    fn get_type(&self) -> InputsProcessorType {
114        InputsProcessorType::Vision
115    }
116    fn process_inputs(
117        &self,
118        tokenizer: Option<Arc<Tokenizer>>,
119        input_seqs: &mut [&mut Sequence],
120        is_prompt: bool,
121        is_xlora: bool,
122        device: &Device,
123        no_kv_cache: bool,
124        last_n_context_len: Option<(usize, usize)>,
125        return_raw_logits: bool,
126        other_config: Option<Arc<dyn Any>>,
127        mut paged_attn_metadata: Option<PagedAttentionMeta>,
128        mapper: Option<&dyn DeviceMapper>,
129    ) -> Result<InputProcessorOutput> {
130        if is_xlora {
131            return Err(anyhow::Error::msg(
132                "Cannot make inputs for X-LoRA vision model.",
133            ));
134        }
135        if no_kv_cache {
136            return Err(anyhow::Error::msg("Vision model must have kv cache."));
137        }
138        if input_seqs.len() != 1 {
139            return Err(anyhow::Error::msg("Qwen2-VL only supports batch size = 1"));
140        }
141        let Some(tokenizer) = tokenizer else {
142            return Err(anyhow::Error::msg(
143                "MLlamaInputProcessor requires a specified tokenizer.",
144            ));
145        };
146
147        let text_models_inputs_processor::InnerInputProcessorOutput {
148            inputs:
149                text_models_inputs_processor::InputMetadata {
150                    input,
151                    positions,
152                    context_lens,
153                    position_ids,
154                    paged_attn_meta,
155                    flash_meta,
156                },
157            seq_indices,
158        } = if is_prompt {
159            get_prompt_input(
160                input_seqs
161                    .iter()
162                    .map(|seq| seq.get_toks())
163                    .collect::<Vec<_>>(),
164                input_seqs,
165                device,
166                last_n_context_len,
167                return_raw_logits,
168                paged_attn_metadata.as_mut(),
169                mapper,
170            )
171            .unwrap()
172        } else {
173            get_completion_input(
174                input_seqs
175                    .iter()
176                    .map(|seq| seq.get_toks())
177                    .collect::<Vec<_>>(),
178                input_seqs,
179                device,
180                no_kv_cache,
181                last_n_context_len,
182                return_raw_logits,
183                paged_attn_metadata.as_mut(),
184                mapper,
185            )
186            .unwrap()
187        };
188        let config = other_config.expect("Need a PreProcessorConfig config.");
189        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
190
191        let has_images = input_seqs.iter().all(|seq| seq.has_images());
192
193        let (
194            new_input,
195            pixel_values,
196            image_grid_thw,
197            video_grid_thw,
198            continuous_img_pad,
199            continuous_vid_pad,
200            input_ids_searching,
201            image_nums,
202            video_nums,
203        ) = if has_images {
204            let mut pixel_values_accum = Vec::new();
205            let mut image_grid_thw_accum = Vec::new();
206            let mut video_grid_thw_accum = Vec::new();
207
208            let mut detok_seqs = tokenizer
209                .decode_batch(
210                    &input_seqs
211                        .iter()
212                        .map(|seq| seq.get_toks())
213                        .collect::<Vec<_>>(),
214                    false,
215                )
216                .expect("Detokenization failed!");
217
218            for seq in input_seqs.iter_mut() {
219                let (pixel_values, image_grid_thw, video_grid_thw) =
220                    if let Some(cached_pixel_values) = &seq.multimodal.cached_pixel_values {
221                        (
222                            cached_pixel_values.clone(),
223                            seq.multimodal.cached_img_thw.clone(),
224                            seq.multimodal.cached_vid_thw.clone(),
225                        )
226                    } else {
227                        let PreprocessedImages {
228                            pixel_values,
229                            pixel_attention_mask: _,
230                            image_sizes: _,
231                            num_img_tokens: _,
232                            aspect_ratio_ids: _,
233                            aspect_ratio_mask: _,
234                            num_tiles: _,
235                            image_grid_thw,
236                            video_grid_thw,
237                            rows: _,
238                            cols: _,
239                            pixel_values_list: _,
240                            tgt_sizes: _,
241                            image_sizes_all: _,
242                            num_crops: _,
243                        } = self
244                            .preprocess(
245                                seq.clone_images()
246                                    .expect("Need to have images by this point."),
247                                vec![],
248                                config,
249                                device,
250                                (usize::MAX, usize::MAX), // Don't use it here...
251                            )
252                            .expect("Preprocessing failed");
253
254                        seq.multimodal.cached_pixel_values = Some(pixel_values.clone());
255                        seq.multimodal.cached_img_thw = image_grid_thw.clone();
256                        seq.multimodal.cached_vid_thw = video_grid_thw.clone();
257                        (pixel_values, image_grid_thw, video_grid_thw)
258                    };
259
260                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
261                image_grid_thw_accum.push(image_grid_thw); //.map(|img| img.unsqueeze(0).unwrap()));
262                video_grid_thw_accum.push(video_grid_thw); //.map(|vid| vid.unsqueeze(0).unwrap()));
263            }
264
265            let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
266                None
267            } else {
268                Some(
269                    image_grid_thw_accum
270                        .into_iter()
271                        .map(|img| img.unwrap())
272                        .collect::<Vec<_>>(),
273                )
274            };
275
276            let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
277                None
278            } else {
279                Some(
280                    video_grid_thw_accum
281                        .into_iter()
282                        .map(|img| img.unwrap())
283                        .collect::<Vec<_>>(),
284                )
285            };
286
287            if is_prompt {
288                if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
289                    let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
290                    for ((batch, text), seq) in
291                        detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
292                    {
293                        if seq.multimodal.has_changed_prompt {
294                            continue;
295                        }
296                        let mut index = 0;
297                        while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
298                            *text = replace_first_occurrence(
299                                text,
300                                Qwen2VLProcessor::IMAGE_PAD,
301                                &Qwen2VLProcessor::PLACEHOLDER.repeat(
302                                    image_grid_thw_accum[batch]
303                                        .i(index)
304                                        .unwrap()
305                                        .to_vec1::<u32>()
306                                        .unwrap()
307                                        .iter()
308                                        .product::<u32>()
309                                        as usize
310                                        / merge_length,
311                                ),
312                            );
313                            index += 1;
314                        }
315                        *text = text
316                            .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
317                    }
318                }
319
320                if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
321                    let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
322                    let mut index = 0;
323                    for ((batch, text), seq) in
324                        detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
325                    {
326                        if seq.multimodal.has_changed_prompt {
327                            continue;
328                        }
329                        while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
330                            *text = replace_first_occurrence(
331                                text,
332                                Qwen2VLProcessor::VIDEO_PAD,
333                                &Qwen2VLProcessor::PLACEHOLDER.repeat(
334                                    video_grid_thw_accum[batch]
335                                        .i(index)
336                                        .unwrap()
337                                        .to_vec1::<u32>()
338                                        .unwrap()
339                                        .iter()
340                                        .product::<u32>()
341                                        as usize
342                                        / merge_length,
343                                ),
344                            );
345                            index += 1;
346                        }
347                        *text = text
348                            .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
349                    }
350                }
351            }
352
353            let mut all_ids = Vec::new();
354            let mut all_continuous_img_pad = Vec::new();
355            let mut all_continuous_vid_pad = Vec::new();
356            for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
357                let toks = tokenizer
358                    .encode_fast(detok.clone(), false)
359                    .expect("Detokenization failed!");
360                let ids = toks.get_ids().to_vec();
361
362                if !seq.multimodal.has_changed_prompt {
363                    seq.set_initial_prompt(detok.clone());
364
365                    seq.set_toks_and_reallocate(ids.clone(), paged_attn_metadata.as_mut());
366                    seq.multimodal.has_changed_prompt = true;
367                }
368                all_ids.push(ids.clone());
369
370                let img_pad = tokenizer
371                    .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
372                    .expect("Detokenization failed!")
373                    .get_ids()
374                    .to_vec();
375                let continuous_img_pad = find_sequences(&ids, img_pad[0]);
376                all_continuous_img_pad.push(continuous_img_pad);
377
378                let vid_pad = tokenizer
379                    .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
380                    .expect("Detokenization failed!")
381                    .get_ids()
382                    .to_vec();
383                let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
384                all_continuous_vid_pad.push(continuous_vid_pad);
385            }
386
387            let mut input_ids_searching = Vec::new();
388            let mut image_nums = Vec::new();
389            let mut video_nums = Vec::new();
390            for (seq, ids) in input_seqs.iter().zip(&all_ids) {
391                let prompt = seq.get_initial_prompt();
392                let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
393                image_nums.push(
394                    match_indices
395                        .iter()
396                        .filter(|&&idx| {
397                            prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
398                                == *Qwen2VLProcessor::IMAGE_PAD
399                        })
400                        .count(),
401                );
402                video_nums.push(
403                    match_indices
404                        .iter()
405                        .filter(|&&idx| {
406                            prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
407                                == *Qwen2VLProcessor::VIDEO_PAD
408                        })
409                        .count(),
410                );
411
412                input_ids_searching.push(ids.to_vec());
413            }
414
415            let mut all_ids_new = Vec::new();
416            let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
417            for ids in all_ids {
418                let pad = max_len - ids.len();
419                all_ids_new
420                    .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
421            }
422
423            (
424                Some(Tensor::stack(&all_ids_new, 0).unwrap()),
425                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
426                image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
427                video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
428                all_continuous_img_pad,
429                all_continuous_vid_pad,
430                input_ids_searching,
431                image_nums,
432                video_nums,
433            )
434        } else {
435            (
436                None,
437                None,
438                None,
439                None,
440                vec![],
441                vec![],
442                vec![vec![]; input_seqs.len()],
443                vec![0; input_seqs.len()],
444                vec![0; input_seqs.len()],
445            )
446        };
447
448        let (input, input_ids_full) = match (new_input, is_prompt) {
449            (Some(new_input), true) => (new_input.clone(), new_input),
450            (Some(new_input), false) => (input, new_input),
451            (None, _) => (input.clone(), input.clone()),
452        };
453
454        let pixel_values = if is_prompt { pixel_values } else { None };
455
456        let seqlens = input_seqs.iter().map(|seq| seq.len()).collect::<Vec<_>>();
457
458        let inputs: Box<dyn Any> = Box::new(ModelInputs {
459            input_ids: input,
460            seqlen_offsets: positions,
461            context_lens,
462            position_ids,
463            pixel_values,
464            model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
465                input_ids_full,
466                image_grid_thw,
467                video_grid_thw,
468                seqlens,
469                continuous_img_pad,
470                continuous_vid_pad,
471                input_ids_searching,
472                image_nums,
473                video_nums,
474            }),
475            paged_attn_meta,
476            flash_meta,
477        });
478        Ok(InputProcessorOutput {
479            inputs,
480            seq_indices,
481        })
482    }
483}
484
485impl Qwen2VLImageProcessor {
486    fn smart_resize(
487        &self,
488        height: usize,
489        width: usize,
490        factor: usize,
491        min_pixels: usize,
492        max_pixels: usize,
493    ) -> candle_core::Result<(usize, usize)> {
494        if height < factor || width < factor {
495            candle_core::bail!(
496                "height:{} or width:{} must be larger than factor:{}",
497                height,
498                width,
499                factor
500            );
501        } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
502            candle_core::bail!(
503                "absolute aspect ratio must be smaller than 200, got {:.2}",
504                height.max(width) as f64 / height.min(width) as f64
505            );
506        }
507
508        let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
509        let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
510
511        if h_bar * w_bar > max_pixels {
512            let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
513            h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
514            w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
515        } else if h_bar * w_bar < min_pixels {
516            let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
517            h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
518            w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
519        }
520
521        Ok((h_bar, w_bar))
522    }
523
524    // patches and t,h,w
525    fn preprocess_inner(
526        &self,
527        images: Vec<DynamicImage>,
528        config: &PreProcessorConfig,
529        device: &Device,
530        (mut height, mut width): (u32, u32),
531    ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
532        let mut processed_images = Vec::new();
533
534        for mut image in images {
535            image = image.resize_exact(
536                height,
537                width,
538                config
539                    .resampling
540                    .map(|resample| Some(resample).to_filter())
541                    .unwrap_or(Ok(FilterType::CatmullRom))?,
542            );
543            image = DynamicImage::ImageRgb8(image.to_rgb8());
544            if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
545                let (resized_height, resized_width) = self.smart_resize(
546                    height as usize,
547                    width as usize,
548                    config.patch_size.context("Require `patch_size`.")?
549                        * config.merge_size.context("Require `merge_size`")?,
550                    config.min_pixels.context("Require `min_pixels`")?,
551                    config.max_pixels.context("Require `max_pixels`")?,
552                )?;
553                height = resized_height as u32;
554                width = resized_width as u32;
555                image = image.resize_exact(
556                    resized_width as u32,
557                    resized_height as u32,
558                    config
559                        .resampling
560                        .map(|resample| Some(resample).to_filter())
561                        .unwrap_or(Ok(FilterType::CatmullRom))?,
562                );
563            }
564
565            let to_tensor_rescale = Transforms {
566                input: &ToTensor,
567                inner_transforms: &[],
568            };
569            let image = image.apply(to_tensor_rescale, device)?;
570
571            let transforms = TensorTransforms {
572                inner_transforms: &[&Normalize {
573                    mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
574                    std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
575                }],
576            };
577            let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
578
579            processed_images.push(image);
580        }
581
582        let mut patches = Tensor::stack(&processed_images, 0)?;
583        let temporal_patch_size = config
584            .temporal_patch_size
585            .context("Require `temporal_patch_size")?;
586        let patch_size = config.patch_size.context("Require `patch_size")?;
587        let merge_size = config.merge_size.context("Require `merge_size")?;
588        // Image
589        if patches.dim(0)? == 1 {
590            patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
591        }
592        let channel = patches.dim(1)?;
593        let grid_t = patches.dim(0)? / temporal_patch_size;
594        let grid_h = height as usize / patch_size;
595        let grid_w = width as usize / patch_size;
596        patches = patches.reshape(&[
597            grid_t,
598            temporal_patch_size,
599            channel,
600            grid_h / merge_size,
601            merge_size,
602            patch_size,
603            grid_w / merge_size,
604            merge_size,
605            patch_size,
606        ])?;
607        patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
608        let flattened_patches = patches.reshape((
609            grid_t * grid_h * grid_w,
610            channel * temporal_patch_size * patch_size * patch_size,
611        ))?;
612
613        Ok((
614            flattened_patches,
615            (grid_t as u32, grid_h as u32, grid_w as u32),
616        ))
617    }
618}
619
620impl ImagePreProcessor for Qwen2VLImageProcessor {
621    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
622    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
623
624    fn preprocess(
625        &self,
626        mut images: Vec<DynamicImage>,
627        videos: Vec<Vec<DynamicImage>>,
628        config: &PreProcessorConfig,
629        device: &Device,
630        (_, _): (usize, usize),
631    ) -> candle_core::Result<PreprocessedImages> {
632        let mut pixel_values = Vec::new();
633        let mut vision_grid_thw = Vec::new();
634
635        if !images.is_empty() {
636            if let Some(max_edge) = self.max_edge {
637                images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
638            }
639
640            let mut height = 0;
641            let mut width = 0;
642            for image in &images {
643                let (w, h) = image.dimensions();
644                if w > width {
645                    width = w;
646                }
647                if h > height {
648                    height = h;
649                }
650            }
651
652            for image in images {
653                let (patches, (t, h, w)) =
654                    self.preprocess_inner(vec![image], config, device, (height, width))?;
655                pixel_values.push(patches);
656                vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
657            }
658            let pixel_values = Tensor::stack(&pixel_values, 0)?;
659            let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
660            return Ok(PreprocessedImages {
661                pixel_values,
662                pixel_attention_mask: None,
663                image_sizes: None,
664                num_img_tokens: None,
665                aspect_ratio_ids: None,
666                aspect_ratio_mask: None,
667                num_tiles: None,
668                image_grid_thw: Some(vision_grid_thw),
669                video_grid_thw: None,
670                rows: None,
671                cols: None,
672                pixel_values_list: None,
673                tgt_sizes: None,
674                image_sizes_all: None,
675                num_crops: None,
676            });
677        }
678
679        if !videos.is_empty() {
680            let mut height = 0;
681            let mut width = 0;
682            for image in &videos {
683                let (w, h) = image[0].dimensions();
684                if w > width {
685                    width = w;
686                }
687                if h > height {
688                    height = h;
689                }
690            }
691
692            for images in videos {
693                let (patches, (t, h, w)) =
694                    self.preprocess_inner(images, config, device, (height, width))?;
695                pixel_values.push(patches);
696                vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
697            }
698            let pixel_values = Tensor::stack(&pixel_values, 0)?;
699            let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
700            return Ok(PreprocessedImages {
701                pixel_values,
702                pixel_attention_mask: None,
703                image_sizes: None,
704                num_img_tokens: None,
705                aspect_ratio_ids: None,
706                aspect_ratio_mask: None,
707                num_tiles: None,
708                image_grid_thw: None,
709                video_grid_thw: Some(vision_grid_thw),
710                rows: None,
711                cols: None,
712                pixel_values_list: None,
713                tgt_sizes: None,
714                image_sizes_all: None,
715                num_crops: None,
716            });
717        }
718        unreachable!()
719    }
720}