mistralrs_core/vision_models/idefics3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, cmp, collections::HashMap, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9use tracing::warn;
10
11use crate::{
12    device_map::DeviceMapper,
13    pipeline::{
14        text_models_inputs_processor::{
15            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16        },
17        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18    },
19    sequence::Sequence,
20    vision_models::ModelInputs,
21};
22
23use crate::vision_models::{
24    image_processor::{ImagePreProcessor, PreprocessedImages},
25    preprocessor_config::{PreProcessorConfig, ToFilter},
26    processor_config::ProcessorConfig,
27};
28
29// 4k resolution as absolute maximum
30const MAX_IMAGE_SIZE: usize = 4096;
31const FAKE_IMAGE_TOKEN: &str = "<fake_token_around_image>";
32const IMAGE_TOKEN: &str = "<image>";
33const GLOBAL_IMAGE_TOKEN: &str = "<global-img>";
34
35pub struct Idefics3ImageProcessor {
36    max_edge: Option<u32>,
37    image_seq_len: usize,
38}
39
40pub struct Idefics3Processor {
41    config: ProcessorConfig,
42    max_edge: Option<u32>,
43}
44
45impl Idefics3Processor {
46    pub fn new(
47        config: ProcessorConfig,
48        _preprocessor_config: PreProcessorConfig,
49        max_edge: Option<u32>,
50    ) -> Self {
51        Self { config, max_edge }
52    }
53}
54
55impl Processor for Idefics3Processor {
56    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
57        // Default image_seq_len is 169.
58        Arc::new(Idefics3ImageProcessor {
59            max_edge: self.max_edge,
60            image_seq_len: self.config.image_seq_len.unwrap_or(169),
61        })
62    }
63
64    fn get_special_tokens(&self) -> &[&'static str] {
65        &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
66    }
67
68    fn template_action(&self) -> MessagesAction {
69        MessagesAction::Keep
70    }
71}
72
73fn get_image_prompt_string(n_rows: usize, n_cols: usize, image_seq_len: usize) -> String {
74    if n_rows == 0 && n_cols == 0 {
75        format!(
76            "{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
77            IMAGE_TOKEN.repeat(image_seq_len)
78        )
79    } else {
80        let mut text_split_images = String::new();
81        for n_h in 0..n_rows {
82            for n_w in 0..n_cols {
83                text_split_images.push_str(&format!(
84                    "{FAKE_IMAGE_TOKEN}<row_{}_col_{}>{}",
85                    n_h + 1,
86                    n_w + 1,
87                    IMAGE_TOKEN.repeat(image_seq_len)
88                ));
89            }
90            text_split_images.push('\n');
91        }
92        format!(
93            "{text_split_images}\n{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
94            IMAGE_TOKEN.repeat(image_seq_len)
95        )
96    }
97}
98
99impl InputsProcessor for Idefics3ImageProcessor {
100    fn get_type(&self) -> InputsProcessorType {
101        InputsProcessorType::Vision
102    }
103    fn process_inputs(
104        &self,
105        tokenizer: Option<Arc<Tokenizer>>,
106        input_seqs: &mut [&mut Sequence],
107        is_prompt: bool,
108        is_xlora: bool,
109        device: &Device,
110        no_kv_cache: bool,
111        last_n_context_len: Option<(usize, usize)>,
112        return_raw_logits: bool,
113        other_config: Option<Arc<dyn Any>>,
114        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
115        prompt_chunksize: Option<NonZeroUsize>,
116        mapper: Option<&dyn DeviceMapper>,
117    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
118        if is_xlora {
119            return Box::new(std::iter::once(Err(anyhow::Error::msg(
120                "Cannot make inputs for X-LoRA vision model.",
121            ))));
122        }
123        if no_kv_cache {
124            return Box::new(std::iter::once(Err(anyhow::Error::msg(
125                "Vision model must have kv cache.",
126            ))));
127        }
128        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
129        if prompt_chunksize.is_some() {
130            warn!("`prompt_chunksize` is set. Idefics 3 does not support prompt batching.");
131        }
132        let Some(tokenizer) = tokenizer else {
133            return Box::new(std::iter::once(Err(anyhow::Error::msg(
134                "Idefics3ImageProcessor requires a specified tokenizer.",
135            ))));
136        };
137
138        let config = other_config.expect("Need a PreProcessorConfig config.");
139        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
140
141        let has_images = input_seqs.iter().all(|seq| seq.has_images());
142
143        let (pixel_values, pixel_attention_mask) = if has_images {
144            let mut pixel_values_accum = Vec::new();
145            let mut pixel_attention_mask_accum = Vec::new();
146            for seq in input_seqs.iter_mut() {
147                let PreprocessedImages {
148                    pixel_values,
149                    pixel_attention_mask,
150                    image_sizes: _,
151                    num_img_tokens: _,
152                    aspect_ratio_ids: _,
153                    aspect_ratio_mask: _,
154                    num_tiles: _,
155                    image_grid_thw: _,
156                    video_grid_thw: _,
157                    rows,
158                    cols,
159                    pixel_values_list: _,
160                    tgt_sizes: _,
161                    image_sizes_all: _,
162                    num_crops: _,
163                } = self
164                    .preprocess(
165                        seq.take_images()
166                            .expect("Need to have images by this point."),
167                        vec![],
168                        config,
169                        device,
170                        (usize::MAX, usize::MAX), // Don't use it here...
171                    )
172                    .expect("Preprocessing failed");
173                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
174                pixel_attention_mask_accum
175                    .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
176
177                let detok = tokenizer
178                    .decode(seq.get_toks(), false)
179                    .expect("Detokenization failed!");
180
181                let mut image_prompt_strings = Vec::new();
182                for (n_rows, n_cols) in rows.unwrap().into_iter().zip(cols.unwrap().into_iter()) {
183                    let image_prompt_string =
184                        get_image_prompt_string(n_rows, n_cols, self.image_seq_len);
185                    image_prompt_strings.push(image_prompt_string);
186                }
187
188                let split_sample = detok.split(IMAGE_TOKEN).collect::<Vec<_>>();
189                let mut sample = split_sample
190                    .first()
191                    .expect("The image token <image> should be present in the text.")
192                    .to_string();
193                for (i, image_prompt_string) in image_prompt_strings.into_iter().enumerate() {
194                    sample.push_str(&format!("{image_prompt_string}{}", split_sample[i]));
195                }
196
197                if !seq.has_changed_prompt {
198                    seq.set_initial_prompt(sample.clone());
199                    let toks = tokenizer
200                        .encode_fast(sample, false)
201                        .expect("Detokenization failed!");
202
203                    let ids = toks.get_ids().to_vec();
204                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
205                    seq.has_changed_prompt = true;
206                }
207            }
208
209            (
210                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
211                Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
212            )
213        } else {
214            (None, None)
215        };
216
217        let text_models_inputs_processor::InnerInputProcessorOutput {
218            inputs:
219                text_models_inputs_processor::InputMetadata {
220                    input,
221                    positions,
222                    context_lens,
223                    position_ids,
224                    paged_attn_meta,
225                    flash_meta,
226                },
227            seq_indices,
228        } = if is_prompt {
229            get_prompt_input(
230                input_seqs
231                    .iter()
232                    .map(|seq| seq.get_toks().to_vec())
233                    .collect::<Vec<_>>(),
234                input_seqs,
235                device,
236                last_n_context_len,
237                return_raw_logits,
238                paged_attn_metadata.as_mut(),
239                None, // TODO: evaluate if it is possible to batch this
240                mapper,
241            )
242            .nth(0)
243            .unwrap()
244            .unwrap()
245        } else {
246            get_completion_input(
247                input_seqs
248                    .iter()
249                    .map(|seq| seq.get_toks().to_vec())
250                    .collect::<Vec<_>>(),
251                input_seqs,
252                device,
253                no_kv_cache,
254                last_n_context_len,
255                return_raw_logits,
256                paged_attn_metadata.as_mut(),
257                None, // TODO: evaluate if it is possible to batch this
258                mapper,
259            )
260            .nth(0)
261            .unwrap()
262            .unwrap()
263        };
264
265        let inputs: Box<dyn Any> = Box::new(ModelInputs {
266            input_ids: input,
267            seqlen_offsets: positions,
268            context_lens,
269            position_ids,
270            pixel_values,
271            model_specific_args: Box::new(pixel_attention_mask),
272            paged_attn_meta,
273            flash_meta,
274        });
275        Box::new(std::iter::once(Ok(InputProcessorOutput {
276            inputs,
277            seq_indices,
278        })))
279    }
280}
281
282// Calculate output size after resizing, rescaling to max length
283fn resize_output_size_rescale_to_max_len(
284    height: usize,
285    width: usize,
286    min_len: Option<usize>,
287    max_len: Option<usize>,
288) -> (usize, usize) {
289    let min_len = min_len.unwrap_or(1);
290    let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
291    let aspect_ratio = width as f32 / height as f32;
292    let (mut height, mut width) = (height, width);
293
294    if width >= height {
295        width = max_len;
296        height = (width as f32 / aspect_ratio).round() as usize;
297        if height % 2 != 0 {
298            height += 1;
299        }
300    } else {
301        height = max_len;
302        width = (height as f32 * aspect_ratio).round() as usize;
303        if width % 2 != 0 {
304            width += 1;
305        }
306    }
307
308    height = cmp::max(height, min_len);
309    width = cmp::max(width, min_len);
310
311    (height, width)
312}
313
314// Calculate output size after resizing, scaling below upper bound
315fn resize_output_size_scale_below_upper_bound(
316    height: usize,
317    width: usize,
318    max_len: Option<usize>,
319) -> (usize, usize) {
320    let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
321    let aspect_ratio = width as f32 / height as f32;
322    let (mut height, mut width) = (height, width);
323
324    if width >= height && width > max_len {
325        width = max_len;
326        height = (width as f32 / aspect_ratio).round() as usize;
327    } else if height > width && height > max_len {
328        height = max_len;
329        width = (height as f32 * aspect_ratio).round() as usize;
330    }
331
332    height = cmp::max(height, 1);
333    width = cmp::max(width, 1);
334
335    (height, width)
336}
337
338/// Given the image sizes (h, w) and the minimum and maximum lengths, calculate the image dimensions
339/// which will preserve aspect ration while respecing the minimum and maximum lengths.
340fn get_resize_output_image_size(
341    (h, w): (usize, usize),
342    resolution_max_side: usize,
343) -> (usize, usize) {
344    let (h, w) = resize_output_size_rescale_to_max_len(h, w, None, Some(resolution_max_side));
345    resize_output_size_scale_below_upper_bound(h, w, Some(MAX_IMAGE_SIZE))
346}
347
348fn resize_for_vision_encoder(
349    (h, w): (usize, usize),
350    vision_encoder_max_size: usize,
351) -> (usize, usize) {
352    let aspect_ratio = w as f32 / h as f32;
353
354    let (new_h, new_w) = if w >= h {
355        let new_w = ((w as f32 / vision_encoder_max_size as f32).ceil()
356            * vision_encoder_max_size as f32) as usize;
357        let mut new_h = (new_w as f32 / aspect_ratio) as usize;
358        new_h = ((new_h as f32 / vision_encoder_max_size as f32).ceil()
359            * vision_encoder_max_size as f32) as usize;
360        (new_h, new_w)
361    } else {
362        let new_h = ((h as f32 / vision_encoder_max_size as f32).ceil()
363            * vision_encoder_max_size as f32) as usize;
364        let mut new_w = (new_h as f32 * aspect_ratio) as usize;
365        new_w = ((new_w as f32 / vision_encoder_max_size as f32).ceil()
366            * vision_encoder_max_size as f32) as usize;
367        (new_h, new_w)
368    };
369
370    (new_h, new_w)
371}
372
373fn resize(
374    image: &DynamicImage,
375    size: &HashMap<String, u32>,
376    resampling: FilterType,
377) -> Result<DynamicImage> {
378    let (h, w) = if size.contains_key("longest_edge") {
379        get_resize_output_image_size(
380            (image.dimensions().1 as usize, image.dimensions().0 as usize),
381            size["longest_edge"] as usize,
382        )
383    } else if size.contains_key("height") && size.contains_key("width") {
384        (size["height"] as usize, size["width"] as usize)
385    } else {
386        candle_core::bail!(
387            "Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."
388        );
389    };
390
391    Ok(image.resize_exact(w as u32, h as u32, resampling))
392    // Ok(image.resize_exact(w as u32, h as u32,  FilterType::Nearest))
393}
394
395/// Returns: frames, num_splits_h, num_splits_w
396fn split_image(
397    image: &DynamicImage,
398    longest_edge: usize,
399) -> Result<(Vec<DynamicImage>, usize, usize)> {
400    let (width, height) = image.dimensions();
401    let mut frames = Vec::new();
402
403    if width > longest_edge as u32 || height > longest_edge as u32 {
404        let num_splits_h = (height as f64 / (longest_edge as f64)).ceil() as usize;
405        let num_splits_w = (width as f64 / (longest_edge as f64)).ceil() as usize;
406
407        let optimal_height = (height as f64 / num_splits_h as f64).ceil() as u32;
408        let optimal_width = (width as f64 / num_splits_w as f64).ceil() as u32;
409
410        for r in 0..num_splits_h {
411            for c in 0..num_splits_w {
412                let start_x = (c as u32) * optimal_width;
413                let start_y = (r as u32) * optimal_height;
414
415                let end_x = std::cmp::min(start_x + optimal_width, width);
416                let end_y = std::cmp::min(start_y + optimal_height, height);
417
418                // Crop the image
419                let cropped_image =
420                    image.crop_imm(start_x, start_y, end_x - start_x, end_y - start_y);
421                frames.push(cropped_image);
422            }
423        }
424
425        // Resize the original image to match `longest_edge` for global efficiency
426        let resized_image = resize(
427            image,
428            &HashMap::from([
429                ("height".to_string(), longest_edge as u32),
430                ("width".to_string(), longest_edge as u32),
431            ]),
432            FilterType::Lanczos3,
433        )?;
434        frames.push(resized_image);
435
436        Ok((frames, num_splits_h, num_splits_w))
437    } else {
438        frames.push(image.clone());
439        Ok((frames, 0, 0))
440    }
441}
442
443impl ImagePreProcessor for Idefics3ImageProcessor {
444    #[allow(clippy::excessive_precision)]
445    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
446    #[allow(clippy::excessive_precision)]
447    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
448
449    fn preprocess(
450        &self,
451        mut images: Vec<DynamicImage>,
452        videos: Vec<Vec<DynamicImage>>,
453        config: &PreProcessorConfig,
454        device: &Device,
455        (_bs, _max_num_images): (usize, usize),
456    ) -> Result<PreprocessedImages> {
457        assert!(videos.is_empty());
458
459        let mut patch_masks = Vec::new();
460        let mut pixel_values = Vec::new();
461
462        if let Some(max_edge) = self.max_edge {
463            images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
464        }
465
466        for image in images.iter_mut() {
467            // Convert to rgb
468            if config.do_convert_rgb.is_some_and(|x| x) {
469                *image = DynamicImage::ImageRgb8(image.to_rgb8());
470            }
471
472            // Resize
473            if config.do_resize.is_some_and(|x| x) {
474                *image = resize(
475                    image,
476                    config.size.as_ref().unwrap(),
477                    config.resampling.to_filter()?,
478                )?;
479            }
480        }
481
482        let mut image_rows = Vec::new();
483        let mut image_cols = Vec::new();
484        let mut new_images = Vec::new();
485        let max_image_size = config
486            .max_image_size
487            .clone()
488            .unwrap_or_else(|| HashMap::from([("longest_edge".to_string(), 364)]));
489        if config.do_image_splitting.unwrap_or(true) {
490            // We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
491            // for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
492            // for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
493            for image in images.iter_mut() {
494                let (new_h, new_w) = resize_for_vision_encoder(
495                    (image.dimensions().1 as usize, image.dimensions().0 as usize),
496                    max_image_size["longest_edge"] as usize,
497                );
498
499                *image =
500                    image.resize_exact(new_w as u32, new_h as u32, config.resampling.to_filter()?);
501
502                let (split_image_array, rows, cols) =
503                    split_image(image, max_image_size["longest_edge"] as usize)?;
504                new_images.extend(split_image_array.into_iter());
505                image_rows.push(rows);
506                image_cols.push(cols);
507            }
508        } else {
509            // We square the images to max_image_size
510            for image in images.iter_mut() {
511                new_images.push(resize(
512                    image,
513                    &HashMap::from([
514                        ("height".to_string(), max_image_size["longest_edge"]),
515                        ("width".to_string(), max_image_size["longest_edge"]),
516                    ]),
517                    FilterType::Lanczos3,
518                )?);
519            }
520            image_rows = vec![0; images.len()];
521            image_cols = vec![0; images.len()];
522        }
523        images = new_images;
524
525        let mut max_h = 0;
526        let mut max_w = 0;
527        for image in &images {
528            let (w, h) = image.dimensions();
529            if w > max_w {
530                max_w = w;
531            }
532            if h > max_h {
533                max_h = h;
534            }
535        }
536
537        for image in images.iter_mut() {
538            let transforms = Transforms {
539                input: &ToTensorNoNorm,
540                inner_transforms: &[
541                    &config
542                        .do_rescale
543                        .is_some_and(|x| x)
544                        .then_some(())
545                        .map(|_| Rescale {
546                            factor: config.rescale_factor,
547                        }),
548                    &config
549                        .do_normalize
550                        .is_some_and(|x| x)
551                        .then_some(())
552                        .map(|_| Normalize {
553                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
554                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
555                        }),
556                ],
557            };
558
559            let mut image = image.apply(transforms, device)?;
560            // Pad images, calculating attention mask.
561            if config.do_pad.is_some_and(|x| x) {
562                let (_c, h, w) = image.dims3()?;
563                let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
564                let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
565                patch_masks.push(mask.unsqueeze(0)?);
566                image = padded;
567            }
568
569            // Get pixel values
570            pixel_values.push(image.unsqueeze(0)?)
571        }
572
573        Ok(PreprocessedImages {
574            pixel_values: Tensor::cat(&pixel_values, 0)?,
575            pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
576            image_sizes: None,
577            num_img_tokens: None,
578            aspect_ratio_ids: None,
579            aspect_ratio_mask: None,
580            num_tiles: None,
581            image_grid_thw: None,
582            video_grid_thw: None,
583            rows: Some(image_rows),
584            cols: Some(image_cols),
585            pixel_values_list: None,
586            tgt_sizes: None,
587            image_sizes_all: None,
588            num_crops: None,
589        })
590    }
591}