mistralrs_core/vision_models/mllama/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    any::Any,
5    collections::HashMap,
6    num::NonZeroUsize,
7    sync::{Arc, RwLock},
8};
9
10use candle_core::{Context, DType, Device, Result, Tensor};
11use image::{imageops::FilterType, DynamicImage};
12use itertools::Itertools;
13use mistralrs_vision::{
14    ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15    Transforms,
16};
17use tokenizers::Tokenizer;
18use tracing::warn;
19
20use crate::{
21    device_map::DeviceMapper,
22    pipeline::{
23        text_models_inputs_processor::{
24            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
25        },
26        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
27    },
28    sequence::Sequence,
29    vision_models::{
30        image_processor::{ImagePreProcessor, PreprocessedImages},
31        preprocessor_config::{PreProcessorConfig, ToFilter},
32        ModelInputs,
33    },
34};
35
36use super::MLlamaSpecificArgs;
37
38const IMAGE_TOKEN: &str = "<|image|>";
39
40// Input processor
41struct MLlamaImageProcessor {
42    // To represent uninitialized, we do this. Should always be init by the time this is read.
43    max_image_tiles: RwLock<Option<usize>>,
44}
45// Processor
46pub struct MLlamaProcessor;
47
48impl MLlamaProcessor {
49    pub fn new() -> Self {
50        Self
51    }
52}
53
54impl Processor for MLlamaProcessor {
55    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
56        Arc::new(MLlamaImageProcessor {
57            max_image_tiles: RwLock::new(None),
58        })
59    }
60
61    fn get_special_tokens(&self) -> &[&'static str] {
62        &[IMAGE_TOKEN, "<|python_tag|>"]
63    }
64
65    fn template_action(&self) -> MessagesAction {
66        MessagesAction::FlattenOnlyText
67    }
68}
69
70// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/processing_mllama.py#L61
71/// Generate a cross-attention token mask for image tokens in the input sequence.
72fn get_cross_attention_token_mask(input_ids: Vec<u32>, image_token_id: u32) -> Vec<(i64, i64)> {
73    let image_token_locations = input_ids
74        .iter()
75        .positions(|token| *token == image_token_id)
76        .collect::<Vec<_>>();
77
78    if image_token_locations.is_empty() {
79        return vec![];
80    }
81
82    // If only one image present, unmask until end of sequence
83    if image_token_locations.len() == 1 {
84        return vec![(image_token_locations[0] as i64, -1)];
85    }
86
87    let mut vision_masks = image_token_locations[..image_token_locations.len() - 1]
88        .iter()
89        .zip(&image_token_locations[1..])
90        .map(|(a, b)| (*a as i64, *b as i64))
91        .collect::<Vec<_>>();
92
93    // Last image will attent to all subsequent text
94    vision_masks.push((
95        *image_token_locations.last().unwrap() as i64,
96        input_ids.len() as i64,
97    ));
98
99    // If there are 2 or more consecutive vision tokens, they should all attend
100    // to all subsequent text present
101    let mut last_mask_end = vision_masks.last().unwrap().1;
102    for vision_mask in vision_masks.iter_mut().rev() {
103        if vision_mask.0 == vision_mask.1 - 1 {
104            vision_mask.1 = last_mask_end;
105        }
106        last_mask_end = vision_mask.1;
107    }
108
109    vision_masks
110}
111
112// Convert the cross attention mask indices to a cross attention mask 4D array.
113/// `cross_attention_token_mask` structure:
114/// - The outer list represents the batch dimension.
115/// - The middle list represents different images within each batch item.
116/// - The inner list contains pairs of integers [start, end] representing token ranges for each image.
117///
118/// `num_tiles`: the number of tiles for each image in each batch item.
119///
120/// NOTE: Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
121///
122/// Out shape is (batch_size, length, max_num_images, max_num_tiles). 1 means attn is allowed, 0 means it is not
123fn convert_sparse_cross_attention_mask_to_dense(
124    cross_attn_token_mask: Vec<Vec<(i64, i64)>>,
125    num_tiles: Vec<Vec<usize>>,
126    max_num_tiles: usize,
127    length: usize,
128    dev: &Device,
129) -> candle_core::Result<Tensor> {
130    let bs = cross_attn_token_mask.len();
131    let max_num_images = cross_attn_token_mask.iter().map(|x| x.len()).max().unwrap();
132
133    let mut cross_attention_mask = Tensor::zeros(
134        (bs, length, max_num_images, max_num_tiles),
135        DType::I64,
136        &Device::Cpu,
137    )?;
138
139    for (sample_idx, (sample_masks, sample_num_tiles)) in
140        cross_attn_token_mask.into_iter().zip(num_tiles).enumerate()
141    {
142        for (mask_idx, ((start, end), mask_num_tiles)) in
143            sample_masks.into_iter().zip(sample_num_tiles).enumerate()
144        {
145            let mut end = end.min(length as i64);
146            if end == -1 {
147                end = length as i64;
148            }
149            cross_attention_mask = cross_attention_mask.slice_assign(
150                &[
151                    &sample_idx,
152                    &(start as usize..end as usize),
153                    &mask_idx,
154                    &(..mask_num_tiles),
155                ],
156                &Tensor::ones(
157                    (1, end as usize - start as usize, 1, mask_num_tiles),
158                    DType::I64,
159                    &Device::Cpu,
160                )?,
161            )?;
162        }
163    }
164
165    cross_attention_mask.to_device(dev)
166}
167
168impl InputsProcessor for MLlamaImageProcessor {
169    fn get_type(&self) -> InputsProcessorType {
170        InputsProcessorType::Vision
171    }
172    fn process_inputs(
173        &self,
174        tokenizer: Option<Arc<Tokenizer>>,
175        input_seqs: &mut [&mut Sequence],
176        is_prompt: bool,
177        is_xlora: bool,
178        device: &Device,
179        no_kv_cache: bool,
180        last_n_context_len: Option<(usize, usize)>,
181        return_raw_logits: bool,
182        other_config: Option<Arc<dyn Any>>,
183        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
184        prompt_chunksize: Option<NonZeroUsize>,
185        mapper: Option<&dyn DeviceMapper>,
186    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
187        if is_xlora {
188            return Box::new(std::iter::once(Err(anyhow::Error::msg(
189                "Cannot make inputs for X-LoRA vision model.",
190            ))));
191        }
192        if no_kv_cache {
193            return Box::new(std::iter::once(Err(anyhow::Error::msg(
194                "Vision model must have kv cache.",
195            ))));
196        }
197        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
198        if prompt_chunksize.is_some() {
199            warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
200        }
201        let Some(tokenizer) = tokenizer else {
202            return Box::new(std::iter::once(Err(anyhow::Error::msg(
203                "MLlamaInputProcessor requires a specified tokenizer.",
204            ))));
205        };
206
207        let text_models_inputs_processor::InnerInputProcessorOutput {
208            inputs:
209                text_models_inputs_processor::InputMetadata {
210                    input,
211                    positions,
212                    context_lens,
213                    position_ids,
214                    paged_attn_meta,
215                    flash_meta,
216                },
217            seq_indices,
218        } = if is_prompt {
219            get_prompt_input(
220                input_seqs
221                    .iter()
222                    .map(|seq| seq.get_toks().to_vec())
223                    .collect::<Vec<_>>(),
224                input_seqs,
225                device,
226                last_n_context_len,
227                return_raw_logits,
228                paged_attn_metadata.as_mut(),
229                None, // TODO: evaluate if it is possible to batch this
230                mapper,
231            )
232            .nth(0)
233            .unwrap()
234            .unwrap()
235        } else {
236            get_completion_input(
237                input_seqs
238                    .iter()
239                    .map(|seq| seq.get_toks().to_vec())
240                    .collect::<Vec<_>>(),
241                input_seqs,
242                device,
243                no_kv_cache,
244                last_n_context_len,
245                return_raw_logits,
246                paged_attn_metadata.as_mut(),
247                None, // TODO: evaluate if it is possible to batch this
248                mapper,
249            )
250            .nth(0)
251            .unwrap()
252            .unwrap()
253        };
254        let config = other_config.expect("Need a PreProcessorConfig config.");
255        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
256
257        let has_images = input_seqs.iter().all(|seq| seq.has_images());
258
259        let (pixel_values, aspect_ratio_ids, aspect_ratio_mask, cross_attn_mask) = if has_images {
260            let mut pixel_values_accum = Vec::new();
261            let mut aspect_ratio_ids_accum = Vec::new();
262            let mut aspect_ratio_mask_accum = Vec::new();
263            let mut num_tiles_accum = Vec::new();
264
265            let bs = input_seqs.len();
266            let detokenized = tokenizer
267                .decode_batch(
268                    &input_seqs
269                        .iter()
270                        .map(|seq| seq.get_toks())
271                        .collect::<Vec<_>>(),
272                    false,
273                )
274                .expect("Detokenization failed!");
275            let n_images_in_text = detokenized
276                .iter()
277                .map(|text| text.matches(IMAGE_TOKEN).count())
278                .collect::<Vec<_>>();
279            let n_images_in_images = input_seqs
280                .iter()
281                .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
282                .collect::<Vec<_>>();
283
284            if n_images_in_text != n_images_in_images {
285                return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
286                    "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
287                )))));
288            }
289
290            let max_num_images = *n_images_in_images
291                .iter()
292                .max()
293                .expect("No max images per batch!");
294
295            for seq in input_seqs.iter_mut() {
296                let PreprocessedImages {
297                    pixel_values,
298                    pixel_attention_mask: _,
299                    image_sizes: _,
300                    num_img_tokens: _,
301                    aspect_ratio_ids,
302                    aspect_ratio_mask,
303                    num_tiles,
304                    image_grid_thw: _,
305                    video_grid_thw: _,
306                    rows: _,
307                    cols: _,
308                    pixel_values_list: _,
309                    tgt_sizes: _,
310                    image_sizes_all: _,
311                    num_crops: _,
312                } = self
313                    .preprocess(
314                        seq.take_images()
315                            .expect("Need to have images by this point."),
316                        vec![],
317                        config,
318                        device,
319                        (bs, max_num_images), // Don't use it here...
320                    )
321                    .expect("Preprocessing failed");
322                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
323                aspect_ratio_ids_accum.push(aspect_ratio_ids.unwrap().unsqueeze(0).unwrap());
324                aspect_ratio_mask_accum.push(aspect_ratio_mask.unwrap().unsqueeze(0).unwrap());
325                num_tiles_accum.push(num_tiles.unwrap());
326            }
327
328            // Create cross attn mask
329            let image_token_id = tokenizer
330                .encode_fast(IMAGE_TOKEN, false)
331                .unwrap()
332                .get_ids()
333                .to_vec();
334            let image_token_id = if image_token_id.len() == 1 {
335                image_token_id[0]
336            } else {
337                panic!("{IMAGE_TOKEN} encoding should be one token, got {image_token_id:?}");
338            };
339            let chunks = input.chunk(input.dim(0).unwrap(), 0).unwrap();
340            let cross_attention_token_mask = chunks
341                .iter()
342                .map(|token_ids| {
343                    get_cross_attention_token_mask(
344                        token_ids.squeeze(0).unwrap().to_vec1::<u32>().unwrap(),
345                        image_token_id,
346                    )
347                })
348                .collect::<Vec<_>>();
349
350            let cross_attn_mask = convert_sparse_cross_attention_mask_to_dense(
351                cross_attention_token_mask,
352                num_tiles_accum,
353                self.max_image_tiles
354                    .read()
355                    .unwrap()
356                    .expect("`max_image_tiles` must be set!"),
357                chunks
358                    .iter()
359                    .map(|input_ids| *input_ids.dims().last().unwrap())
360                    .max()
361                    .unwrap(),
362                chunks[0].device(),
363            );
364
365            let cross_attn_mask = match cross_attn_mask {
366                Ok(v) => v,
367                Err(e) => return Box::new(std::iter::once(Err(anyhow::Error::msg(e.to_string())))),
368            };
369
370            (
371                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
372                Some(Tensor::cat(&aspect_ratio_ids_accum, 0).unwrap()),
373                Some(Tensor::cat(&aspect_ratio_mask_accum, 0).unwrap()),
374                Some(cross_attn_mask),
375            )
376        } else {
377            (None, None, None, None)
378        };
379
380        let inputs: Box<dyn Any> = Box::new(ModelInputs {
381            input_ids: input,
382            seqlen_offsets: positions,
383            context_lens,
384            position_ids,
385            pixel_values,
386            model_specific_args: Box::new(MLlamaSpecificArgs {
387                aspect_ratio_ids,
388                aspect_ratio_mask,
389                cross_attn_mask,
390            }),
391            paged_attn_meta,
392            flash_meta,
393        });
394        Box::new(std::iter::once(Ok(InputProcessorOutput {
395            inputs,
396            seq_indices,
397        })))
398    }
399}
400
401fn argmin<T, I>(iter: I) -> Option<usize>
402where
403    T: PartialOrd,
404    I: Iterator<Item = T>,
405{
406    iter.enumerate()
407        .fold(None, |min, (idx, item)| match min {
408            None => Some((idx, item)),
409            Some((min_idx, min_item)) => {
410                if item < min_item {
411                    Some((idx, item))
412                } else {
413                    Some((min_idx, min_item))
414                }
415            }
416        })
417        .map(|(min_idx, _)| min_idx)
418}
419
420impl MLlamaImageProcessor {
421    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L53
422    fn get_all_supported_aspect_ratios(max_image_tiles: usize) -> Vec<(usize, usize)> {
423        (1..max_image_tiles + 1)
424            .flat_map(|width| {
425                (1..max_image_tiles + 1).filter_map(move |height| {
426                    if width * height <= max_image_tiles {
427                        Some((width, height))
428                    } else {
429                        None
430                    }
431                })
432            })
433            .collect::<Vec<_>>()
434    }
435
436    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L132
437    fn get_optimal_tiled_canvas(
438        image_height: u32,
439        image_width: u32,
440        max_image_tiles: usize,
441        tile_size: usize,
442    ) -> Result<(usize, usize)> {
443        let possible_tile_arrangements = Self::get_all_supported_aspect_ratios(max_image_tiles);
444        let possible_canvas_sizes: (Vec<_>, Vec<_>) = possible_tile_arrangements
445            .into_iter()
446            .map(|(h, w)| (h * tile_size, w * tile_size))
447            .unzip();
448        // Get all possible resolution heights/widths
449        let (target_heights, target_widths) = possible_canvas_sizes;
450
451        // Get scaling factors to resize the image without distortion
452        let scale_h = target_heights
453            .iter()
454            .map(|h| *h as f32 / image_height as f32)
455            .collect::<Vec<_>>();
456        let scale_w = target_widths
457            .iter()
458            .map(|w| *w as f32 / image_width as f32)
459            .collect::<Vec<_>>();
460
461        // Get the min scale between width and height
462        let scales = scale_h
463            .into_iter()
464            .zip(scale_w)
465            .map(|(scale_h, scale_w)| if scale_w > scale_h { scale_h } else { scale_w })
466            .collect::<Vec<_>>();
467
468        // Filter only scales that allow upscaling
469        let upscaling_options = scales
470            .iter()
471            .copied()
472            .filter(|scale| *scale >= 1.)
473            .collect::<Vec<_>>();
474        let selected_scale = if !upscaling_options.is_empty() {
475            upscaling_options
476                .into_iter()
477                .min_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
478                .context("No min, upscale")?
479        } else {
480            // No upscaling possible, get min downscaling (max scale for scales<1)
481            let downscaling_options = scales
482                .iter()
483                .copied()
484                .filter(|scale| *scale < 1.)
485                .collect::<Vec<_>>();
486            downscaling_options
487                .into_iter()
488                .max_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
489                .context("No max, downscale")?
490        };
491
492        // Get all resolutions that support this scaling factor
493        let chosen_canvas_h = target_heights
494            .iter()
495            .copied()
496            .enumerate()
497            .filter_map(|(i, h)| {
498                if scales[i] == selected_scale {
499                    Some(h)
500                } else {
501                    None
502                }
503            })
504            .collect::<Vec<_>>();
505        let chosen_canvas_w = target_widths
506            .iter()
507            .copied()
508            .enumerate()
509            .filter_map(|(i, w)| {
510                if scales[i] == selected_scale {
511                    Some(w)
512                } else {
513                    None
514                }
515            })
516            .collect::<Vec<_>>();
517
518        assert_eq!(chosen_canvas_h.len(), chosen_canvas_w.len());
519        if chosen_canvas_h.len() > 1 {
520            let optimal_idx = argmin(
521                chosen_canvas_h
522                    .iter()
523                    .zip(&chosen_canvas_w)
524                    .map(|(h, w)| *h * *w),
525            )
526            .context("No argmin")?;
527            Ok((chosen_canvas_h[optimal_idx], chosen_canvas_w[optimal_idx]))
528        } else {
529            Ok((chosen_canvas_h[0], chosen_canvas_w[0]))
530        }
531    }
532
533    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L82
534    fn get_image_size_fit_to_canvas(
535        image_height: u32,
536        image_width: u32,
537        canvas_height: usize,
538        canvas_width: usize,
539        tile_size: usize,
540    ) -> (usize, usize) {
541        let target_width = (image_width as usize).clamp(tile_size, canvas_width);
542        let target_height = (image_height as usize).clamp(tile_size, canvas_height);
543
544        let scale_h = (target_height as f32) / (image_height as f32);
545        let scale_w = (target_width as f32) / (image_width as f32);
546
547        if scale_w < scale_h {
548            (
549                target_height.min((image_height as f32 * scale_w).floor() as usize),
550                target_width,
551            )
552        } else {
553            (
554                target_height,
555                target_width.min((image_width as f32 * scale_h).floor() as usize),
556            )
557        }
558    }
559
560    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L796
561    /// Resizes an image to fit within a tiled canvas while maintaining its aspect ratio.
562    /// The optimal canvas size is calculated based on the maximum number of tiles and the tile size.
563    fn resize(
564        &self,
565        image: DynamicImage,
566        size: &HashMap<String, u32>,
567        max_image_tiles: usize,
568        filter: FilterType,
569    ) -> Result<(DynamicImage, (usize, usize))> {
570        let image_height = image.height();
571        let image_width = image.width();
572        let tile_size = size["height"] as usize;
573
574        let (canvas_height, canvas_width) =
575            Self::get_optimal_tiled_canvas(image_height, image_width, max_image_tiles, tile_size)?;
576        let num_tiles_height = canvas_height / tile_size;
577        let num_tiles_width = canvas_width / tile_size;
578
579        let (new_height, new_width) = Self::get_image_size_fit_to_canvas(
580            image_height,
581            image_width,
582            canvas_height,
583            canvas_width,
584            tile_size,
585        );
586
587        Ok((
588            image.resize_exact(new_width as u32, new_height as u32, filter),
589            (num_tiles_height, num_tiles_width),
590        ))
591    }
592
593    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L749
594    /// Pad an image to the `size` x `aspect_ratio`. For example, if size is {height: 224, width: 224} and aspect ratio is
595    /// (1, 2), the image will be padded to 224x448.
596    fn pad(
597        &self,
598        image: &Tensor,
599        size: &HashMap<String, u32>,
600        aspect_ratio: (usize, usize),
601    ) -> Result<Tensor> {
602        let (num_tiles_h, num_tiles_w) = aspect_ratio;
603        let padded_height = num_tiles_h * size["height"] as usize;
604        let padded_width = num_tiles_w * size["width"] as usize;
605
606        // Add padding on bottom and right sides
607        mistralrs_vision::pad(image, padded_height, padded_width)
608    }
609
610    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L213
611    /// Split an image into a specified number of tiles along its width and height dimensions.
612    fn split_to_tiles(
613        &self,
614        image: &Tensor,
615        num_tiles_height: usize,
616        num_tiles_width: usize,
617    ) -> Result<Tensor> {
618        let (ch, h, w) = image.dims3()?;
619        let tile_height = h / num_tiles_height;
620        let tile_width = w / num_tiles_width;
621
622        let mut image = image.reshape((
623            ch,
624            num_tiles_height,
625            tile_height,
626            num_tiles_width,
627            tile_width,
628        ))?;
629
630        // Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
631        image = image.permute((1, 3, 0, 2, 4))?;
632
633        // Reshape into the desired output shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
634        image
635            .reshape((
636                num_tiles_width * num_tiles_height,
637                ch,
638                tile_height,
639                tile_width,
640            ))?
641            .contiguous()
642    }
643
644    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L277
645    /// Returns
646    /// - stacked and packed images
647    /// - a list of lists containing the number of tiles for each image in each batch sample.
648    ///   Padding uses 0
649    fn pack_images(
650        &self,
651        images: Vec<Tensor>,
652        max_image_tiles: usize,
653        (_bs, max_num_images): (usize, usize),
654    ) -> Result<(Tensor, Vec<usize>)> {
655        let (_, ch, tile_h, tile_w) = images[0].dims4()?;
656
657        let mut stacked_images = Tensor::zeros(
658            (max_num_images, max_image_tiles, ch, tile_h, tile_w),
659            images[0].dtype(),
660            images[0].device(),
661        )?;
662        let mut num_sample_tiles = Vec::new();
663        for (i, image) in images.into_iter().enumerate() {
664            let num_tiles = image.dim(0)?;
665            stacked_images = stacked_images
666                .slice_assign(&[&i, &(..num_tiles), &.., &.., &..], &image.unsqueeze(0)?)?;
667            num_sample_tiles.push(num_tiles)
668        }
669        Ok((stacked_images, num_sample_tiles))
670    }
671
672    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L354
673    /// Convert aspect ratio tuples to unique ids.
674    /// Padding uses 0
675    fn convert_aspect_ratios_to_ids(
676        &self,
677        aspect_ratios: Vec<(usize, usize)>,
678        max_image_tiles: usize,
679        (_bs, max_num_images): (usize, usize),
680        device: &Device,
681    ) -> Result<Tensor> {
682        let supported_aspect_ratios = Self::get_all_supported_aspect_ratios(max_image_tiles);
683
684        let mut aspect_ratios_ids = vec![0i64; max_num_images];
685        for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
686            aspect_ratios_ids[i] = (supported_aspect_ratios
687                .iter()
688                .position(|(h, w)| *h == *num_tiles_h && *w == *num_tiles_w)
689                .context("Could not find aspect ratio")?
690                + 1) as i64;
691        }
692
693        Tensor::new(aspect_ratios_ids, device)
694    }
695
696    fn build_aspect_ratio_mask(
697        &self,
698        aspect_ratios: Vec<(usize, usize)>,
699        max_image_tiles: usize,
700        (_bs, max_num_images): (usize, usize),
701        device: &Device,
702    ) -> Result<Tensor> {
703        let mut aspect_ratio_mask =
704            Tensor::zeros((max_num_images, max_image_tiles), DType::I64, device)?;
705
706        // Set the first tile to 1 for all aspect ratios
707        // because in the original implementation, aspect ratios are apdded with (1,1)
708
709        aspect_ratio_mask = aspect_ratio_mask.slice_assign(
710            &[&.., &0],
711            &Tensor::ones((max_num_images, 1), DType::I64, device)?,
712        )?;
713
714        for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
715            aspect_ratio_mask = aspect_ratio_mask.slice_assign(
716                &[&i, &(..*num_tiles_h * *num_tiles_w)],
717                &Tensor::ones((1, *num_tiles_h * *num_tiles_w), DType::I64, device)?,
718            )?;
719        }
720
721        Ok(aspect_ratio_mask)
722    }
723}
724
725impl ImagePreProcessor for MLlamaImageProcessor {
726    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
727    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
728
729    fn preprocess(
730        &self,
731        images: Vec<DynamicImage>,
732        videos: Vec<Vec<DynamicImage>>,
733        config: &PreProcessorConfig,
734        device: &Device,
735        (bs, max_num_images): (usize, usize),
736    ) -> Result<PreprocessedImages> {
737        assert!(videos.is_empty());
738
739        let mut sample_images = Vec::new();
740        let mut sample_aspect_ratios = Vec::new();
741        let max_image_tiles = config
742            .max_image_tiles
743            .context("`do_resize=false` is not supported, need `max_image_tiles`!")?;
744        *self.max_image_tiles.write().unwrap() = Some(max_image_tiles);
745
746        for mut image in images {
747            // Convert to rgb, default to true
748            if config.do_convert_rgb.unwrap_or(true) {
749                image = DynamicImage::ImageRgb8(image.to_rgb8());
750            }
751
752            let size = config
753                .size
754                .as_ref()
755                .context("`do_resize=false` is not supported, need `size`!")?;
756
757            let (image, aspect_ratio) =
758                self.resize(image, size, max_image_tiles, config.resampling.to_filter()?)?;
759
760            // In transformers they rescale from [0, 255] to [0, 1]
761            // at the end of resize:
762            // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/image_transforms.py#L340
763            let to_tensor_rescale = Transforms {
764                input: &ToTensorNoNorm,
765                inner_transforms: &[],
766            };
767            let mut image = image.apply(to_tensor_rescale, device)?;
768
769            image = self.pad(&image, size, aspect_ratio)?;
770
771            let transforms = TensorTransforms {
772                inner_transforms: &[
773                    &config
774                        .do_rescale
775                        .is_some_and(|x| x)
776                        .then_some(())
777                        .map(|_| Rescale {
778                            factor: config.rescale_factor,
779                        }),
780                    &config
781                        .do_normalize
782                        .is_some_and(|x| x)
783                        .then_some(())
784                        .map(|_| Normalize {
785                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
786                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
787                        }),
788                ],
789            };
790            image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
791
792            let (num_tiles_height, num_tiles_width) = aspect_ratio;
793            image = self.split_to_tiles(&image, num_tiles_height, num_tiles_width)?;
794
795            sample_images.push(image);
796            sample_aspect_ratios.push((num_tiles_height, num_tiles_width));
797        }
798
799        let (images, num_tiles) =
800            self.pack_images(sample_images, max_image_tiles, (bs, max_num_images))?;
801
802        let aspect_ratio_ids = self.convert_aspect_ratios_to_ids(
803            sample_aspect_ratios.clone(),
804            max_image_tiles,
805            (bs, max_num_images),
806            device,
807        )?;
808        let aspect_ratio_mask = self.build_aspect_ratio_mask(
809            sample_aspect_ratios,
810            max_image_tiles,
811            (bs, max_num_images),
812            device,
813        )?;
814
815        Ok(PreprocessedImages {
816            pixel_values: images,
817            pixel_attention_mask: None,
818            image_sizes: None,
819            num_img_tokens: None,
820            aspect_ratio_ids: Some(aspect_ratio_ids),
821            aspect_ratio_mask: Some(aspect_ratio_mask),
822            num_tiles: Some(num_tiles),
823            image_grid_thw: None,
824            video_grid_thw: None,
825            rows: None,
826            cols: None,
827            pixel_values_list: None,
828            tgt_sizes: None,
829            image_sizes_all: None,
830            num_crops: None,
831        })
832    }
833}