mistralrs_core/vision_models/idefics3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, cmp, collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9
10use crate::{
11    device_map::DeviceMapper,
12    pipeline::{
13        text_models_inputs_processor::{
14            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
15        },
16        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
17    },
18    sequence::Sequence,
19    vision_models::ModelInputs,
20};
21
22use crate::vision_models::{
23    image_processor::{ImagePreProcessor, PreprocessedImages},
24    preprocessor_config::{PreProcessorConfig, ToFilter},
25    processor_config::ProcessorConfig,
26};
27
28// 4k resolution as absolute maximum
29const MAX_IMAGE_SIZE: usize = 4096;
30const FAKE_IMAGE_TOKEN: &str = "<fake_token_around_image>";
31const IMAGE_TOKEN: &str = "<image>";
32const GLOBAL_IMAGE_TOKEN: &str = "<global-img>";
33
34pub struct Idefics3ImageProcessor {
35    max_edge: Option<u32>,
36    image_seq_len: usize,
37}
38
39pub struct Idefics3Processor {
40    config: ProcessorConfig,
41    max_edge: Option<u32>,
42}
43
44impl Idefics3Processor {
45    pub fn new(
46        config: ProcessorConfig,
47        _preprocessor_config: PreProcessorConfig,
48        max_edge: Option<u32>,
49    ) -> Self {
50        Self { config, max_edge }
51    }
52}
53
54impl Processor for Idefics3Processor {
55    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
56        // Default image_seq_len is 169.
57        Arc::new(Idefics3ImageProcessor {
58            max_edge: self.max_edge,
59            image_seq_len: self.config.image_seq_len.unwrap_or(169),
60        })
61    }
62
63    fn get_special_tokens(&self) -> &[&'static str] {
64        &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
65    }
66
67    fn template_action(&self) -> MessagesAction {
68        MessagesAction::Keep
69    }
70}
71
72fn get_image_prompt_string(n_rows: usize, n_cols: usize, image_seq_len: usize) -> String {
73    if n_rows == 0 && n_cols == 0 {
74        format!(
75            "{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
76            IMAGE_TOKEN.repeat(image_seq_len)
77        )
78    } else {
79        let mut text_split_images = String::new();
80        for n_h in 0..n_rows {
81            for n_w in 0..n_cols {
82                text_split_images.push_str(&format!(
83                    "{FAKE_IMAGE_TOKEN}<row_{}_col_{}>{}",
84                    n_h + 1,
85                    n_w + 1,
86                    IMAGE_TOKEN.repeat(image_seq_len)
87                ));
88            }
89            text_split_images.push('\n');
90        }
91        format!(
92            "{text_split_images}\n{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
93            IMAGE_TOKEN.repeat(image_seq_len)
94        )
95    }
96}
97
98impl InputsProcessor for Idefics3ImageProcessor {
99    fn get_type(&self) -> InputsProcessorType {
100        InputsProcessorType::Vision
101    }
102    fn process_inputs(
103        &self,
104        tokenizer: Option<Arc<Tokenizer>>,
105        input_seqs: &mut [&mut Sequence],
106        is_prompt: bool,
107        is_xlora: bool,
108        device: &Device,
109        no_kv_cache: bool,
110        last_n_context_len: Option<(usize, usize)>,
111        return_raw_logits: bool,
112        other_config: Option<Arc<dyn Any>>,
113        mut paged_attn_metadata: Option<PagedAttentionMeta>,
114        mapper: Option<&dyn DeviceMapper>,
115    ) -> anyhow::Result<InputProcessorOutput> {
116        if is_xlora {
117            return Err(anyhow::Error::msg(
118                "Cannot make inputs for X-LoRA vision model.",
119            ));
120        }
121        if no_kv_cache {
122            return Err(anyhow::Error::msg("Vision model must have kv cache."));
123        }
124        let Some(tokenizer) = tokenizer else {
125            return Err(anyhow::Error::msg(
126                "Idefics3ImageProcessor requires a specified tokenizer.",
127            ));
128        };
129
130        let config = other_config.expect("Need a PreProcessorConfig config.");
131        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
132
133        let has_images = input_seqs.iter().all(|seq| seq.has_images());
134
135        let (pixel_values, pixel_attention_mask) = if has_images {
136            let mut pixel_values_accum = Vec::new();
137            let mut pixel_attention_mask_accum = Vec::new();
138            for seq in input_seqs.iter_mut() {
139                let PreprocessedImages {
140                    pixel_values,
141                    pixel_attention_mask,
142                    image_sizes: _,
143                    num_img_tokens: _,
144                    aspect_ratio_ids: _,
145                    aspect_ratio_mask: _,
146                    num_tiles: _,
147                    image_grid_thw: _,
148                    video_grid_thw: _,
149                    rows,
150                    cols,
151                    pixel_values_list: _,
152                    tgt_sizes: _,
153                    image_sizes_all: _,
154                    num_crops: _,
155                } = self
156                    .preprocess(
157                        seq.take_images()
158                            .expect("Need to have images by this point."),
159                        vec![],
160                        config,
161                        device,
162                        (usize::MAX, usize::MAX), // Don't use it here...
163                    )
164                    .expect("Preprocessing failed");
165                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
166                pixel_attention_mask_accum
167                    .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
168
169                if !seq.multimodal.has_changed_prompt {
170                    let detok = tokenizer
171                        .decode(seq.get_toks(), false)
172                        .expect("Detokenization failed!");
173
174                    let mut image_prompt_strings = Vec::new();
175                    for (n_rows, n_cols) in rows.unwrap().into_iter().zip(cols.unwrap().into_iter())
176                    {
177                        let image_prompt_string =
178                            get_image_prompt_string(n_rows, n_cols, self.image_seq_len);
179                        image_prompt_strings.push(image_prompt_string);
180                    }
181
182                    let split_sample = detok.split(IMAGE_TOKEN).collect::<Vec<_>>();
183                    let mut sample = split_sample
184                        .first()
185                        .expect("The image token <image> should be present in the text.")
186                        .to_string();
187                    for (i, image_prompt_string) in image_prompt_strings.into_iter().enumerate() {
188                        sample.push_str(&format!(
189                            "{image_prompt_string}{}",
190                            split_sample
191                                .get(i + 1)
192                                .expect("Incorrect chat template. Use the one provided in `chat_templates` with the `--chat-template`/`chat_template` settings.")
193                        ));
194                    }
195
196                    seq.set_initial_prompt(sample.clone());
197                    let toks = tokenizer
198                        .encode_fast(sample, false)
199                        .expect("Detokenization failed!");
200
201                    let ids = toks.get_ids().to_vec();
202                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
203                    seq.multimodal.has_changed_prompt = true;
204                }
205            }
206
207            (
208                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
209                Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
210            )
211        } else {
212            (None, None)
213        };
214
215        let text_models_inputs_processor::InnerInputProcessorOutput {
216            inputs:
217                text_models_inputs_processor::InputMetadata {
218                    input,
219                    positions,
220                    context_lens,
221                    position_ids,
222                    paged_attn_meta,
223                    flash_meta,
224                },
225            seq_indices,
226        } = if is_prompt {
227            get_prompt_input(
228                input_seqs
229                    .iter()
230                    .map(|seq| seq.get_toks())
231                    .collect::<Vec<_>>(),
232                input_seqs,
233                device,
234                last_n_context_len,
235                return_raw_logits,
236                paged_attn_metadata.as_mut(),
237                mapper,
238            )
239            .unwrap()
240        } else {
241            get_completion_input(
242                input_seqs
243                    .iter()
244                    .map(|seq| seq.get_toks())
245                    .collect::<Vec<_>>(),
246                input_seqs,
247                device,
248                no_kv_cache,
249                last_n_context_len,
250                return_raw_logits,
251                paged_attn_metadata.as_mut(),
252                mapper,
253            )
254            .unwrap()
255        };
256
257        let inputs: Box<dyn Any> = Box::new(ModelInputs {
258            input_ids: input,
259            seqlen_offsets: positions,
260            context_lens,
261            position_ids,
262            pixel_values,
263            model_specific_args: Box::new(pixel_attention_mask),
264            paged_attn_meta,
265            flash_meta,
266        });
267        Ok(InputProcessorOutput {
268            inputs,
269            seq_indices,
270        })
271    }
272}
273
274// Calculate output size after resizing, rescaling to max length
275fn resize_output_size_rescale_to_max_len(
276    height: usize,
277    width: usize,
278    min_len: Option<usize>,
279    max_len: Option<usize>,
280) -> (usize, usize) {
281    let min_len = min_len.unwrap_or(1);
282    let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
283    let aspect_ratio = width as f32 / height as f32;
284    let (mut height, mut width) = (height, width);
285
286    if width >= height {
287        width = max_len;
288        height = (width as f32 / aspect_ratio).round() as usize;
289        if height % 2 != 0 {
290            height += 1;
291        }
292    } else {
293        height = max_len;
294        width = (height as f32 * aspect_ratio).round() as usize;
295        if width % 2 != 0 {
296            width += 1;
297        }
298    }
299
300    height = cmp::max(height, min_len);
301    width = cmp::max(width, min_len);
302
303    (height, width)
304}
305
306// Calculate output size after resizing, scaling below upper bound
307fn resize_output_size_scale_below_upper_bound(
308    height: usize,
309    width: usize,
310    max_len: Option<usize>,
311) -> (usize, usize) {
312    let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
313    let aspect_ratio = width as f32 / height as f32;
314    let (mut height, mut width) = (height, width);
315
316    if width >= height && width > max_len {
317        width = max_len;
318        height = (width as f32 / aspect_ratio).round() as usize;
319    } else if height > width && height > max_len {
320        height = max_len;
321        width = (height as f32 * aspect_ratio).round() as usize;
322    }
323
324    height = cmp::max(height, 1);
325    width = cmp::max(width, 1);
326
327    (height, width)
328}
329
330/// Given the image sizes (h, w) and the minimum and maximum lengths, calculate the image dimensions
331/// which will preserve aspect ration while respecing the minimum and maximum lengths.
332fn get_resize_output_image_size(
333    (h, w): (usize, usize),
334    resolution_max_side: usize,
335) -> (usize, usize) {
336    let (h, w) = resize_output_size_rescale_to_max_len(h, w, None, Some(resolution_max_side));
337    resize_output_size_scale_below_upper_bound(h, w, Some(MAX_IMAGE_SIZE))
338}
339
340fn resize_for_vision_encoder(
341    (h, w): (usize, usize),
342    vision_encoder_max_size: usize,
343) -> (usize, usize) {
344    let aspect_ratio = w as f32 / h as f32;
345
346    let (new_h, new_w) = if w >= h {
347        let new_w = ((w as f32 / vision_encoder_max_size as f32).ceil()
348            * vision_encoder_max_size as f32) as usize;
349        let mut new_h = (new_w as f32 / aspect_ratio) as usize;
350        new_h = ((new_h as f32 / vision_encoder_max_size as f32).ceil()
351            * vision_encoder_max_size as f32) as usize;
352        (new_h, new_w)
353    } else {
354        let new_h = ((h as f32 / vision_encoder_max_size as f32).ceil()
355            * vision_encoder_max_size as f32) as usize;
356        let mut new_w = (new_h as f32 * aspect_ratio) as usize;
357        new_w = ((new_w as f32 / vision_encoder_max_size as f32).ceil()
358            * vision_encoder_max_size as f32) as usize;
359        (new_h, new_w)
360    };
361
362    (new_h, new_w)
363}
364
365fn resize(
366    image: &DynamicImage,
367    size: &HashMap<String, u32>,
368    resampling: FilterType,
369) -> Result<DynamicImage> {
370    let (h, w) = if size.contains_key("longest_edge") {
371        get_resize_output_image_size(
372            (image.dimensions().1 as usize, image.dimensions().0 as usize),
373            size["longest_edge"] as usize,
374        )
375    } else if size.contains_key("height") && size.contains_key("width") {
376        (size["height"] as usize, size["width"] as usize)
377    } else {
378        candle_core::bail!(
379            "Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."
380        );
381    };
382
383    Ok(image.resize_exact(w as u32, h as u32, resampling))
384    // Ok(image.resize_exact(w as u32, h as u32,  FilterType::Nearest))
385}
386
387/// Returns: frames, num_splits_h, num_splits_w
388fn split_image(
389    image: &DynamicImage,
390    longest_edge: usize,
391) -> Result<(Vec<DynamicImage>, usize, usize)> {
392    let (width, height) = image.dimensions();
393    let mut frames = Vec::new();
394
395    if width > longest_edge as u32 || height > longest_edge as u32 {
396        let num_splits_h = (height as f64 / (longest_edge as f64)).ceil() as usize;
397        let num_splits_w = (width as f64 / (longest_edge as f64)).ceil() as usize;
398
399        let optimal_height = (height as f64 / num_splits_h as f64).ceil() as u32;
400        let optimal_width = (width as f64 / num_splits_w as f64).ceil() as u32;
401
402        for r in 0..num_splits_h {
403            for c in 0..num_splits_w {
404                let start_x = (c as u32) * optimal_width;
405                let start_y = (r as u32) * optimal_height;
406
407                let end_x = std::cmp::min(start_x + optimal_width, width);
408                let end_y = std::cmp::min(start_y + optimal_height, height);
409
410                // Crop the image
411                let cropped_image =
412                    image.crop_imm(start_x, start_y, end_x - start_x, end_y - start_y);
413                frames.push(cropped_image);
414            }
415        }
416
417        // Resize the original image to match `longest_edge` for global efficiency
418        let resized_image = resize(
419            image,
420            &HashMap::from([
421                ("height".to_string(), longest_edge as u32),
422                ("width".to_string(), longest_edge as u32),
423            ]),
424            FilterType::Lanczos3,
425        )?;
426        frames.push(resized_image);
427
428        Ok((frames, num_splits_h, num_splits_w))
429    } else {
430        frames.push(image.clone());
431        Ok((frames, 0, 0))
432    }
433}
434
435impl ImagePreProcessor for Idefics3ImageProcessor {
436    #[allow(clippy::excessive_precision)]
437    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
438    #[allow(clippy::excessive_precision)]
439    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
440
441    fn preprocess(
442        &self,
443        mut images: Vec<DynamicImage>,
444        videos: Vec<Vec<DynamicImage>>,
445        config: &PreProcessorConfig,
446        device: &Device,
447        (_bs, _max_num_images): (usize, usize),
448    ) -> Result<PreprocessedImages> {
449        assert!(videos.is_empty());
450
451        let mut patch_masks = Vec::new();
452        let mut pixel_values = Vec::new();
453
454        if let Some(max_edge) = self.max_edge {
455            images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
456        }
457
458        for image in images.iter_mut() {
459            // Convert to rgb
460            if config.do_convert_rgb.is_some_and(|x| x) {
461                *image = DynamicImage::ImageRgb8(image.to_rgb8());
462            }
463
464            // Resize
465            if config.do_resize.is_some_and(|x| x) {
466                *image = resize(
467                    image,
468                    config.size.as_ref().unwrap(),
469                    config.resampling.to_filter()?,
470                )?;
471            }
472        }
473
474        let mut image_rows = Vec::new();
475        let mut image_cols = Vec::new();
476        let mut new_images = Vec::new();
477        let max_image_size = config
478            .max_image_size
479            .clone()
480            .unwrap_or_else(|| HashMap::from([("longest_edge".to_string(), 364)]));
481        if config.do_image_splitting.unwrap_or(true) {
482            // We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
483            // for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
484            // for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
485            for image in images.iter_mut() {
486                let (new_h, new_w) = resize_for_vision_encoder(
487                    (image.dimensions().1 as usize, image.dimensions().0 as usize),
488                    max_image_size["longest_edge"] as usize,
489                );
490
491                *image =
492                    image.resize_exact(new_w as u32, new_h as u32, config.resampling.to_filter()?);
493
494                let (split_image_array, rows, cols) =
495                    split_image(image, max_image_size["longest_edge"] as usize)?;
496                new_images.extend(split_image_array.into_iter());
497                image_rows.push(rows);
498                image_cols.push(cols);
499            }
500        } else {
501            // We square the images to max_image_size
502            for image in images.iter_mut() {
503                new_images.push(resize(
504                    image,
505                    &HashMap::from([
506                        ("height".to_string(), max_image_size["longest_edge"]),
507                        ("width".to_string(), max_image_size["longest_edge"]),
508                    ]),
509                    FilterType::Lanczos3,
510                )?);
511            }
512            image_rows = vec![0; images.len()];
513            image_cols = vec![0; images.len()];
514        }
515        images = new_images;
516
517        let mut max_h = 0;
518        let mut max_w = 0;
519        for image in &images {
520            let (w, h) = image.dimensions();
521            if w > max_w {
522                max_w = w;
523            }
524            if h > max_h {
525                max_h = h;
526            }
527        }
528
529        for image in images.iter_mut() {
530            let transforms = Transforms {
531                input: &ToTensorNoNorm,
532                inner_transforms: &[
533                    &config
534                        .do_rescale
535                        .is_some_and(|x| x)
536                        .then_some(())
537                        .map(|_| Rescale {
538                            factor: config.rescale_factor,
539                        }),
540                    &config
541                        .do_normalize
542                        .is_some_and(|x| x)
543                        .then_some(())
544                        .map(|_| Normalize {
545                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
546                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
547                        }),
548                ],
549            };
550
551            let mut image = image.apply(transforms, device)?;
552            // Pad images, calculating attention mask.
553            if config.do_pad.is_some_and(|x| x) {
554                let (_c, h, w) = image.dims3()?;
555                let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
556                let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
557                patch_masks.push(mask.unsqueeze(0)?);
558                image = padded;
559            }
560
561            // Get pixel values
562            pixel_values.push(image.unsqueeze(0)?)
563        }
564
565        Ok(PreprocessedImages {
566            pixel_values: Tensor::cat(&pixel_values, 0)?,
567            pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
568            image_sizes: None,
569            num_img_tokens: None,
570            aspect_ratio_ids: None,
571            aspect_ratio_mask: None,
572            num_tiles: None,
573            image_grid_thw: None,
574            video_grid_thw: None,
575            rows: Some(image_rows),
576            cols: Some(image_cols),
577            pixel_values_list: None,
578            tgt_sizes: None,
579            image_sizes_all: None,
580            num_crops: None,
581        })
582    }
583}