mistralrs_core/vision_models/idefics3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, cmp, collections::HashMap, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9use tracing::warn;
10
11use crate::{
12    device_map::DeviceMapper,
13    pipeline::{
14        text_models_inputs_processor::{
15            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16        },
17        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18    },
19    sequence::Sequence,
20    vision_models::ModelInputs,
21};
22
23use crate::vision_models::{
24    image_processor::{ImagePreProcessor, PreprocessedImages},
25    preprocessor_config::{PreProcessorConfig, ToFilter},
26    processor_config::ProcessorConfig,
27};
28
29// 4k resolution as absolute maximum
30const MAX_IMAGE_SIZE: usize = 4096;
31const FAKE_IMAGE_TOKEN: &str = "<fake_token_around_image>";
32const IMAGE_TOKEN: &str = "<image>";
33const GLOBAL_IMAGE_TOKEN: &str = "<global-img>";
34
35pub struct Idefics3ImageProcessor {
36    max_edge: Option<u32>,
37    image_seq_len: usize,
38}
39
40pub struct Idefics3Processor {
41    config: ProcessorConfig,
42    max_edge: Option<u32>,
43}
44
45impl Idefics3Processor {
46    pub fn new(
47        config: ProcessorConfig,
48        _preprocessor_config: PreProcessorConfig,
49        max_edge: Option<u32>,
50    ) -> Self {
51        Self { config, max_edge }
52    }
53}
54
55impl Processor for Idefics3Processor {
56    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
57        // Default image_seq_len is 169.
58        Arc::new(Idefics3ImageProcessor {
59            max_edge: self.max_edge,
60            image_seq_len: self.config.image_seq_len.unwrap_or(169),
61        })
62    }
63
64    fn get_special_tokens(&self) -> &[&'static str] {
65        &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
66    }
67
68    fn template_action(&self) -> MessagesAction {
69        MessagesAction::Keep
70    }
71}
72
73fn get_image_prompt_string(n_rows: usize, n_cols: usize, image_seq_len: usize) -> String {
74    if n_rows == 0 && n_cols == 0 {
75        format!(
76            "{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
77            IMAGE_TOKEN.repeat(image_seq_len)
78        )
79    } else {
80        let mut text_split_images = String::new();
81        for n_h in 0..n_rows {
82            for n_w in 0..n_cols {
83                text_split_images.push_str(&format!(
84                    "{FAKE_IMAGE_TOKEN}<row_{}_col_{}>{}",
85                    n_h + 1,
86                    n_w + 1,
87                    IMAGE_TOKEN.repeat(image_seq_len)
88                ));
89            }
90            text_split_images.push('\n');
91        }
92        format!(
93            "{text_split_images}\n{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
94            IMAGE_TOKEN.repeat(image_seq_len)
95        )
96    }
97}
98
99impl InputsProcessor for Idefics3ImageProcessor {
100    fn get_type(&self) -> InputsProcessorType {
101        InputsProcessorType::Vision
102    }
103    fn process_inputs(
104        &self,
105        tokenizer: Option<Arc<Tokenizer>>,
106        input_seqs: &mut [&mut Sequence],
107        is_prompt: bool,
108        is_xlora: bool,
109        device: &Device,
110        no_kv_cache: bool,
111        last_n_context_len: Option<(usize, usize)>,
112        return_raw_logits: bool,
113        other_config: Option<Arc<dyn Any>>,
114        mut paged_attn_metadata: Option<PagedAttentionMeta>,
115        prompt_chunksize: Option<NonZeroUsize>,
116        mapper: Option<&dyn DeviceMapper>,
117    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
118        if is_xlora {
119            return Box::new(std::iter::once(Err(anyhow::Error::msg(
120                "Cannot make inputs for X-LoRA vision model.",
121            ))));
122        }
123        if no_kv_cache {
124            return Box::new(std::iter::once(Err(anyhow::Error::msg(
125                "Vision model must have kv cache.",
126            ))));
127        }
128        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
129        if prompt_chunksize.is_some() {
130            warn!("`prompt_chunksize` is set. Idefics 3 does not support prompt batching.");
131        }
132        let Some(tokenizer) = tokenizer else {
133            return Box::new(std::iter::once(Err(anyhow::Error::msg(
134                "Idefics3ImageProcessor requires a specified tokenizer.",
135            ))));
136        };
137
138        let config = other_config.expect("Need a PreProcessorConfig config.");
139        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
140
141        let has_images = input_seqs.iter().all(|seq| seq.has_images());
142
143        let (pixel_values, pixel_attention_mask) = if has_images {
144            let mut pixel_values_accum = Vec::new();
145            let mut pixel_attention_mask_accum = Vec::new();
146            for seq in input_seqs.iter_mut() {
147                let PreprocessedImages {
148                    pixel_values,
149                    pixel_attention_mask,
150                    image_sizes: _,
151                    num_img_tokens: _,
152                    aspect_ratio_ids: _,
153                    aspect_ratio_mask: _,
154                    num_tiles: _,
155                    image_grid_thw: _,
156                    video_grid_thw: _,
157                    rows,
158                    cols,
159                    pixel_values_list: _,
160                    tgt_sizes: _,
161                    image_sizes_all: _,
162                    num_crops: _,
163                } = self
164                    .preprocess(
165                        seq.take_images()
166                            .expect("Need to have images by this point."),
167                        vec![],
168                        config,
169                        device,
170                        (usize::MAX, usize::MAX), // Don't use it here...
171                    )
172                    .expect("Preprocessing failed");
173                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
174                pixel_attention_mask_accum
175                    .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
176
177                if !seq.multimodal.has_changed_prompt {
178                    let detok = tokenizer
179                        .decode(seq.get_toks(), false)
180                        .expect("Detokenization failed!");
181
182                    let mut image_prompt_strings = Vec::new();
183                    for (n_rows, n_cols) in rows.unwrap().into_iter().zip(cols.unwrap().into_iter())
184                    {
185                        let image_prompt_string =
186                            get_image_prompt_string(n_rows, n_cols, self.image_seq_len);
187                        image_prompt_strings.push(image_prompt_string);
188                    }
189
190                    let split_sample = detok.split(IMAGE_TOKEN).collect::<Vec<_>>();
191                    let mut sample = split_sample
192                        .first()
193                        .expect("The image token <image> should be present in the text.")
194                        .to_string();
195                    for (i, image_prompt_string) in image_prompt_strings.into_iter().enumerate() {
196                        sample.push_str(&format!(
197                            "{image_prompt_string}{}",
198                            split_sample
199                                .get(i + 1)
200                                .expect("Incorrect chat template. Use the one provided in `chat_templates` with the `--chat-template`/`chat_template` settings.")
201                        ));
202                    }
203
204                    seq.set_initial_prompt(sample.clone());
205                    let toks = tokenizer
206                        .encode_fast(sample, false)
207                        .expect("Detokenization failed!");
208
209                    let ids = toks.get_ids().to_vec();
210                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
211                    seq.multimodal.has_changed_prompt = true;
212                }
213            }
214
215            (
216                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
217                Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
218            )
219        } else {
220            (None, None)
221        };
222
223        let text_models_inputs_processor::InnerInputProcessorOutput {
224            inputs:
225                text_models_inputs_processor::InputMetadata {
226                    input,
227                    positions,
228                    context_lens,
229                    position_ids,
230                    paged_attn_meta,
231                    flash_meta,
232                },
233            seq_indices,
234        } = if is_prompt {
235            get_prompt_input(
236                input_seqs
237                    .iter()
238                    .map(|seq| seq.get_toks())
239                    .collect::<Vec<_>>(),
240                input_seqs,
241                device,
242                last_n_context_len,
243                return_raw_logits,
244                paged_attn_metadata.as_mut(),
245                None, // TODO: evaluate if it is possible to batch this
246                mapper,
247            )
248            .nth(0)
249            .unwrap()
250            .unwrap()
251        } else {
252            get_completion_input(
253                input_seqs
254                    .iter()
255                    .map(|seq| seq.get_toks())
256                    .collect::<Vec<_>>(),
257                input_seqs,
258                device,
259                no_kv_cache,
260                last_n_context_len,
261                return_raw_logits,
262                paged_attn_metadata.as_mut(),
263                None, // TODO: evaluate if it is possible to batch this
264                mapper,
265            )
266            .nth(0)
267            .unwrap()
268            .unwrap()
269        };
270
271        let inputs: Box<dyn Any> = Box::new(ModelInputs {
272            input_ids: input,
273            seqlen_offsets: positions,
274            context_lens,
275            position_ids,
276            pixel_values,
277            model_specific_args: Box::new(pixel_attention_mask),
278            paged_attn_meta,
279            flash_meta,
280        });
281        Box::new(std::iter::once(Ok(InputProcessorOutput {
282            inputs,
283            seq_indices,
284        })))
285    }
286}
287
288// Calculate output size after resizing, rescaling to max length
289fn resize_output_size_rescale_to_max_len(
290    height: usize,
291    width: usize,
292    min_len: Option<usize>,
293    max_len: Option<usize>,
294) -> (usize, usize) {
295    let min_len = min_len.unwrap_or(1);
296    let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
297    let aspect_ratio = width as f32 / height as f32;
298    let (mut height, mut width) = (height, width);
299
300    if width >= height {
301        width = max_len;
302        height = (width as f32 / aspect_ratio).round() as usize;
303        if height % 2 != 0 {
304            height += 1;
305        }
306    } else {
307        height = max_len;
308        width = (height as f32 * aspect_ratio).round() as usize;
309        if width % 2 != 0 {
310            width += 1;
311        }
312    }
313
314    height = cmp::max(height, min_len);
315    width = cmp::max(width, min_len);
316
317    (height, width)
318}
319
320// Calculate output size after resizing, scaling below upper bound
321fn resize_output_size_scale_below_upper_bound(
322    height: usize,
323    width: usize,
324    max_len: Option<usize>,
325) -> (usize, usize) {
326    let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
327    let aspect_ratio = width as f32 / height as f32;
328    let (mut height, mut width) = (height, width);
329
330    if width >= height && width > max_len {
331        width = max_len;
332        height = (width as f32 / aspect_ratio).round() as usize;
333    } else if height > width && height > max_len {
334        height = max_len;
335        width = (height as f32 * aspect_ratio).round() as usize;
336    }
337
338    height = cmp::max(height, 1);
339    width = cmp::max(width, 1);
340
341    (height, width)
342}
343
344/// Given the image sizes (h, w) and the minimum and maximum lengths, calculate the image dimensions
345/// which will preserve aspect ration while respecing the minimum and maximum lengths.
346fn get_resize_output_image_size(
347    (h, w): (usize, usize),
348    resolution_max_side: usize,
349) -> (usize, usize) {
350    let (h, w) = resize_output_size_rescale_to_max_len(h, w, None, Some(resolution_max_side));
351    resize_output_size_scale_below_upper_bound(h, w, Some(MAX_IMAGE_SIZE))
352}
353
354fn resize_for_vision_encoder(
355    (h, w): (usize, usize),
356    vision_encoder_max_size: usize,
357) -> (usize, usize) {
358    let aspect_ratio = w as f32 / h as f32;
359
360    let (new_h, new_w) = if w >= h {
361        let new_w = ((w as f32 / vision_encoder_max_size as f32).ceil()
362            * vision_encoder_max_size as f32) as usize;
363        let mut new_h = (new_w as f32 / aspect_ratio) as usize;
364        new_h = ((new_h as f32 / vision_encoder_max_size as f32).ceil()
365            * vision_encoder_max_size as f32) as usize;
366        (new_h, new_w)
367    } else {
368        let new_h = ((h as f32 / vision_encoder_max_size as f32).ceil()
369            * vision_encoder_max_size as f32) as usize;
370        let mut new_w = (new_h as f32 * aspect_ratio) as usize;
371        new_w = ((new_w as f32 / vision_encoder_max_size as f32).ceil()
372            * vision_encoder_max_size as f32) as usize;
373        (new_h, new_w)
374    };
375
376    (new_h, new_w)
377}
378
379fn resize(
380    image: &DynamicImage,
381    size: &HashMap<String, u32>,
382    resampling: FilterType,
383) -> Result<DynamicImage> {
384    let (h, w) = if size.contains_key("longest_edge") {
385        get_resize_output_image_size(
386            (image.dimensions().1 as usize, image.dimensions().0 as usize),
387            size["longest_edge"] as usize,
388        )
389    } else if size.contains_key("height") && size.contains_key("width") {
390        (size["height"] as usize, size["width"] as usize)
391    } else {
392        candle_core::bail!(
393            "Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."
394        );
395    };
396
397    Ok(image.resize_exact(w as u32, h as u32, resampling))
398    // Ok(image.resize_exact(w as u32, h as u32,  FilterType::Nearest))
399}
400
401/// Returns: frames, num_splits_h, num_splits_w
402fn split_image(
403    image: &DynamicImage,
404    longest_edge: usize,
405) -> Result<(Vec<DynamicImage>, usize, usize)> {
406    let (width, height) = image.dimensions();
407    let mut frames = Vec::new();
408
409    if width > longest_edge as u32 || height > longest_edge as u32 {
410        let num_splits_h = (height as f64 / (longest_edge as f64)).ceil() as usize;
411        let num_splits_w = (width as f64 / (longest_edge as f64)).ceil() as usize;
412
413        let optimal_height = (height as f64 / num_splits_h as f64).ceil() as u32;
414        let optimal_width = (width as f64 / num_splits_w as f64).ceil() as u32;
415
416        for r in 0..num_splits_h {
417            for c in 0..num_splits_w {
418                let start_x = (c as u32) * optimal_width;
419                let start_y = (r as u32) * optimal_height;
420
421                let end_x = std::cmp::min(start_x + optimal_width, width);
422                let end_y = std::cmp::min(start_y + optimal_height, height);
423
424                // Crop the image
425                let cropped_image =
426                    image.crop_imm(start_x, start_y, end_x - start_x, end_y - start_y);
427                frames.push(cropped_image);
428            }
429        }
430
431        // Resize the original image to match `longest_edge` for global efficiency
432        let resized_image = resize(
433            image,
434            &HashMap::from([
435                ("height".to_string(), longest_edge as u32),
436                ("width".to_string(), longest_edge as u32),
437            ]),
438            FilterType::Lanczos3,
439        )?;
440        frames.push(resized_image);
441
442        Ok((frames, num_splits_h, num_splits_w))
443    } else {
444        frames.push(image.clone());
445        Ok((frames, 0, 0))
446    }
447}
448
449impl ImagePreProcessor for Idefics3ImageProcessor {
450    #[allow(clippy::excessive_precision)]
451    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
452    #[allow(clippy::excessive_precision)]
453    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
454
455    fn preprocess(
456        &self,
457        mut images: Vec<DynamicImage>,
458        videos: Vec<Vec<DynamicImage>>,
459        config: &PreProcessorConfig,
460        device: &Device,
461        (_bs, _max_num_images): (usize, usize),
462    ) -> Result<PreprocessedImages> {
463        assert!(videos.is_empty());
464
465        let mut patch_masks = Vec::new();
466        let mut pixel_values = Vec::new();
467
468        if let Some(max_edge) = self.max_edge {
469            images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
470        }
471
472        for image in images.iter_mut() {
473            // Convert to rgb
474            if config.do_convert_rgb.is_some_and(|x| x) {
475                *image = DynamicImage::ImageRgb8(image.to_rgb8());
476            }
477
478            // Resize
479            if config.do_resize.is_some_and(|x| x) {
480                *image = resize(
481                    image,
482                    config.size.as_ref().unwrap(),
483                    config.resampling.to_filter()?,
484                )?;
485            }
486        }
487
488        let mut image_rows = Vec::new();
489        let mut image_cols = Vec::new();
490        let mut new_images = Vec::new();
491        let max_image_size = config
492            .max_image_size
493            .clone()
494            .unwrap_or_else(|| HashMap::from([("longest_edge".to_string(), 364)]));
495        if config.do_image_splitting.unwrap_or(true) {
496            // We first resize both height and width of each image to the nearest max_image_size multiple, disregarding the aspect ratio
497            // for size=(10, max_image_size) -> rescaled_size=(max_image_size, max_image_size)
498            // for size=(11, max_image_size+1) -> rescaled_size=(max_image_size, max_image_size*2)
499            for image in images.iter_mut() {
500                let (new_h, new_w) = resize_for_vision_encoder(
501                    (image.dimensions().1 as usize, image.dimensions().0 as usize),
502                    max_image_size["longest_edge"] as usize,
503                );
504
505                *image =
506                    image.resize_exact(new_w as u32, new_h as u32, config.resampling.to_filter()?);
507
508                let (split_image_array, rows, cols) =
509                    split_image(image, max_image_size["longest_edge"] as usize)?;
510                new_images.extend(split_image_array.into_iter());
511                image_rows.push(rows);
512                image_cols.push(cols);
513            }
514        } else {
515            // We square the images to max_image_size
516            for image in images.iter_mut() {
517                new_images.push(resize(
518                    image,
519                    &HashMap::from([
520                        ("height".to_string(), max_image_size["longest_edge"]),
521                        ("width".to_string(), max_image_size["longest_edge"]),
522                    ]),
523                    FilterType::Lanczos3,
524                )?);
525            }
526            image_rows = vec![0; images.len()];
527            image_cols = vec![0; images.len()];
528        }
529        images = new_images;
530
531        let mut max_h = 0;
532        let mut max_w = 0;
533        for image in &images {
534            let (w, h) = image.dimensions();
535            if w > max_w {
536                max_w = w;
537            }
538            if h > max_h {
539                max_h = h;
540            }
541        }
542
543        for image in images.iter_mut() {
544            let transforms = Transforms {
545                input: &ToTensorNoNorm,
546                inner_transforms: &[
547                    &config
548                        .do_rescale
549                        .is_some_and(|x| x)
550                        .then_some(())
551                        .map(|_| Rescale {
552                            factor: config.rescale_factor,
553                        }),
554                    &config
555                        .do_normalize
556                        .is_some_and(|x| x)
557                        .then_some(())
558                        .map(|_| Normalize {
559                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
560                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
561                        }),
562                ],
563            };
564
565            let mut image = image.apply(transforms, device)?;
566            // Pad images, calculating attention mask.
567            if config.do_pad.is_some_and(|x| x) {
568                let (_c, h, w) = image.dims3()?;
569                let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
570                let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
571                patch_masks.push(mask.unsqueeze(0)?);
572                image = padded;
573            }
574
575            // Get pixel values
576            pixel_values.push(image.unsqueeze(0)?)
577        }
578
579        Ok(PreprocessedImages {
580            pixel_values: Tensor::cat(&pixel_values, 0)?,
581            pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
582            image_sizes: None,
583            num_img_tokens: None,
584            aspect_ratio_ids: None,
585            aspect_ratio_mask: None,
586            num_tiles: None,
587            image_grid_thw: None,
588            video_grid_thw: None,
589            rows: Some(image_rows),
590            cols: Some(image_cols),
591            pixel_values_list: None,
592            tgt_sizes: None,
593            image_sizes_all: None,
594            num_crops: None,
595        })
596    }
597}