mistralrs_core/vision_models/mllama/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    any::Any,
5    collections::HashMap,
6    sync::{Arc, RwLock},
7};
8
9use candle_core::{Context, DType, Device, Result, Tensor};
10use image::{imageops::FilterType, DynamicImage};
11use itertools::Itertools;
12use mistralrs_vision::{
13    ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
14    Transforms,
15};
16use tokenizers::Tokenizer;
17
18use crate::{
19    device_map::DeviceMapper,
20    pipeline::{
21        text_models_inputs_processor::{
22            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
23        },
24        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
25    },
26    sequence::Sequence,
27    vision_models::{
28        image_processor::{ImagePreProcessor, PreprocessedImages},
29        preprocessor_config::{PreProcessorConfig, ToFilter},
30        ModelInputs,
31    },
32};
33
34use super::MLlamaSpecificArgs;
35
36const IMAGE_TOKEN: &str = "<|image|>";
37
38// Input processor
39struct MLlamaImageProcessor {
40    // To represent uninitialized, we do this. Should always be init by the time this is read.
41    max_image_tiles: RwLock<Option<usize>>,
42}
43// Processor
44pub struct MLlamaProcessor;
45
46impl MLlamaProcessor {
47    pub fn new() -> Self {
48        Self
49    }
50}
51
52impl Processor for MLlamaProcessor {
53    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
54        Arc::new(MLlamaImageProcessor {
55            max_image_tiles: RwLock::new(None),
56        })
57    }
58
59    fn get_special_tokens(&self) -> &[&'static str] {
60        &[IMAGE_TOKEN, "<|python_tag|>"]
61    }
62
63    fn template_action(&self) -> MessagesAction {
64        MessagesAction::FlattenOnlyText
65    }
66}
67
68// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/processing_mllama.py#L61
69/// Generate a cross-attention token mask for image tokens in the input sequence.
70fn get_cross_attention_token_mask(input_ids: Vec<u32>, image_token_id: u32) -> Vec<(i64, i64)> {
71    let image_token_locations = input_ids
72        .iter()
73        .positions(|token| *token == image_token_id)
74        .collect::<Vec<_>>();
75
76    if image_token_locations.is_empty() {
77        return vec![];
78    }
79
80    // If only one image present, unmask until end of sequence
81    if image_token_locations.len() == 1 {
82        return vec![(image_token_locations[0] as i64, -1)];
83    }
84
85    let mut vision_masks = image_token_locations[..image_token_locations.len() - 1]
86        .iter()
87        .zip(&image_token_locations[1..])
88        .map(|(a, b)| (*a as i64, *b as i64))
89        .collect::<Vec<_>>();
90
91    // Last image will attent to all subsequent text
92    vision_masks.push((
93        *image_token_locations.last().unwrap() as i64,
94        input_ids.len() as i64,
95    ));
96
97    // If there are 2 or more consecutive vision tokens, they should all attend
98    // to all subsequent text present
99    let mut last_mask_end = vision_masks.last().unwrap().1;
100    for vision_mask in vision_masks.iter_mut().rev() {
101        if vision_mask.0 == vision_mask.1 - 1 {
102            vision_mask.1 = last_mask_end;
103        }
104        last_mask_end = vision_mask.1;
105    }
106
107    vision_masks
108}
109
110// Convert the cross attention mask indices to a cross attention mask 4D array.
111/// `cross_attention_token_mask` structure:
112/// - The outer list represents the batch dimension.
113/// - The middle list represents different images within each batch item.
114/// - The inner list contains pairs of integers [start, end] representing token ranges for each image.
115///
116/// `num_tiles`: the number of tiles for each image in each batch item.
117///
118/// NOTE: Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
119///
120/// Out shape is (batch_size, length, max_num_images, max_num_tiles). 1 means attn is allowed, 0 means it is not
121fn convert_sparse_cross_attention_mask_to_dense(
122    cross_attn_token_mask: Vec<Vec<(i64, i64)>>,
123    num_tiles: Vec<Vec<usize>>,
124    max_num_tiles: usize,
125    length: usize,
126    dev: &Device,
127) -> candle_core::Result<Tensor> {
128    let bs = cross_attn_token_mask.len();
129    let max_num_images = cross_attn_token_mask.iter().map(|x| x.len()).max().unwrap();
130
131    let mut cross_attention_mask = Tensor::zeros(
132        (bs, length, max_num_images, max_num_tiles),
133        DType::I64,
134        &Device::Cpu,
135    )?;
136
137    for (sample_idx, (sample_masks, sample_num_tiles)) in
138        cross_attn_token_mask.into_iter().zip(num_tiles).enumerate()
139    {
140        for (mask_idx, ((start, end), mask_num_tiles)) in
141            sample_masks.into_iter().zip(sample_num_tiles).enumerate()
142        {
143            let mut end = end.min(length as i64);
144            if end == -1 {
145                end = length as i64;
146            }
147            cross_attention_mask = cross_attention_mask.slice_assign(
148                &[
149                    sample_idx..sample_idx + 1,
150                    start as usize..end as usize,
151                    mask_idx..mask_idx + 1,
152                    0..mask_num_tiles,
153                ],
154                &Tensor::ones(
155                    (1, end as usize - start as usize, 1, mask_num_tiles),
156                    DType::I64,
157                    &Device::Cpu,
158                )?,
159            )?;
160        }
161    }
162
163    cross_attention_mask.to_device(dev)
164}
165
166impl InputsProcessor for MLlamaImageProcessor {
167    fn get_type(&self) -> InputsProcessorType {
168        InputsProcessorType::Vision
169    }
170    fn process_inputs(
171        &self,
172        tokenizer: Option<Arc<Tokenizer>>,
173        input_seqs: &mut [&mut Sequence],
174        is_prompt: bool,
175        is_xlora: bool,
176        device: &Device,
177        no_kv_cache: bool,
178        last_n_context_len: Option<(usize, usize)>,
179        return_raw_logits: bool,
180        other_config: Option<Arc<dyn Any>>,
181        mut paged_attn_metadata: Option<PagedAttentionMeta>,
182        mapper: Option<&dyn DeviceMapper>,
183    ) -> anyhow::Result<InputProcessorOutput> {
184        if is_xlora {
185            return Err(anyhow::Error::msg(
186                "Cannot make inputs for X-LoRA vision model.",
187            ));
188        }
189        if no_kv_cache {
190            return Err(anyhow::Error::msg("Vision model must have kv cache."));
191        }
192        let Some(tokenizer) = tokenizer else {
193            return Err(anyhow::Error::msg(
194                "MLlamaInputProcessor requires a specified tokenizer.",
195            ));
196        };
197
198        let text_models_inputs_processor::InnerInputProcessorOutput {
199            inputs:
200                text_models_inputs_processor::InputMetadata {
201                    input,
202                    positions: _,
203                    context_lens: _,
204                    position_ids: _,
205                    paged_attn_meta: _,
206                    flash_meta: _,
207                },
208            seq_indices: _,
209        } = if is_prompt {
210            get_prompt_input(
211                input_seqs
212                    .iter()
213                    .map(|seq| seq.get_toks())
214                    .collect::<Vec<_>>(),
215                input_seqs,
216                device,
217                last_n_context_len,
218                return_raw_logits,
219                paged_attn_metadata.as_mut(),
220                mapper,
221            )
222            .unwrap()
223        } else {
224            get_completion_input(
225                input_seqs
226                    .iter()
227                    .map(|seq| seq.get_toks())
228                    .collect::<Vec<_>>(),
229                input_seqs,
230                device,
231                no_kv_cache,
232                last_n_context_len,
233                return_raw_logits,
234                paged_attn_metadata.as_mut(),
235                mapper,
236            )
237            .unwrap()
238        };
239        let config = other_config.expect("Need a PreProcessorConfig config.");
240        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
241
242        let has_images = input_seqs.iter().all(|seq| seq.has_images());
243
244        let (pixel_values, aspect_ratio_ids, aspect_ratio_mask, cross_attn_mask) = if has_images {
245            let mut pixel_values_accum = Vec::new();
246            let mut aspect_ratio_ids_accum = Vec::new();
247            let mut aspect_ratio_mask_accum = Vec::new();
248            let mut num_tiles_accum = Vec::new();
249
250            let bs = input_seqs.len();
251            let detokenized = tokenizer
252                .decode_batch(
253                    &input_seqs
254                        .iter()
255                        .map(|seq| seq.get_toks())
256                        .collect::<Vec<_>>(),
257                    false,
258                )
259                .expect("Detokenization failed!");
260            let n_images_in_text = detokenized
261                .iter()
262                .map(|text| text.matches(IMAGE_TOKEN).count())
263                .collect::<Vec<_>>();
264            let n_images_in_images = input_seqs
265                .iter()
266                .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
267                .collect::<Vec<_>>();
268
269            if n_images_in_text != n_images_in_images {
270                return Err(anyhow::Error::msg(format!(
271                    "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
272                )));
273            }
274
275            let max_num_images = *n_images_in_images
276                .iter()
277                .max()
278                .expect("No max images per batch!");
279
280            for seq in input_seqs.iter_mut() {
281                let PreprocessedImages {
282                    pixel_values,
283                    pixel_attention_mask: _,
284                    image_sizes: _,
285                    num_img_tokens: _,
286                    aspect_ratio_ids,
287                    aspect_ratio_mask,
288                    num_tiles,
289                    image_grid_thw: _,
290                    video_grid_thw: _,
291                    rows: _,
292                    cols: _,
293                    pixel_values_list: _,
294                    tgt_sizes: _,
295                    image_sizes_all: _,
296                    num_crops: _,
297                } = self
298                    .preprocess(
299                        seq.take_images()
300                            .expect("Need to have images by this point."),
301                        vec![],
302                        config,
303                        device,
304                        (bs, max_num_images), // Don't use it here...
305                    )
306                    .expect("Preprocessing failed");
307                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
308                aspect_ratio_ids_accum.push(aspect_ratio_ids.unwrap().unsqueeze(0).unwrap());
309                aspect_ratio_mask_accum.push(aspect_ratio_mask.unwrap().unsqueeze(0).unwrap());
310                num_tiles_accum.push(num_tiles.unwrap());
311
312                seq.multimodal.has_changed_prompt = true;
313            }
314
315            // Create cross attn mask
316            let image_token_id = tokenizer
317                .encode_fast(IMAGE_TOKEN, false)
318                .unwrap()
319                .get_ids()
320                .to_vec();
321            let image_token_id = if image_token_id.len() == 1 {
322                image_token_id[0]
323            } else {
324                panic!("{IMAGE_TOKEN} encoding should be one token, got {image_token_id:?}");
325            };
326            let chunks = input.chunk(input.dim(0).unwrap(), 0).unwrap();
327            let cross_attention_token_mask = chunks
328                .iter()
329                .map(|token_ids| {
330                    get_cross_attention_token_mask(
331                        token_ids.squeeze(0).unwrap().to_vec1::<u32>().unwrap(),
332                        image_token_id,
333                    )
334                })
335                .collect::<Vec<_>>();
336
337            let cross_attn_mask = convert_sparse_cross_attention_mask_to_dense(
338                cross_attention_token_mask,
339                num_tiles_accum,
340                self.max_image_tiles
341                    .read()
342                    .unwrap()
343                    .expect("`max_image_tiles` must be set!"),
344                chunks
345                    .iter()
346                    .map(|input_ids| *input_ids.dims().last().unwrap())
347                    .max()
348                    .unwrap(),
349                chunks[0].device(),
350            );
351
352            let cross_attn_mask = match cross_attn_mask {
353                Ok(v) => v,
354                Err(e) => return Err(anyhow::Error::msg(e.to_string())),
355            };
356
357            (
358                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
359                Some(Tensor::cat(&aspect_ratio_ids_accum, 0).unwrap()),
360                Some(Tensor::cat(&aspect_ratio_mask_accum, 0).unwrap()),
361                Some(cross_attn_mask),
362            )
363        } else {
364            (None, None, None, None)
365        };
366
367        let text_models_inputs_processor::InnerInputProcessorOutput {
368            inputs:
369                text_models_inputs_processor::InputMetadata {
370                    input,
371                    positions,
372                    context_lens,
373                    position_ids,
374                    paged_attn_meta,
375                    flash_meta,
376                },
377            seq_indices,
378        } = if is_prompt {
379            get_prompt_input(
380                input_seqs
381                    .iter()
382                    .map(|seq| seq.get_toks())
383                    .collect::<Vec<_>>(),
384                input_seqs,
385                device,
386                last_n_context_len,
387                return_raw_logits,
388                paged_attn_metadata.as_mut(),
389                mapper,
390            )
391            .unwrap()
392        } else {
393            get_completion_input(
394                input_seqs
395                    .iter()
396                    .map(|seq| seq.get_toks())
397                    .collect::<Vec<_>>(),
398                input_seqs,
399                device,
400                no_kv_cache,
401                last_n_context_len,
402                return_raw_logits,
403                paged_attn_metadata.as_mut(),
404                mapper,
405            )
406            .unwrap()
407        };
408
409        let inputs: Box<dyn Any> = Box::new(ModelInputs {
410            input_ids: input,
411            seqlen_offsets: positions,
412            context_lens,
413            position_ids,
414            pixel_values,
415            model_specific_args: Box::new(MLlamaSpecificArgs {
416                aspect_ratio_ids,
417                aspect_ratio_mask,
418                cross_attn_mask,
419            }),
420            paged_attn_meta,
421            flash_meta,
422        });
423        Ok(InputProcessorOutput {
424            inputs,
425            seq_indices,
426        })
427    }
428}
429
430fn argmin<T, I>(iter: I) -> Option<usize>
431where
432    T: PartialOrd,
433    I: Iterator<Item = T>,
434{
435    iter.enumerate()
436        .fold(None, |min, (idx, item)| match min {
437            None => Some((idx, item)),
438            Some((min_idx, min_item)) => {
439                if item < min_item {
440                    Some((idx, item))
441                } else {
442                    Some((min_idx, min_item))
443                }
444            }
445        })
446        .map(|(min_idx, _)| min_idx)
447}
448
449impl MLlamaImageProcessor {
450    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L53
451    fn get_all_supported_aspect_ratios(max_image_tiles: usize) -> Vec<(usize, usize)> {
452        (1..max_image_tiles + 1)
453            .flat_map(|width| {
454                (1..max_image_tiles + 1).filter_map(move |height| {
455                    if width * height <= max_image_tiles {
456                        Some((width, height))
457                    } else {
458                        None
459                    }
460                })
461            })
462            .collect::<Vec<_>>()
463    }
464
465    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L132
466    fn get_optimal_tiled_canvas(
467        image_height: u32,
468        image_width: u32,
469        max_image_tiles: usize,
470        tile_size: usize,
471    ) -> Result<(usize, usize)> {
472        let possible_tile_arrangements = Self::get_all_supported_aspect_ratios(max_image_tiles);
473        let possible_canvas_sizes: (Vec<_>, Vec<_>) = possible_tile_arrangements
474            .into_iter()
475            .map(|(h, w)| (h * tile_size, w * tile_size))
476            .unzip();
477        // Get all possible resolution heights/widths
478        let (target_heights, target_widths) = possible_canvas_sizes;
479
480        // Get scaling factors to resize the image without distortion
481        let scale_h = target_heights
482            .iter()
483            .map(|h| *h as f32 / image_height as f32)
484            .collect::<Vec<_>>();
485        let scale_w = target_widths
486            .iter()
487            .map(|w| *w as f32 / image_width as f32)
488            .collect::<Vec<_>>();
489
490        // Get the min scale between width and height
491        let scales = scale_h
492            .into_iter()
493            .zip(scale_w)
494            .map(|(scale_h, scale_w)| if scale_w > scale_h { scale_h } else { scale_w })
495            .collect::<Vec<_>>();
496
497        // Filter only scales that allow upscaling
498        let upscaling_options = scales
499            .iter()
500            .copied()
501            .filter(|scale| *scale >= 1.)
502            .collect::<Vec<_>>();
503        let selected_scale = if !upscaling_options.is_empty() {
504            upscaling_options
505                .into_iter()
506                .min_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
507                .context("No min, upscale")?
508        } else {
509            // No upscaling possible, get min downscaling (max scale for scales<1)
510            let downscaling_options = scales
511                .iter()
512                .copied()
513                .filter(|scale| *scale < 1.)
514                .collect::<Vec<_>>();
515            downscaling_options
516                .into_iter()
517                .max_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
518                .context("No max, downscale")?
519        };
520
521        // Get all resolutions that support this scaling factor
522        let chosen_canvas_h = target_heights
523            .iter()
524            .copied()
525            .enumerate()
526            .filter_map(|(i, h)| {
527                if scales[i] == selected_scale {
528                    Some(h)
529                } else {
530                    None
531                }
532            })
533            .collect::<Vec<_>>();
534        let chosen_canvas_w = target_widths
535            .iter()
536            .copied()
537            .enumerate()
538            .filter_map(|(i, w)| {
539                if scales[i] == selected_scale {
540                    Some(w)
541                } else {
542                    None
543                }
544            })
545            .collect::<Vec<_>>();
546
547        assert_eq!(chosen_canvas_h.len(), chosen_canvas_w.len());
548        if chosen_canvas_h.len() > 1 {
549            let optimal_idx = argmin(
550                chosen_canvas_h
551                    .iter()
552                    .zip(&chosen_canvas_w)
553                    .map(|(h, w)| *h * *w),
554            )
555            .context("No argmin")?;
556            Ok((chosen_canvas_h[optimal_idx], chosen_canvas_w[optimal_idx]))
557        } else {
558            Ok((chosen_canvas_h[0], chosen_canvas_w[0]))
559        }
560    }
561
562    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L82
563    fn get_image_size_fit_to_canvas(
564        image_height: u32,
565        image_width: u32,
566        canvas_height: usize,
567        canvas_width: usize,
568        tile_size: usize,
569    ) -> (usize, usize) {
570        let target_width = (image_width as usize).clamp(tile_size, canvas_width);
571        let target_height = (image_height as usize).clamp(tile_size, canvas_height);
572
573        let scale_h = (target_height as f32) / (image_height as f32);
574        let scale_w = (target_width as f32) / (image_width as f32);
575
576        if scale_w < scale_h {
577            (
578                target_height.min((image_height as f32 * scale_w).floor() as usize),
579                target_width,
580            )
581        } else {
582            (
583                target_height,
584                target_width.min((image_width as f32 * scale_h).floor() as usize),
585            )
586        }
587    }
588
589    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L796
590    /// Resizes an image to fit within a tiled canvas while maintaining its aspect ratio.
591    /// The optimal canvas size is calculated based on the maximum number of tiles and the tile size.
592    fn resize(
593        &self,
594        image: DynamicImage,
595        size: &HashMap<String, u32>,
596        max_image_tiles: usize,
597        filter: FilterType,
598    ) -> Result<(DynamicImage, (usize, usize))> {
599        let image_height = image.height();
600        let image_width = image.width();
601        let tile_size = size["height"] as usize;
602
603        let (canvas_height, canvas_width) =
604            Self::get_optimal_tiled_canvas(image_height, image_width, max_image_tiles, tile_size)?;
605        let num_tiles_height = canvas_height / tile_size;
606        let num_tiles_width = canvas_width / tile_size;
607
608        let (new_height, new_width) = Self::get_image_size_fit_to_canvas(
609            image_height,
610            image_width,
611            canvas_height,
612            canvas_width,
613            tile_size,
614        );
615
616        Ok((
617            image.resize_exact(new_width as u32, new_height as u32, filter),
618            (num_tiles_height, num_tiles_width),
619        ))
620    }
621
622    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L749
623    /// Pad an image to the `size` x `aspect_ratio`. For example, if size is {height: 224, width: 224} and aspect ratio is
624    /// (1, 2), the image will be padded to 224x448.
625    fn pad(
626        &self,
627        image: &Tensor,
628        size: &HashMap<String, u32>,
629        aspect_ratio: (usize, usize),
630    ) -> Result<Tensor> {
631        let (num_tiles_h, num_tiles_w) = aspect_ratio;
632        let padded_height = num_tiles_h * size["height"] as usize;
633        let padded_width = num_tiles_w * size["width"] as usize;
634
635        // Add padding on bottom and right sides
636        mistralrs_vision::pad(image, padded_height, padded_width)
637    }
638
639    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L213
640    /// Split an image into a specified number of tiles along its width and height dimensions.
641    fn split_to_tiles(
642        &self,
643        image: &Tensor,
644        num_tiles_height: usize,
645        num_tiles_width: usize,
646    ) -> Result<Tensor> {
647        let (ch, h, w) = image.dims3()?;
648        let tile_height = h / num_tiles_height;
649        let tile_width = w / num_tiles_width;
650
651        let mut image = image.reshape((
652            ch,
653            num_tiles_height,
654            tile_height,
655            num_tiles_width,
656            tile_width,
657        ))?;
658
659        // Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
660        image = image.permute((1, 3, 0, 2, 4))?;
661
662        // Reshape into the desired output shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
663        image
664            .reshape((
665                num_tiles_width * num_tiles_height,
666                ch,
667                tile_height,
668                tile_width,
669            ))?
670            .contiguous()
671    }
672
673    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L277
674    /// Returns
675    /// - stacked and packed images
676    /// - a list of lists containing the number of tiles for each image in each batch sample.
677    ///   Padding uses 0
678    fn pack_images(
679        &self,
680        images: Vec<Tensor>,
681        max_image_tiles: usize,
682        (_bs, max_num_images): (usize, usize),
683    ) -> Result<(Tensor, Vec<usize>)> {
684        let (_, ch, tile_h, tile_w) = images[0].dims4()?;
685
686        let mut stacked_images = Tensor::zeros(
687            (max_num_images, max_image_tiles, ch, tile_h, tile_w),
688            images[0].dtype(),
689            images[0].device(),
690        )?;
691        let mut num_sample_tiles = Vec::new();
692        for (i, image) in images.into_iter().enumerate() {
693            let num_tiles = image.dim(0)?;
694            stacked_images = stacked_images.slice_assign(
695                &[i..i + 1, 0..num_tiles, 0..ch, 0..tile_h, 0..tile_w],
696                &image.unsqueeze(0)?,
697            )?;
698            num_sample_tiles.push(num_tiles)
699        }
700        Ok((stacked_images, num_sample_tiles))
701    }
702
703    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L354
704    /// Convert aspect ratio tuples to unique ids.
705    /// Padding uses 0
706    fn convert_aspect_ratios_to_ids(
707        &self,
708        aspect_ratios: Vec<(usize, usize)>,
709        max_image_tiles: usize,
710        (_bs, max_num_images): (usize, usize),
711        device: &Device,
712    ) -> Result<Tensor> {
713        let supported_aspect_ratios = Self::get_all_supported_aspect_ratios(max_image_tiles);
714
715        let mut aspect_ratios_ids = vec![0i64; max_num_images];
716        for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
717            aspect_ratios_ids[i] = (supported_aspect_ratios
718                .iter()
719                .position(|(h, w)| *h == *num_tiles_h && *w == *num_tiles_w)
720                .context("Could not find aspect ratio")?
721                + 1) as i64;
722        }
723
724        Tensor::new(aspect_ratios_ids, device)
725    }
726
727    fn build_aspect_ratio_mask(
728        &self,
729        aspect_ratios: Vec<(usize, usize)>,
730        max_image_tiles: usize,
731        (_bs, max_num_images): (usize, usize),
732        device: &Device,
733    ) -> Result<Tensor> {
734        let mut aspect_ratio_mask =
735            Tensor::zeros((max_num_images, max_image_tiles), DType::I64, device)?;
736
737        // Set the first tile to 1 for all aspect ratios
738        // because in the original implementation, aspect ratios are apdded with (1,1)
739
740        aspect_ratio_mask = aspect_ratio_mask.slice_assign(
741            &[0..max_num_images, 0..1],
742            &Tensor::ones((max_num_images, 1), DType::I64, device)?,
743        )?;
744
745        for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
746            aspect_ratio_mask = aspect_ratio_mask.slice_assign(
747                &[i..i + 1, 0..*num_tiles_h * *num_tiles_w],
748                &Tensor::ones((1, *num_tiles_h * *num_tiles_w), DType::I64, device)?,
749            )?;
750        }
751
752        Ok(aspect_ratio_mask)
753    }
754}
755
756impl ImagePreProcessor for MLlamaImageProcessor {
757    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
758    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
759
760    fn preprocess(
761        &self,
762        images: Vec<DynamicImage>,
763        videos: Vec<Vec<DynamicImage>>,
764        config: &PreProcessorConfig,
765        device: &Device,
766        (bs, max_num_images): (usize, usize),
767    ) -> Result<PreprocessedImages> {
768        assert!(videos.is_empty());
769
770        let mut sample_images = Vec::new();
771        let mut sample_aspect_ratios = Vec::new();
772        let max_image_tiles = config
773            .max_image_tiles
774            .context("`do_resize=false` is not supported, need `max_image_tiles`!")?;
775        *self.max_image_tiles.write().unwrap() = Some(max_image_tiles);
776
777        for mut image in images {
778            // Convert to rgb, default to true
779            if config.do_convert_rgb.unwrap_or(true) {
780                image = DynamicImage::ImageRgb8(image.to_rgb8());
781            }
782
783            let size = config
784                .size
785                .as_ref()
786                .context("`do_resize=false` is not supported, need `size`!")?;
787
788            let (image, aspect_ratio) =
789                self.resize(image, size, max_image_tiles, config.resampling.to_filter()?)?;
790
791            // In transformers they rescale from [0, 255] to [0, 1]
792            // at the end of resize:
793            // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/image_transforms.py#L340
794            let to_tensor_rescale = Transforms {
795                input: &ToTensorNoNorm,
796                inner_transforms: &[],
797            };
798            let mut image = image.apply(to_tensor_rescale, device)?;
799
800            image = self.pad(&image, size, aspect_ratio)?;
801
802            let transforms = TensorTransforms {
803                inner_transforms: &[
804                    &config
805                        .do_rescale
806                        .is_some_and(|x| x)
807                        .then_some(())
808                        .map(|_| Rescale {
809                            factor: config.rescale_factor,
810                        }),
811                    &config
812                        .do_normalize
813                        .is_some_and(|x| x)
814                        .then_some(())
815                        .map(|_| Normalize {
816                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
817                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
818                        }),
819                ],
820            };
821            image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
822
823            let (num_tiles_height, num_tiles_width) = aspect_ratio;
824            image = self.split_to_tiles(&image, num_tiles_height, num_tiles_width)?;
825
826            sample_images.push(image);
827            sample_aspect_ratios.push((num_tiles_height, num_tiles_width));
828        }
829
830        let (images, num_tiles) =
831            self.pack_images(sample_images, max_image_tiles, (bs, max_num_images))?;
832
833        let aspect_ratio_ids = self.convert_aspect_ratios_to_ids(
834            sample_aspect_ratios.clone(),
835            max_image_tiles,
836            (bs, max_num_images),
837            device,
838        )?;
839        let aspect_ratio_mask = self.build_aspect_ratio_mask(
840            sample_aspect_ratios,
841            max_image_tiles,
842            (bs, max_num_images),
843            device,
844        )?;
845
846        Ok(PreprocessedImages {
847            pixel_values: images,
848            pixel_attention_mask: None,
849            image_sizes: None,
850            num_img_tokens: None,
851            aspect_ratio_ids: Some(aspect_ratio_ids),
852            aspect_ratio_mask: Some(aspect_ratio_mask),
853            num_tiles: Some(num_tiles),
854            image_grid_thw: None,
855            video_grid_thw: None,
856            rows: None,
857            cols: None,
858            pixel_values_list: None,
859            tgt_sizes: None,
860            image_sizes_all: None,
861            num_crops: None,
862        })
863    }
864}