mistralrs_core/vision_models/qwen2vl/
inputs_processor.rs

1use std::{any::Any, sync::Arc};
2
3use anyhow::Result;
4use candle_core::{Context, Device, IndexOp, Tensor};
5use image::{imageops::FilterType, DynamicImage, GenericImageView};
6use mistralrs_vision::{
7    ApplyTensorTransforms, ApplyTransforms, Normalize, TensorTransforms, ToTensor, Transforms,
8};
9use tokenizers::Tokenizer;
10
11use crate::{
12    device_map::DeviceMapper,
13    pipeline::{
14        text_models_inputs_processor::{
15            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16        },
17        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18    },
19    sequence::Sequence,
20    vision_models::{
21        image_processor::{ImagePreProcessor, PreprocessedImages},
22        preprocessor_config::{PreProcessorConfig, ToFilter},
23        ModelInputs,
24    },
25};
26
27use super::Qwen2VLVisionSpecificArgs;
28
29// Input processor
30struct Qwen2VLImageProcessor {
31    max_edge: Option<u32>,
32}
33// Processor
34pub struct Qwen2VLProcessor {
35    max_edge: Option<u32>,
36}
37
38impl Qwen2VLProcessor {
39    pub const VISION_START: &str = "<|vision_start|>";
40    pub const VISION_END: &str = "<|vision_end|>";
41    pub const IMAGE_PAD: &str = "<|image_pad|>";
42    pub const VIDEO_PAD: &str = "<|video_pad|>";
43    pub const PLACEHOLDER: &str = "<|placeholder|>";
44
45    pub fn new(max_edge: Option<u32>) -> Self {
46        Self { max_edge }
47    }
48}
49
50impl Processor for Qwen2VLProcessor {
51    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
52        Arc::new(Qwen2VLImageProcessor {
53            max_edge: self.max_edge,
54        })
55    }
56
57    fn get_special_tokens(&self) -> &[&'static str] {
58        &[Self::IMAGE_PAD, Self::VIDEO_PAD, Self::PLACEHOLDER]
59    }
60
61    fn template_action(&self) -> MessagesAction {
62        MessagesAction::FlattenOnlyText
63    }
64}
65
66fn replace_first_occurrence(text: &str, to_replace: &str, replacement: &str) -> String {
67    if let Some(pos) = text.find(to_replace) {
68        let mut result = text.to_string();
69        result.replace_range(pos..pos + to_replace.len(), replacement);
70        result
71    } else {
72        text.to_string()
73    }
74}
75
76fn find_sequences(nums: &[u32], needle: u32) -> Vec<(usize, usize)> {
77    let mut sequences = Vec::new();
78    let mut start = None;
79
80    for (i, &num) in nums.iter().enumerate() {
81        if num == needle {
82            if start.is_none() {
83                start = Some(i);
84            }
85        } else if let Some(s) = start {
86            sequences.push((s, i));
87            start = None;
88        }
89    }
90
91    if let Some(s) = start {
92        sequences.push((s, nums.len()));
93    }
94
95    sequences
96}
97
98// index + needle length
99fn find_substring_indices(haystack: &str, needle: &str) -> Vec<usize> {
100    let mut indices = Vec::new();
101    let mut start = 0;
102
103    while let Some(pos) = haystack[start..].find(needle) {
104        let index = start + pos;
105        indices.push(index + needle.len());
106        start = index + needle.len(); // Move past the last found occurrence
107    }
108
109    indices
110}
111
112impl InputsProcessor for Qwen2VLImageProcessor {
113    fn get_type(&self) -> InputsProcessorType {
114        InputsProcessorType::Vision
115    }
116    fn process_inputs(
117        &self,
118        tokenizer: Option<Arc<Tokenizer>>,
119        input_seqs: &mut [&mut Sequence],
120        is_prompt: bool,
121        is_xlora: bool,
122        device: &Device,
123        no_kv_cache: bool,
124        last_n_context_len: Option<(usize, usize)>,
125        return_raw_logits: bool,
126        other_config: Option<Arc<dyn Any>>,
127        mut paged_attn_metadata: Option<PagedAttentionMeta>,
128        mapper: Option<&dyn DeviceMapper>,
129    ) -> Result<InputProcessorOutput> {
130        if is_xlora {
131            return Err(anyhow::Error::msg(
132                "Cannot make inputs for X-LoRA vision model.",
133            ));
134        }
135        if no_kv_cache {
136            return Err(anyhow::Error::msg("Vision model must have kv cache."));
137        }
138        let Some(tokenizer) = tokenizer else {
139            return Err(anyhow::Error::msg(
140                "MLlamaInputProcessor requires a specified tokenizer.",
141            ));
142        };
143
144        let text_models_inputs_processor::InnerInputProcessorOutput {
145            inputs:
146                text_models_inputs_processor::InputMetadata {
147                    input,
148                    positions,
149                    context_lens,
150                    position_ids,
151                    paged_attn_meta,
152                    flash_meta,
153                },
154            seq_indices,
155        } = if is_prompt {
156            get_prompt_input(
157                input_seqs
158                    .iter()
159                    .map(|seq| seq.get_toks())
160                    .collect::<Vec<_>>(),
161                input_seqs,
162                device,
163                last_n_context_len,
164                return_raw_logits,
165                paged_attn_metadata.as_mut(),
166                mapper,
167            )
168            .unwrap()
169        } else {
170            get_completion_input(
171                input_seqs
172                    .iter()
173                    .map(|seq| seq.get_toks())
174                    .collect::<Vec<_>>(),
175                input_seqs,
176                device,
177                no_kv_cache,
178                last_n_context_len,
179                return_raw_logits,
180                paged_attn_metadata.as_mut(),
181                mapper,
182            )
183            .unwrap()
184        };
185        let config = other_config.expect("Need a PreProcessorConfig config.");
186        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
187
188        let has_images = input_seqs.iter().all(|seq| seq.has_images());
189
190        let (
191            new_input,
192            pixel_values,
193            image_grid_thw,
194            video_grid_thw,
195            continuous_img_pad,
196            continuous_vid_pad,
197            input_ids_searching,
198            image_nums,
199            video_nums,
200        ) = if has_images {
201            let mut pixel_values_accum = Vec::new();
202            let mut image_grid_thw_accum = Vec::new();
203            let mut video_grid_thw_accum = Vec::new();
204
205            let mut detok_seqs = tokenizer
206                .decode_batch(
207                    &input_seqs
208                        .iter()
209                        .map(|seq| seq.get_toks())
210                        .collect::<Vec<_>>(),
211                    false,
212                )
213                .expect("Detokenization failed!");
214
215            for seq in input_seqs.iter_mut() {
216                let (pixel_values, image_grid_thw, video_grid_thw) =
217                    if let Some(cached_pixel_values) = &seq.multimodal.cached_pixel_values {
218                        (
219                            cached_pixel_values.clone(),
220                            seq.multimodal.cached_img_thw.clone(),
221                            seq.multimodal.cached_vid_thw.clone(),
222                        )
223                    } else {
224                        let PreprocessedImages {
225                            pixel_values,
226                            pixel_attention_mask: _,
227                            image_sizes: _,
228                            num_img_tokens: _,
229                            aspect_ratio_ids: _,
230                            aspect_ratio_mask: _,
231                            num_tiles: _,
232                            image_grid_thw,
233                            video_grid_thw,
234                            rows: _,
235                            cols: _,
236                            pixel_values_list: _,
237                            tgt_sizes: _,
238                            image_sizes_all: _,
239                            num_crops: _,
240                        } = self
241                            .preprocess(
242                                seq.clone_images()
243                                    .expect("Need to have images by this point."),
244                                vec![],
245                                config,
246                                device,
247                                (usize::MAX, usize::MAX), // Don't use it here...
248                            )
249                            .expect("Preprocessing failed");
250
251                        seq.multimodal.cached_pixel_values = Some(pixel_values.clone());
252                        seq.multimodal.cached_img_thw = image_grid_thw.clone();
253                        seq.multimodal.cached_vid_thw = video_grid_thw.clone();
254                        (pixel_values, image_grid_thw, video_grid_thw)
255                    };
256
257                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
258                image_grid_thw_accum.push(image_grid_thw); //.map(|img| img.unsqueeze(0).unwrap()));
259                video_grid_thw_accum.push(video_grid_thw); //.map(|vid| vid.unsqueeze(0).unwrap()));
260            }
261
262            let image_grid_thw_accum = if image_grid_thw_accum.iter().any(|img| img.is_none()) {
263                None
264            } else {
265                Some(
266                    image_grid_thw_accum
267                        .into_iter()
268                        .map(|img| img.unwrap())
269                        .collect::<Vec<_>>(),
270                )
271            };
272
273            let video_grid_thw_accum = if video_grid_thw_accum.iter().any(|img| img.is_none()) {
274                None
275            } else {
276                Some(
277                    video_grid_thw_accum
278                        .into_iter()
279                        .map(|img| img.unwrap())
280                        .collect::<Vec<_>>(),
281                )
282            };
283
284            if is_prompt {
285                if let Some(ref image_grid_thw_accum) = image_grid_thw_accum {
286                    let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
287                    for ((batch, text), seq) in
288                        detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
289                    {
290                        if seq.multimodal.has_changed_prompt {
291                            continue;
292                        }
293                        let mut index = 0;
294                        while text.contains(Qwen2VLProcessor::IMAGE_PAD) {
295                            *text = replace_first_occurrence(
296                                text,
297                                Qwen2VLProcessor::IMAGE_PAD,
298                                &Qwen2VLProcessor::PLACEHOLDER.repeat(
299                                    image_grid_thw_accum[batch]
300                                        .i(index)
301                                        .unwrap()
302                                        .to_vec1::<u32>()
303                                        .unwrap()
304                                        .iter()
305                                        .product::<u32>()
306                                        as usize
307                                        / merge_length,
308                                ),
309                            );
310                            index += 1;
311                        }
312                        *text = text
313                            .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::IMAGE_PAD);
314                    }
315                }
316
317                if let Some(ref video_grid_thw_accum) = video_grid_thw_accum {
318                    let merge_length = config.merge_size.expect("Require `merge_size").pow(2);
319                    let mut index = 0;
320                    for ((batch, text), seq) in
321                        detok_seqs.iter_mut().enumerate().zip(input_seqs.iter_mut())
322                    {
323                        if seq.multimodal.has_changed_prompt {
324                            continue;
325                        }
326                        while text.contains(Qwen2VLProcessor::VIDEO_PAD) {
327                            *text = replace_first_occurrence(
328                                text,
329                                Qwen2VLProcessor::VIDEO_PAD,
330                                &Qwen2VLProcessor::PLACEHOLDER.repeat(
331                                    video_grid_thw_accum[batch]
332                                        .i(index)
333                                        .unwrap()
334                                        .to_vec1::<u32>()
335                                        .unwrap()
336                                        .iter()
337                                        .product::<u32>()
338                                        as usize
339                                        / merge_length,
340                                ),
341                            );
342                            index += 1;
343                        }
344                        *text = text
345                            .replace(Qwen2VLProcessor::PLACEHOLDER, Qwen2VLProcessor::VIDEO_PAD);
346                    }
347                }
348            }
349
350            let mut all_ids = Vec::new();
351            let mut all_continuous_img_pad = Vec::new();
352            let mut all_continuous_vid_pad = Vec::new();
353            for (detok, seq) in detok_seqs.into_iter().zip(input_seqs.iter_mut()) {
354                let toks = tokenizer
355                    .encode_fast(detok.clone(), false)
356                    .expect("Detokenization failed!");
357                let ids = toks.get_ids().to_vec();
358
359                if !seq.multimodal.has_changed_prompt {
360                    seq.set_initial_prompt(detok.clone());
361
362                    seq.set_toks_and_reallocate(ids.clone(), paged_attn_metadata.as_mut());
363                    seq.multimodal.has_changed_prompt = true;
364                }
365                all_ids.push(ids.clone());
366
367                let img_pad = tokenizer
368                    .encode_fast(Qwen2VLProcessor::IMAGE_PAD, false)
369                    .expect("Detokenization failed!")
370                    .get_ids()
371                    .to_vec();
372                let continuous_img_pad = find_sequences(&ids, img_pad[0]);
373                all_continuous_img_pad.push(continuous_img_pad);
374
375                let vid_pad = tokenizer
376                    .encode_fast(Qwen2VLProcessor::VIDEO_PAD, false)
377                    .expect("Detokenization failed!")
378                    .get_ids()
379                    .to_vec();
380                let continuous_vid_pad = find_sequences(&ids, vid_pad[0]);
381                all_continuous_vid_pad.push(continuous_vid_pad);
382            }
383
384            let mut input_ids_searching = Vec::new();
385            let mut image_nums = Vec::new();
386            let mut video_nums = Vec::new();
387            for (seq, ids) in input_seqs.iter().zip(&all_ids) {
388                let prompt = seq.get_initial_prompt();
389                let match_indices = find_substring_indices(prompt, Qwen2VLProcessor::VISION_START);
390                image_nums.push(
391                    match_indices
392                        .iter()
393                        .filter(|&&idx| {
394                            prompt[idx..idx + Qwen2VLProcessor::IMAGE_PAD.len()]
395                                == *Qwen2VLProcessor::IMAGE_PAD
396                        })
397                        .count(),
398                );
399                video_nums.push(
400                    match_indices
401                        .iter()
402                        .filter(|&&idx| {
403                            prompt[idx..idx + Qwen2VLProcessor::VIDEO_PAD.len()]
404                                == *Qwen2VLProcessor::VIDEO_PAD
405                        })
406                        .count(),
407                );
408
409                input_ids_searching.push(ids.to_vec());
410            }
411
412            let mut all_ids_new = Vec::new();
413            let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
414            for ids in all_ids {
415                let pad = max_len - ids.len();
416                all_ids_new
417                    .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
418            }
419
420            (
421                Some(Tensor::stack(&all_ids_new, 0).unwrap()),
422                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
423                image_grid_thw_accum.map(|img| Tensor::cat(&img, 0).unwrap()),
424                video_grid_thw_accum.map(|vid| Tensor::cat(&vid, 0).unwrap()),
425                all_continuous_img_pad,
426                all_continuous_vid_pad,
427                input_ids_searching,
428                image_nums,
429                video_nums,
430            )
431        } else {
432            (
433                None,
434                None,
435                None,
436                None,
437                vec![],
438                vec![],
439                vec![vec![]; input_seqs.len()],
440                vec![0; input_seqs.len()],
441                vec![0; input_seqs.len()],
442            )
443        };
444
445        let (input, input_ids_full) = match (new_input, is_prompt) {
446            (Some(new_input), true) => (new_input.clone(), new_input),
447            (Some(new_input), false) => (input, new_input),
448            (None, _) => (input.clone(), input.clone()),
449        };
450
451        let pixel_values = if is_prompt { pixel_values } else { None };
452
453        let seqlens = input_seqs.iter().map(|seq| seq.len()).collect::<Vec<_>>();
454
455        let inputs: Box<dyn Any> = Box::new(ModelInputs {
456            input_ids: input,
457            seqlen_offsets: positions,
458            context_lens,
459            position_ids,
460            pixel_values,
461            model_specific_args: Box::new(Qwen2VLVisionSpecificArgs {
462                input_ids_full,
463                image_grid_thw,
464                video_grid_thw,
465                seqlens,
466                continuous_img_pad,
467                continuous_vid_pad,
468                input_ids_searching,
469                image_nums,
470                video_nums,
471            }),
472            paged_attn_meta,
473            flash_meta,
474        });
475        Ok(InputProcessorOutput {
476            inputs,
477            seq_indices,
478        })
479    }
480}
481
482impl Qwen2VLImageProcessor {
483    fn smart_resize(
484        &self,
485        height: usize,
486        width: usize,
487        factor: usize,
488        min_pixels: usize,
489        max_pixels: usize,
490    ) -> candle_core::Result<(usize, usize)> {
491        if height < factor || width < factor {
492            candle_core::bail!(
493                "height:{} or width:{} must be larger than factor:{}",
494                height,
495                width,
496                factor
497            );
498        } else if (height.max(width) as f64 / height.min(width) as f64) > 200.0 {
499            candle_core::bail!(
500                "absolute aspect ratio must be smaller than 200, got {:.2}",
501                height.max(width) as f64 / height.min(width) as f64
502            );
503        }
504
505        let mut h_bar = (height as f64 / factor as f64).round() as usize * factor;
506        let mut w_bar = (width as f64 / factor as f64).round() as usize * factor;
507
508        if h_bar * w_bar > max_pixels {
509            let beta = ((height * width) as f64 / max_pixels as f64).sqrt();
510            h_bar = ((height as f64 / beta / factor as f64).floor() as usize) * factor;
511            w_bar = ((width as f64 / beta / factor as f64).floor() as usize) * factor;
512        } else if h_bar * w_bar < min_pixels {
513            let beta = (min_pixels as f64 / (height * width) as f64).sqrt();
514            h_bar = ((height as f64 * beta / factor as f64).ceil() as usize) * factor;
515            w_bar = ((width as f64 * beta / factor as f64).ceil() as usize) * factor;
516        }
517
518        Ok((h_bar, w_bar))
519    }
520
521    // patches and t,h,w
522    fn preprocess_inner(
523        &self,
524        images: Vec<DynamicImage>,
525        config: &PreProcessorConfig,
526        device: &Device,
527        (mut height, mut width): (u32, u32),
528    ) -> candle_core::Result<(Tensor, (u32, u32, u32))> {
529        let mut processed_images = Vec::new();
530
531        for mut image in images {
532            image = image.resize_exact(
533                height,
534                width,
535                config
536                    .resampling
537                    .map(|resample| Some(resample).to_filter())
538                    .unwrap_or(Ok(FilterType::CatmullRom))?,
539            );
540            image = DynamicImage::ImageRgb8(image.to_rgb8());
541            if config.do_resize.is_none() || config.do_resize.is_some_and(|x| x) {
542                let (resized_height, resized_width) = self.smart_resize(
543                    height as usize,
544                    width as usize,
545                    config.patch_size.context("Require `patch_size`.")?
546                        * config.merge_size.context("Require `merge_size`")?,
547                    config.min_pixels.context("Require `min_pixels`")?,
548                    config.max_pixels.context("Require `max_pixels`")?,
549                )?;
550                height = resized_height as u32;
551                width = resized_width as u32;
552                image = image.resize_exact(
553                    resized_width as u32,
554                    resized_height as u32,
555                    config
556                        .resampling
557                        .map(|resample| Some(resample).to_filter())
558                        .unwrap_or(Ok(FilterType::CatmullRom))?,
559                );
560            }
561
562            let to_tensor_rescale = Transforms {
563                input: &ToTensor,
564                inner_transforms: &[],
565            };
566            let image = image.apply(to_tensor_rescale, device)?;
567
568            let transforms = TensorTransforms {
569                inner_transforms: &[&Normalize {
570                    mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
571                    std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
572                }],
573            };
574            let image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
575
576            processed_images.push(image);
577        }
578
579        let mut patches = Tensor::stack(&processed_images, 0)?;
580        let temporal_patch_size = config
581            .temporal_patch_size
582            .context("Require `temporal_patch_size")?;
583        let patch_size = config.patch_size.context("Require `patch_size")?;
584        let merge_size = config.merge_size.context("Require `merge_size")?;
585        // Image
586        if patches.dim(0)? == 1 {
587            patches = patches.repeat((temporal_patch_size, 1, 1, 1))?;
588        }
589        let channel = patches.dim(1)?;
590        let grid_t = patches.dim(0)? / temporal_patch_size;
591        let grid_h = height as usize / patch_size;
592        let grid_w = width as usize / patch_size;
593        patches = patches.reshape(&[
594            grid_t,
595            temporal_patch_size,
596            channel,
597            grid_h / merge_size,
598            merge_size,
599            patch_size,
600            grid_w / merge_size,
601            merge_size,
602            patch_size,
603        ])?;
604        patches = patches.permute([0, 3, 6, 4, 7, 2, 1, 5, 8])?;
605        let flattened_patches = patches.reshape((
606            grid_t * grid_h * grid_w,
607            channel * temporal_patch_size * patch_size * patch_size,
608        ))?;
609
610        Ok((
611            flattened_patches,
612            (grid_t as u32, grid_h as u32, grid_w as u32),
613        ))
614    }
615}
616
617impl ImagePreProcessor for Qwen2VLImageProcessor {
618    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
619    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
620
621    fn preprocess(
622        &self,
623        mut images: Vec<DynamicImage>,
624        videos: Vec<Vec<DynamicImage>>,
625        config: &PreProcessorConfig,
626        device: &Device,
627        (_, _): (usize, usize),
628    ) -> candle_core::Result<PreprocessedImages> {
629        let mut pixel_values = Vec::new();
630        let mut vision_grid_thw = Vec::new();
631
632        if !images.is_empty() {
633            if let Some(max_edge) = self.max_edge {
634                images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
635            }
636
637            let mut height = 0;
638            let mut width = 0;
639            for image in &images {
640                let (w, h) = image.dimensions();
641                if w > width {
642                    width = w;
643                }
644                if h > height {
645                    height = h;
646                }
647            }
648
649            for image in images {
650                let (patches, (t, h, w)) =
651                    self.preprocess_inner(vec![image], config, device, (height, width))?;
652                pixel_values.push(patches);
653                vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
654            }
655            let pixel_values = Tensor::stack(&pixel_values, 0)?;
656            let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
657            return Ok(PreprocessedImages {
658                pixel_values,
659                pixel_attention_mask: None,
660                image_sizes: None,
661                num_img_tokens: None,
662                aspect_ratio_ids: None,
663                aspect_ratio_mask: None,
664                num_tiles: None,
665                image_grid_thw: Some(vision_grid_thw),
666                video_grid_thw: None,
667                rows: None,
668                cols: None,
669                pixel_values_list: None,
670                tgt_sizes: None,
671                image_sizes_all: None,
672                num_crops: None,
673            });
674        }
675
676        if !videos.is_empty() {
677            let mut height = 0;
678            let mut width = 0;
679            for image in &videos {
680                let (w, h) = image[0].dimensions();
681                if w > width {
682                    width = w;
683                }
684                if h > height {
685                    height = h;
686                }
687            }
688
689            for images in videos {
690                let (patches, (t, h, w)) =
691                    self.preprocess_inner(images, config, device, (height, width))?;
692                pixel_values.push(patches);
693                vision_grid_thw.push(Tensor::new(&[t, h, w], &Device::Cpu)?);
694            }
695            let pixel_values = Tensor::stack(&pixel_values, 0)?;
696            let vision_grid_thw = Tensor::stack(&vision_grid_thw, 0)?;
697            return Ok(PreprocessedImages {
698                pixel_values,
699                pixel_attention_mask: None,
700                image_sizes: None,
701                num_img_tokens: None,
702                aspect_ratio_ids: None,
703                aspect_ratio_mask: None,
704                num_tiles: None,
705                image_grid_thw: None,
706                video_grid_thw: Some(vision_grid_thw),
707                rows: None,
708                cols: None,
709                pixel_values_list: None,
710                tgt_sizes: None,
711                image_sizes_all: None,
712                num_crops: None,
713            });
714        }
715        unreachable!()
716    }
717}