mistralrs_core/vision_models/mllama/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    any::Any,
5    collections::HashMap,
6    num::NonZeroUsize,
7    sync::{Arc, RwLock},
8};
9
10use candle_core::{Context, DType, Device, Result, Tensor};
11use image::{imageops::FilterType, DynamicImage};
12use itertools::Itertools;
13use mistralrs_vision::{
14    ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15    Transforms,
16};
17use tokenizers::Tokenizer;
18use tracing::warn;
19
20use crate::{
21    device_map::DeviceMapper,
22    pipeline::{
23        text_models_inputs_processor::{
24            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
25        },
26        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
27    },
28    sequence::Sequence,
29    vision_models::{
30        image_processor::{ImagePreProcessor, PreprocessedImages},
31        preprocessor_config::{PreProcessorConfig, ToFilter},
32        ModelInputs,
33    },
34};
35
36use super::MLlamaSpecificArgs;
37
38const IMAGE_TOKEN: &str = "<|image|>";
39
40// Input processor
41struct MLlamaImageProcessor {
42    // To represent uninitialized, we do this. Should always be init by the time this is read.
43    max_image_tiles: RwLock<Option<usize>>,
44}
45// Processor
46pub struct MLlamaProcessor;
47
48impl MLlamaProcessor {
49    pub fn new() -> Self {
50        Self
51    }
52}
53
54impl Processor for MLlamaProcessor {
55    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
56        Arc::new(MLlamaImageProcessor {
57            max_image_tiles: RwLock::new(None),
58        })
59    }
60
61    fn get_special_tokens(&self) -> &[&'static str] {
62        &[IMAGE_TOKEN, "<|python_tag|>"]
63    }
64
65    fn template_action(&self) -> MessagesAction {
66        MessagesAction::FlattenOnlyText
67    }
68}
69
70// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/processing_mllama.py#L61
71/// Generate a cross-attention token mask for image tokens in the input sequence.
72fn get_cross_attention_token_mask(input_ids: Vec<u32>, image_token_id: u32) -> Vec<(i64, i64)> {
73    let image_token_locations = input_ids
74        .iter()
75        .positions(|token| *token == image_token_id)
76        .collect::<Vec<_>>();
77
78    if image_token_locations.is_empty() {
79        return vec![];
80    }
81
82    // If only one image present, unmask until end of sequence
83    if image_token_locations.len() == 1 {
84        return vec![(image_token_locations[0] as i64, -1)];
85    }
86
87    let mut vision_masks = image_token_locations[..image_token_locations.len() - 1]
88        .iter()
89        .zip(&image_token_locations[1..])
90        .map(|(a, b)| (*a as i64, *b as i64))
91        .collect::<Vec<_>>();
92
93    // Last image will attent to all subsequent text
94    vision_masks.push((
95        *image_token_locations.last().unwrap() as i64,
96        input_ids.len() as i64,
97    ));
98
99    // If there are 2 or more consecutive vision tokens, they should all attend
100    // to all subsequent text present
101    let mut last_mask_end = vision_masks.last().unwrap().1;
102    for vision_mask in vision_masks.iter_mut().rev() {
103        if vision_mask.0 == vision_mask.1 - 1 {
104            vision_mask.1 = last_mask_end;
105        }
106        last_mask_end = vision_mask.1;
107    }
108
109    vision_masks
110}
111
112// Convert the cross attention mask indices to a cross attention mask 4D array.
113/// `cross_attention_token_mask` structure:
114/// - The outer list represents the batch dimension.
115/// - The middle list represents different images within each batch item.
116/// - The inner list contains pairs of integers [start, end] representing token ranges for each image.
117///
118/// `num_tiles`: the number of tiles for each image in each batch item.
119///
120/// NOTE: Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
121///
122/// Out shape is (batch_size, length, max_num_images, max_num_tiles). 1 means attn is allowed, 0 means it is not
123fn convert_sparse_cross_attention_mask_to_dense(
124    cross_attn_token_mask: Vec<Vec<(i64, i64)>>,
125    num_tiles: Vec<Vec<usize>>,
126    max_num_tiles: usize,
127    length: usize,
128    dev: &Device,
129) -> candle_core::Result<Tensor> {
130    let bs = cross_attn_token_mask.len();
131    let max_num_images = cross_attn_token_mask.iter().map(|x| x.len()).max().unwrap();
132
133    let mut cross_attention_mask = Tensor::zeros(
134        (bs, length, max_num_images, max_num_tiles),
135        DType::I64,
136        &Device::Cpu,
137    )?;
138
139    for (sample_idx, (sample_masks, sample_num_tiles)) in
140        cross_attn_token_mask.into_iter().zip(num_tiles).enumerate()
141    {
142        for (mask_idx, ((start, end), mask_num_tiles)) in
143            sample_masks.into_iter().zip(sample_num_tiles).enumerate()
144        {
145            let mut end = end.min(length as i64);
146            if end == -1 {
147                end = length as i64;
148            }
149            cross_attention_mask = cross_attention_mask.slice_assign(
150                &[
151                    &sample_idx,
152                    &(start as usize..end as usize),
153                    &mask_idx,
154                    &(..mask_num_tiles),
155                ],
156                &Tensor::ones(
157                    (1, end as usize - start as usize, 1, mask_num_tiles),
158                    DType::I64,
159                    &Device::Cpu,
160                )?,
161            )?;
162        }
163    }
164
165    cross_attention_mask.to_device(dev)
166}
167
168impl InputsProcessor for MLlamaImageProcessor {
169    fn get_type(&self) -> InputsProcessorType {
170        InputsProcessorType::Vision
171    }
172    fn process_inputs(
173        &self,
174        tokenizer: Option<Arc<Tokenizer>>,
175        input_seqs: &mut [&mut Sequence],
176        is_prompt: bool,
177        is_xlora: bool,
178        device: &Device,
179        no_kv_cache: bool,
180        last_n_context_len: Option<(usize, usize)>,
181        return_raw_logits: bool,
182        other_config: Option<Arc<dyn Any>>,
183        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
184        prompt_chunksize: Option<NonZeroUsize>,
185        mapper: Option<&dyn DeviceMapper>,
186    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
187        if is_xlora {
188            return Box::new(std::iter::once(Err(anyhow::Error::msg(
189                "Cannot make inputs for X-LoRA vision model.",
190            ))));
191        }
192        if no_kv_cache {
193            return Box::new(std::iter::once(Err(anyhow::Error::msg(
194                "Vision model must have kv cache.",
195            ))));
196        }
197        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
198        if prompt_chunksize.is_some() {
199            warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
200        }
201        let Some(tokenizer) = tokenizer else {
202            return Box::new(std::iter::once(Err(anyhow::Error::msg(
203                "MLlamaInputProcessor requires a specified tokenizer.",
204            ))));
205        };
206
207        let text_models_inputs_processor::InnerInputProcessorOutput {
208            inputs:
209                text_models_inputs_processor::InputMetadata {
210                    input,
211                    positions: _,
212                    context_lens: _,
213                    position_ids: _,
214                    paged_attn_meta: _,
215                    flash_meta: _,
216                },
217            seq_indices: _,
218        } = if is_prompt {
219            get_prompt_input(
220                input_seqs
221                    .iter()
222                    .map(|seq| seq.get_toks().to_vec())
223                    .collect::<Vec<_>>(),
224                input_seqs,
225                device,
226                last_n_context_len,
227                return_raw_logits,
228                paged_attn_metadata.as_mut(),
229                None, // TODO: evaluate if it is possible to batch this
230                mapper,
231            )
232            .nth(0)
233            .unwrap()
234            .unwrap()
235        } else {
236            get_completion_input(
237                input_seqs
238                    .iter()
239                    .map(|seq| seq.get_toks().to_vec())
240                    .collect::<Vec<_>>(),
241                input_seqs,
242                device,
243                no_kv_cache,
244                last_n_context_len,
245                return_raw_logits,
246                paged_attn_metadata.as_mut(),
247                None, // TODO: evaluate if it is possible to batch this
248                mapper,
249            )
250            .nth(0)
251            .unwrap()
252            .unwrap()
253        };
254        let config = other_config.expect("Need a PreProcessorConfig config.");
255        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
256
257        let has_images = input_seqs.iter().all(|seq| seq.has_images());
258
259        let (pixel_values, aspect_ratio_ids, aspect_ratio_mask, cross_attn_mask) = if has_images {
260            let mut pixel_values_accum = Vec::new();
261            let mut aspect_ratio_ids_accum = Vec::new();
262            let mut aspect_ratio_mask_accum = Vec::new();
263            let mut num_tiles_accum = Vec::new();
264
265            let bs = input_seqs.len();
266            let detokenized = tokenizer
267                .decode_batch(
268                    &input_seqs
269                        .iter()
270                        .map(|seq| seq.get_toks())
271                        .collect::<Vec<_>>(),
272                    false,
273                )
274                .expect("Detokenization failed!");
275            let n_images_in_text = detokenized
276                .iter()
277                .map(|text| text.matches(IMAGE_TOKEN).count())
278                .collect::<Vec<_>>();
279            let n_images_in_images = input_seqs
280                .iter()
281                .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
282                .collect::<Vec<_>>();
283
284            if n_images_in_text != n_images_in_images {
285                return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
286                    "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
287                )))));
288            }
289
290            let max_num_images = *n_images_in_images
291                .iter()
292                .max()
293                .expect("No max images per batch!");
294
295            for seq in input_seqs.iter_mut() {
296                let PreprocessedImages {
297                    pixel_values,
298                    pixel_attention_mask: _,
299                    image_sizes: _,
300                    num_img_tokens: _,
301                    aspect_ratio_ids,
302                    aspect_ratio_mask,
303                    num_tiles,
304                    image_grid_thw: _,
305                    video_grid_thw: _,
306                    rows: _,
307                    cols: _,
308                    pixel_values_list: _,
309                    tgt_sizes: _,
310                    image_sizes_all: _,
311                    num_crops: _,
312                } = self
313                    .preprocess(
314                        seq.take_images()
315                            .expect("Need to have images by this point."),
316                        vec![],
317                        config,
318                        device,
319                        (bs, max_num_images), // Don't use it here...
320                    )
321                    .expect("Preprocessing failed");
322                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
323                aspect_ratio_ids_accum.push(aspect_ratio_ids.unwrap().unsqueeze(0).unwrap());
324                aspect_ratio_mask_accum.push(aspect_ratio_mask.unwrap().unsqueeze(0).unwrap());
325                num_tiles_accum.push(num_tiles.unwrap());
326            }
327
328            // Create cross attn mask
329            let image_token_id = tokenizer
330                .encode_fast(IMAGE_TOKEN, false)
331                .unwrap()
332                .get_ids()
333                .to_vec();
334            let image_token_id = if image_token_id.len() == 1 {
335                image_token_id[0]
336            } else {
337                panic!("{IMAGE_TOKEN} encoding should be one token, got {image_token_id:?}");
338            };
339            let chunks = input.chunk(input.dim(0).unwrap(), 0).unwrap();
340            let cross_attention_token_mask = chunks
341                .iter()
342                .map(|token_ids| {
343                    get_cross_attention_token_mask(
344                        token_ids.squeeze(0).unwrap().to_vec1::<u32>().unwrap(),
345                        image_token_id,
346                    )
347                })
348                .collect::<Vec<_>>();
349
350            let cross_attn_mask = convert_sparse_cross_attention_mask_to_dense(
351                cross_attention_token_mask,
352                num_tiles_accum,
353                self.max_image_tiles
354                    .read()
355                    .unwrap()
356                    .expect("`max_image_tiles` must be set!"),
357                chunks
358                    .iter()
359                    .map(|input_ids| *input_ids.dims().last().unwrap())
360                    .max()
361                    .unwrap(),
362                chunks[0].device(),
363            );
364
365            let cross_attn_mask = match cross_attn_mask {
366                Ok(v) => v,
367                Err(e) => return Box::new(std::iter::once(Err(anyhow::Error::msg(e.to_string())))),
368            };
369
370            (
371                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
372                Some(Tensor::cat(&aspect_ratio_ids_accum, 0).unwrap()),
373                Some(Tensor::cat(&aspect_ratio_mask_accum, 0).unwrap()),
374                Some(cross_attn_mask),
375            )
376        } else {
377            (None, None, None, None)
378        };
379
380        let text_models_inputs_processor::InnerInputProcessorOutput {
381            inputs:
382                text_models_inputs_processor::InputMetadata {
383                    input,
384                    positions,
385                    context_lens,
386                    position_ids,
387                    paged_attn_meta,
388                    flash_meta,
389                },
390            seq_indices,
391        } = if is_prompt {
392            get_prompt_input(
393                input_seqs
394                    .iter()
395                    .map(|seq| seq.get_toks().to_vec())
396                    .collect::<Vec<_>>(),
397                input_seqs,
398                device,
399                last_n_context_len,
400                return_raw_logits,
401                paged_attn_metadata.as_mut(),
402                None, // TODO: evaluate if it is possible to batch this
403                mapper,
404            )
405            .nth(0)
406            .unwrap()
407            .unwrap()
408        } else {
409            get_completion_input(
410                input_seqs
411                    .iter()
412                    .map(|seq| seq.get_toks().to_vec())
413                    .collect::<Vec<_>>(),
414                input_seqs,
415                device,
416                no_kv_cache,
417                last_n_context_len,
418                return_raw_logits,
419                paged_attn_metadata.as_mut(),
420                None, // TODO: evaluate if it is possible to batch this
421                mapper,
422            )
423            .nth(0)
424            .unwrap()
425            .unwrap()
426        };
427
428        let inputs: Box<dyn Any> = Box::new(ModelInputs {
429            input_ids: input,
430            seqlen_offsets: positions,
431            context_lens,
432            position_ids,
433            pixel_values,
434            model_specific_args: Box::new(MLlamaSpecificArgs {
435                aspect_ratio_ids,
436                aspect_ratio_mask,
437                cross_attn_mask,
438            }),
439            paged_attn_meta,
440            flash_meta,
441        });
442        Box::new(std::iter::once(Ok(InputProcessorOutput {
443            inputs,
444            seq_indices,
445        })))
446    }
447}
448
449fn argmin<T, I>(iter: I) -> Option<usize>
450where
451    T: PartialOrd,
452    I: Iterator<Item = T>,
453{
454    iter.enumerate()
455        .fold(None, |min, (idx, item)| match min {
456            None => Some((idx, item)),
457            Some((min_idx, min_item)) => {
458                if item < min_item {
459                    Some((idx, item))
460                } else {
461                    Some((min_idx, min_item))
462                }
463            }
464        })
465        .map(|(min_idx, _)| min_idx)
466}
467
468impl MLlamaImageProcessor {
469    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L53
470    fn get_all_supported_aspect_ratios(max_image_tiles: usize) -> Vec<(usize, usize)> {
471        (1..max_image_tiles + 1)
472            .flat_map(|width| {
473                (1..max_image_tiles + 1).filter_map(move |height| {
474                    if width * height <= max_image_tiles {
475                        Some((width, height))
476                    } else {
477                        None
478                    }
479                })
480            })
481            .collect::<Vec<_>>()
482    }
483
484    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L132
485    fn get_optimal_tiled_canvas(
486        image_height: u32,
487        image_width: u32,
488        max_image_tiles: usize,
489        tile_size: usize,
490    ) -> Result<(usize, usize)> {
491        let possible_tile_arrangements = Self::get_all_supported_aspect_ratios(max_image_tiles);
492        let possible_canvas_sizes: (Vec<_>, Vec<_>) = possible_tile_arrangements
493            .into_iter()
494            .map(|(h, w)| (h * tile_size, w * tile_size))
495            .unzip();
496        // Get all possible resolution heights/widths
497        let (target_heights, target_widths) = possible_canvas_sizes;
498
499        // Get scaling factors to resize the image without distortion
500        let scale_h = target_heights
501            .iter()
502            .map(|h| *h as f32 / image_height as f32)
503            .collect::<Vec<_>>();
504        let scale_w = target_widths
505            .iter()
506            .map(|w| *w as f32 / image_width as f32)
507            .collect::<Vec<_>>();
508
509        // Get the min scale between width and height
510        let scales = scale_h
511            .into_iter()
512            .zip(scale_w)
513            .map(|(scale_h, scale_w)| if scale_w > scale_h { scale_h } else { scale_w })
514            .collect::<Vec<_>>();
515
516        // Filter only scales that allow upscaling
517        let upscaling_options = scales
518            .iter()
519            .copied()
520            .filter(|scale| *scale >= 1.)
521            .collect::<Vec<_>>();
522        let selected_scale = if !upscaling_options.is_empty() {
523            upscaling_options
524                .into_iter()
525                .min_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
526                .context("No min, upscale")?
527        } else {
528            // No upscaling possible, get min downscaling (max scale for scales<1)
529            let downscaling_options = scales
530                .iter()
531                .copied()
532                .filter(|scale| *scale < 1.)
533                .collect::<Vec<_>>();
534            downscaling_options
535                .into_iter()
536                .max_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
537                .context("No max, downscale")?
538        };
539
540        // Get all resolutions that support this scaling factor
541        let chosen_canvas_h = target_heights
542            .iter()
543            .copied()
544            .enumerate()
545            .filter_map(|(i, h)| {
546                if scales[i] == selected_scale {
547                    Some(h)
548                } else {
549                    None
550                }
551            })
552            .collect::<Vec<_>>();
553        let chosen_canvas_w = target_widths
554            .iter()
555            .copied()
556            .enumerate()
557            .filter_map(|(i, w)| {
558                if scales[i] == selected_scale {
559                    Some(w)
560                } else {
561                    None
562                }
563            })
564            .collect::<Vec<_>>();
565
566        assert_eq!(chosen_canvas_h.len(), chosen_canvas_w.len());
567        if chosen_canvas_h.len() > 1 {
568            let optimal_idx = argmin(
569                chosen_canvas_h
570                    .iter()
571                    .zip(&chosen_canvas_w)
572                    .map(|(h, w)| *h * *w),
573            )
574            .context("No argmin")?;
575            Ok((chosen_canvas_h[optimal_idx], chosen_canvas_w[optimal_idx]))
576        } else {
577            Ok((chosen_canvas_h[0], chosen_canvas_w[0]))
578        }
579    }
580
581    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L82
582    fn get_image_size_fit_to_canvas(
583        image_height: u32,
584        image_width: u32,
585        canvas_height: usize,
586        canvas_width: usize,
587        tile_size: usize,
588    ) -> (usize, usize) {
589        let target_width = (image_width as usize).clamp(tile_size, canvas_width);
590        let target_height = (image_height as usize).clamp(tile_size, canvas_height);
591
592        let scale_h = (target_height as f32) / (image_height as f32);
593        let scale_w = (target_width as f32) / (image_width as f32);
594
595        if scale_w < scale_h {
596            (
597                target_height.min((image_height as f32 * scale_w).floor() as usize),
598                target_width,
599            )
600        } else {
601            (
602                target_height,
603                target_width.min((image_width as f32 * scale_h).floor() as usize),
604            )
605        }
606    }
607
608    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L796
609    /// Resizes an image to fit within a tiled canvas while maintaining its aspect ratio.
610    /// The optimal canvas size is calculated based on the maximum number of tiles and the tile size.
611    fn resize(
612        &self,
613        image: DynamicImage,
614        size: &HashMap<String, u32>,
615        max_image_tiles: usize,
616        filter: FilterType,
617    ) -> Result<(DynamicImage, (usize, usize))> {
618        let image_height = image.height();
619        let image_width = image.width();
620        let tile_size = size["height"] as usize;
621
622        let (canvas_height, canvas_width) =
623            Self::get_optimal_tiled_canvas(image_height, image_width, max_image_tiles, tile_size)?;
624        let num_tiles_height = canvas_height / tile_size;
625        let num_tiles_width = canvas_width / tile_size;
626
627        let (new_height, new_width) = Self::get_image_size_fit_to_canvas(
628            image_height,
629            image_width,
630            canvas_height,
631            canvas_width,
632            tile_size,
633        );
634
635        Ok((
636            image.resize_exact(new_width as u32, new_height as u32, filter),
637            (num_tiles_height, num_tiles_width),
638        ))
639    }
640
641    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L749
642    /// Pad an image to the `size` x `aspect_ratio`. For example, if size is {height: 224, width: 224} and aspect ratio is
643    /// (1, 2), the image will be padded to 224x448.
644    fn pad(
645        &self,
646        image: &Tensor,
647        size: &HashMap<String, u32>,
648        aspect_ratio: (usize, usize),
649    ) -> Result<Tensor> {
650        let (num_tiles_h, num_tiles_w) = aspect_ratio;
651        let padded_height = num_tiles_h * size["height"] as usize;
652        let padded_width = num_tiles_w * size["width"] as usize;
653
654        // Add padding on bottom and right sides
655        mistralrs_vision::pad(image, padded_height, padded_width)
656    }
657
658    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L213
659    /// Split an image into a specified number of tiles along its width and height dimensions.
660    fn split_to_tiles(
661        &self,
662        image: &Tensor,
663        num_tiles_height: usize,
664        num_tiles_width: usize,
665    ) -> Result<Tensor> {
666        let (ch, h, w) = image.dims3()?;
667        let tile_height = h / num_tiles_height;
668        let tile_width = w / num_tiles_width;
669
670        let mut image = image.reshape((
671            ch,
672            num_tiles_height,
673            tile_height,
674            num_tiles_width,
675            tile_width,
676        ))?;
677
678        // Permute to (num_tiles_height, num_tiles_width, num_channels, tile_height, tile_width)
679        image = image.permute((1, 3, 0, 2, 4))?;
680
681        // Reshape into the desired output shape (num_tiles_width * num_tiles_height, num_channels, tile_height, tile_width)
682        image
683            .reshape((
684                num_tiles_width * num_tiles_height,
685                ch,
686                tile_height,
687                tile_width,
688            ))?
689            .contiguous()
690    }
691
692    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L277
693    /// Returns
694    /// - stacked and packed images
695    /// - a list of lists containing the number of tiles for each image in each batch sample.
696    ///   Padding uses 0
697    fn pack_images(
698        &self,
699        images: Vec<Tensor>,
700        max_image_tiles: usize,
701        (_bs, max_num_images): (usize, usize),
702    ) -> Result<(Tensor, Vec<usize>)> {
703        let (_, ch, tile_h, tile_w) = images[0].dims4()?;
704
705        let mut stacked_images = Tensor::zeros(
706            (max_num_images, max_image_tiles, ch, tile_h, tile_w),
707            images[0].dtype(),
708            images[0].device(),
709        )?;
710        let mut num_sample_tiles = Vec::new();
711        for (i, image) in images.into_iter().enumerate() {
712            let num_tiles = image.dim(0)?;
713            stacked_images = stacked_images
714                .slice_assign(&[&i, &(..num_tiles), &.., &.., &..], &image.unsqueeze(0)?)?;
715            num_sample_tiles.push(num_tiles)
716        }
717        Ok((stacked_images, num_sample_tiles))
718    }
719
720    // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/mllama/image_processing_mllama.py#L354
721    /// Convert aspect ratio tuples to unique ids.
722    /// Padding uses 0
723    fn convert_aspect_ratios_to_ids(
724        &self,
725        aspect_ratios: Vec<(usize, usize)>,
726        max_image_tiles: usize,
727        (_bs, max_num_images): (usize, usize),
728        device: &Device,
729    ) -> Result<Tensor> {
730        let supported_aspect_ratios = Self::get_all_supported_aspect_ratios(max_image_tiles);
731
732        let mut aspect_ratios_ids = vec![0i64; max_num_images];
733        for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
734            aspect_ratios_ids[i] = (supported_aspect_ratios
735                .iter()
736                .position(|(h, w)| *h == *num_tiles_h && *w == *num_tiles_w)
737                .context("Could not find aspect ratio")?
738                + 1) as i64;
739        }
740
741        Tensor::new(aspect_ratios_ids, device)
742    }
743
744    fn build_aspect_ratio_mask(
745        &self,
746        aspect_ratios: Vec<(usize, usize)>,
747        max_image_tiles: usize,
748        (_bs, max_num_images): (usize, usize),
749        device: &Device,
750    ) -> Result<Tensor> {
751        let mut aspect_ratio_mask =
752            Tensor::zeros((max_num_images, max_image_tiles), DType::I64, device)?;
753
754        // Set the first tile to 1 for all aspect ratios
755        // because in the original implementation, aspect ratios are apdded with (1,1)
756
757        aspect_ratio_mask = aspect_ratio_mask.slice_assign(
758            &[&.., &0],
759            &Tensor::ones((max_num_images, 1), DType::I64, device)?,
760        )?;
761
762        for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
763            aspect_ratio_mask = aspect_ratio_mask.slice_assign(
764                &[&i, &(..*num_tiles_h * *num_tiles_w)],
765                &Tensor::ones((1, *num_tiles_h * *num_tiles_w), DType::I64, device)?,
766            )?;
767        }
768
769        Ok(aspect_ratio_mask)
770    }
771}
772
773impl ImagePreProcessor for MLlamaImageProcessor {
774    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
775    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
776
777    fn preprocess(
778        &self,
779        images: Vec<DynamicImage>,
780        videos: Vec<Vec<DynamicImage>>,
781        config: &PreProcessorConfig,
782        device: &Device,
783        (bs, max_num_images): (usize, usize),
784    ) -> Result<PreprocessedImages> {
785        assert!(videos.is_empty());
786
787        let mut sample_images = Vec::new();
788        let mut sample_aspect_ratios = Vec::new();
789        let max_image_tiles = config
790            .max_image_tiles
791            .context("`do_resize=false` is not supported, need `max_image_tiles`!")?;
792        *self.max_image_tiles.write().unwrap() = Some(max_image_tiles);
793
794        for mut image in images {
795            // Convert to rgb, default to true
796            if config.do_convert_rgb.unwrap_or(true) {
797                image = DynamicImage::ImageRgb8(image.to_rgb8());
798            }
799
800            let size = config
801                .size
802                .as_ref()
803                .context("`do_resize=false` is not supported, need `size`!")?;
804
805            let (image, aspect_ratio) =
806                self.resize(image, size, max_image_tiles, config.resampling.to_filter()?)?;
807
808            // In transformers they rescale from [0, 255] to [0, 1]
809            // at the end of resize:
810            // https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/image_transforms.py#L340
811            let to_tensor_rescale = Transforms {
812                input: &ToTensorNoNorm,
813                inner_transforms: &[],
814            };
815            let mut image = image.apply(to_tensor_rescale, device)?;
816
817            image = self.pad(&image, size, aspect_ratio)?;
818
819            let transforms = TensorTransforms {
820                inner_transforms: &[
821                    &config
822                        .do_rescale
823                        .is_some_and(|x| x)
824                        .then_some(())
825                        .map(|_| Rescale {
826                            factor: config.rescale_factor,
827                        }),
828                    &config
829                        .do_normalize
830                        .is_some_and(|x| x)
831                        .then_some(())
832                        .map(|_| Normalize {
833                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
834                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
835                        }),
836                ],
837            };
838            image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
839
840            let (num_tiles_height, num_tiles_width) = aspect_ratio;
841            image = self.split_to_tiles(&image, num_tiles_height, num_tiles_width)?;
842
843            sample_images.push(image);
844            sample_aspect_ratios.push((num_tiles_height, num_tiles_width));
845        }
846
847        let (images, num_tiles) =
848            self.pack_images(sample_images, max_image_tiles, (bs, max_num_images))?;
849
850        let aspect_ratio_ids = self.convert_aspect_ratios_to_ids(
851            sample_aspect_ratios.clone(),
852            max_image_tiles,
853            (bs, max_num_images),
854            device,
855        )?;
856        let aspect_ratio_mask = self.build_aspect_ratio_mask(
857            sample_aspect_ratios,
858            max_image_tiles,
859            (bs, max_num_images),
860            device,
861        )?;
862
863        Ok(PreprocessedImages {
864            pixel_values: images,
865            pixel_attention_mask: None,
866            image_sizes: None,
867            num_img_tokens: None,
868            aspect_ratio_ids: Some(aspect_ratio_ids),
869            aspect_ratio_mask: Some(aspect_ratio_mask),
870            num_tiles: Some(num_tiles),
871            image_grid_thw: None,
872            video_grid_thw: None,
873            rows: None,
874            cols: None,
875            pixel_values_list: None,
876            tgt_sizes: None,
877            image_sizes_all: None,
878            num_crops: None,
879        })
880    }
881}