mistralrs_core/vision_models/llama4/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    any::Any,
5    collections::{HashMap, HashSet},
6    num::NonZeroUsize,
7    sync::Arc,
8};
9
10use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
11use image::DynamicImage;
12use itertools::Itertools;
13use mistralrs_vision::{
14    ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15    Transforms,
16};
17use ordered_float::NotNan;
18use tokenizers::Tokenizer;
19use tracing::warn;
20
21use crate::{
22    device_map::DeviceMapper,
23    pipeline::{
24        text_models_inputs_processor::{
25            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
26        },
27        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
28    },
29    sequence::Sequence,
30    vision_models::{
31        image_processor::{ImagePreProcessor, PreprocessedImages},
32        preprocessor_config::PreProcessorConfig,
33        processor_config::ProcessorConfig,
34        ModelInputs,
35    },
36};
37
38use super::Llama4ModelSpecificArgs;
39
40pub(crate) const IMAGE_TOKEN: &str = "<|image|>";
41const IMAGE_START: &str = "<|image_start|>";
42const IMAGE_END: &str = "<|image_end|>";
43const PATCH: &str = "<|patch|>";
44const TILE_X_SEP: &str = "<|tile_x_separator|>";
45const TILE_Y_SEP: &str = "<|tile_y_separator|>";
46
47// Input processor
48pub struct Llama4ImageProcessor {
49    pub patch_size: usize,
50    pub downsample_ratio: usize,
51}
52
53impl Llama4ImageProcessor {
54    pub fn new(patch_size: Option<usize>, pixel_shuffle_ratio: Option<f32>) -> Self {
55        Self {
56            patch_size: patch_size.unwrap_or(14),
57            downsample_ratio: (1. / pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round() as usize,
58        }
59    }
60}
61
62// Processor
63pub struct Llama4Processor {
64    patch_size: usize,
65    downsample_ratio: usize,
66}
67
68impl Llama4Processor {
69    pub fn new(cfg: &ProcessorConfig) -> Self {
70        Self {
71            patch_size: cfg.patch_size.unwrap_or(14),
72            downsample_ratio: (1. / cfg.pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round()
73                as usize,
74        }
75    }
76}
77
78impl Processor for Llama4Processor {
79    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
80        Arc::new(Llama4ImageProcessor {
81            patch_size: self.patch_size,
82            downsample_ratio: self.downsample_ratio,
83        })
84    }
85
86    fn get_special_tokens(&self) -> &[&'static str] {
87        &[
88            IMAGE_START,
89            IMAGE_END,
90            PATCH,
91            TILE_X_SEP,
92            TILE_Y_SEP,
93            IMAGE_TOKEN,
94        ]
95    }
96
97    fn template_action(&self) -> MessagesAction {
98        MessagesAction::FlattenOnlyText
99    }
100}
101
102impl Llama4ImageProcessor {
103    fn prompt_split_image(&self, aspect_ratio: &Tensor, num_patches_per_chunk: usize) -> String {
104        let mut img_string = IMAGE_START.to_string();
105        let aspect_ratio = aspect_ratio.to_vec1::<u32>().unwrap();
106        let (ratio_h, ratio_w) = (aspect_ratio[0] as usize, aspect_ratio[1] as usize);
107        if ratio_h * ratio_w > 1 {
108            for _yy in 0..ratio_h {
109                for xx in 0..ratio_w {
110                    img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
111                    if xx < ratio_w - 1 {
112                        img_string.push_str(TILE_X_SEP);
113                    }
114                }
115                img_string.push_str(TILE_Y_SEP);
116            }
117        }
118        img_string.push_str(IMAGE_TOKEN);
119        img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
120        img_string.push_str(IMAGE_END);
121        img_string
122    }
123}
124
125impl InputsProcessor for Llama4ImageProcessor {
126    fn get_type(&self) -> InputsProcessorType {
127        InputsProcessorType::Vision
128    }
129    fn process_inputs(
130        &self,
131        tokenizer: Option<Arc<Tokenizer>>,
132        input_seqs: &mut [&mut Sequence],
133        is_prompt: bool,
134        is_xlora: bool,
135        device: &Device,
136        no_kv_cache: bool,
137        last_n_context_len: Option<(usize, usize)>,
138        return_raw_logits: bool,
139        other_config: Option<Arc<dyn Any>>,
140        mut paged_attn_metadata: Option<PagedAttentionMeta>,
141        prompt_chunksize: Option<NonZeroUsize>,
142        mapper: Option<&dyn DeviceMapper>,
143    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
144        if is_xlora {
145            return Box::new(std::iter::once(Err(anyhow::Error::msg(
146                "Cannot make inputs for X-LoRA vision model.",
147            ))));
148        }
149        if no_kv_cache {
150            return Box::new(std::iter::once(Err(anyhow::Error::msg(
151                "Vision model must have kv cache.",
152            ))));
153        }
154        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
155        if prompt_chunksize.is_some() {
156            warn!("`prompt_chunksize` is set. Llama4 does not support prompt batching.");
157        }
158        let Some(tokenizer) = tokenizer else {
159            return Box::new(std::iter::once(Err(anyhow::Error::msg(
160                "Llama4InputProcessor requires a specified tokenizer.",
161            ))));
162        };
163
164        let config = other_config.expect("Need a PreProcessorConfig config.");
165        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
166
167        let has_images = input_seqs.iter().all(|seq| seq.has_images());
168
169        let pixel_values = if has_images {
170            let mut pixel_values_accum = Vec::new();
171            let mut aspect_ratios_accum = Vec::new();
172
173            let bs = input_seqs.len();
174            let detokenized = tokenizer
175                .decode_batch(
176                    &input_seqs
177                        .iter()
178                        .map(|seq| seq.get_toks())
179                        .collect::<Vec<_>>(),
180                    false,
181                )
182                .expect("Detokenization failed!");
183            let n_images_in_text = detokenized
184                .iter()
185                .map(|text| text.matches(IMAGE_TOKEN).count())
186                .collect::<Vec<_>>();
187            let n_images_in_images = input_seqs
188                .iter()
189                .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
190                .collect::<Vec<_>>();
191
192            if n_images_in_text != n_images_in_images {
193                return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
194                    "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
195                )))));
196            }
197
198            let max_num_images = *n_images_in_images
199                .iter()
200                .max()
201                .expect("No max images per batch!");
202
203            for seq in input_seqs.iter_mut() {
204                let PreprocessedImages {
205                    pixel_values,
206                    pixel_attention_mask: _,
207                    image_sizes: _,
208                    num_img_tokens: _,
209                    aspect_ratio_ids,
210                    aspect_ratio_mask: _,
211                    num_tiles: _,
212                    image_grid_thw: _,
213                    video_grid_thw: _,
214                    rows: _,
215                    cols: _,
216                    pixel_values_list: _,
217                    tgt_sizes: _,
218                    image_sizes_all: _,
219                    num_crops: _,
220                } = self
221                    .preprocess(
222                        seq.take_images()
223                            .expect("Need to have images by this point."),
224                        vec![],
225                        config,
226                        device,
227                        (bs, max_num_images), // Don't use it here...
228                    )
229                    .expect("Preprocessing failed");
230                // Intentionally don't unsqueeze here as the BS is already included. Just stack now.
231                pixel_values_accum.push(pixel_values);
232                aspect_ratios_accum.push(aspect_ratio_ids.unwrap());
233            }
234
235            let pixel_values = Tensor::cat(&pixel_values_accum, 0).unwrap();
236            let aspect_ratios = Tensor::cat(&aspect_ratios_accum, 0).unwrap();
237
238            let (image_h, image_w) = (
239                pixel_values.dim(D::Minus2).unwrap(),
240                pixel_values.dim(D::Minus1).unwrap(),
241            );
242            let num_patches_per_chunk =
243                (image_h / self.patch_size) * (image_w / self.patch_size) / self.downsample_ratio;
244
245            let placeholder_counts = input_seqs
246                .iter()
247                .map(|seq| seq.get_initial_prompt().match_indices(IMAGE_TOKEN).count())
248                .collect::<Vec<_>>();
249
250            let mut image_index = 0;
251            for (seq, placeholder_count) in input_seqs.iter_mut().zip(placeholder_counts) {
252                if placeholder_count == 0 {
253                    continue;
254                }
255                let prompt_splits: std::str::Split<'_, &str> =
256                    seq.get_initial_prompt().split(IMAGE_TOKEN);
257                let mut new_prompt = Vec::new();
258                for (local_image_index, split_part) in prompt_splits.enumerate() {
259                    new_prompt.push(split_part.to_string());
260                    if local_image_index < placeholder_count {
261                        let tokens_for_this_image = self.prompt_split_image(
262                            &aspect_ratios.i(image_index).unwrap(),
263                            num_patches_per_chunk,
264                        );
265                        image_index += 1;
266                        new_prompt.push(tokens_for_this_image);
267                    }
268                }
269                let prompt = new_prompt.join("");
270
271                if !seq.multimodal.has_changed_prompt {
272                    seq.set_initial_prompt(prompt.clone());
273                    let toks = tokenizer
274                        .encode_fast(prompt, false)
275                        .expect("Detokenization failed!");
276
277                    let ids = toks.get_ids().to_vec();
278                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
279                    seq.multimodal.has_changed_prompt = true;
280                }
281            }
282
283            Some(pixel_values)
284        } else {
285            None
286        };
287
288        let text_models_inputs_processor::InnerInputProcessorOutput {
289            inputs:
290                text_models_inputs_processor::InputMetadata {
291                    input,
292                    positions,
293                    context_lens,
294                    position_ids,
295                    paged_attn_meta,
296                    flash_meta,
297                },
298            seq_indices,
299        } = if is_prompt {
300            get_prompt_input(
301                input_seqs
302                    .iter()
303                    .map(|seq| seq.get_toks())
304                    .collect::<Vec<_>>(),
305                input_seqs,
306                device,
307                last_n_context_len,
308                return_raw_logits,
309                paged_attn_metadata.as_mut(),
310                None, // TODO: evaluate if it is possible to batch this
311                mapper,
312            )
313            .nth(0)
314            .unwrap()
315            .unwrap()
316        } else {
317            get_completion_input(
318                input_seqs
319                    .iter()
320                    .map(|seq| seq.get_toks())
321                    .collect::<Vec<_>>(),
322                input_seqs,
323                device,
324                no_kv_cache,
325                last_n_context_len,
326                return_raw_logits,
327                paged_attn_metadata.as_mut(),
328                None, // TODO: evaluate if it is possible to batch this
329                mapper,
330            )
331            .nth(0)
332            .unwrap()
333            .unwrap()
334        };
335
336        let inputs: Box<dyn Any> = Box::new(ModelInputs {
337            input_ids: input,
338            seqlen_offsets: positions,
339            context_lens,
340            position_ids,
341            pixel_values,
342            model_specific_args: Box::new(Llama4ModelSpecificArgs),
343            paged_attn_meta,
344            flash_meta,
345        });
346        Box::new(std::iter::once(Ok(InputProcessorOutput {
347            inputs,
348            seq_indices,
349        })))
350    }
351}
352
353impl Llama4ImageProcessor {
354    fn get_factors(dividend: u32) -> HashSet<u32> {
355        let mut factors_set = HashSet::new();
356
357        let sqrt = (dividend as f64).sqrt() as u32;
358        for i in 1..=sqrt {
359            if dividend % i == 0 {
360                factors_set.insert(i);
361                factors_set.insert(dividend / i);
362            }
363        }
364
365        factors_set
366    }
367
368    fn find_supported_resolutions(
369        &self,
370        max_num_chunks: usize,
371        size: &HashMap<String, u32>,
372    ) -> Result<Vec<(u32, u32)>> {
373        let height = size["height"];
374        let width = size["width"];
375        if height != width {
376            candle_core::bail!("Expected config size height==width ({height}!={width})");
377        }
378
379        let patch_size = height;
380
381        let mut asp_map = HashMap::new();
382        for chunk_size in (0..max_num_chunks).rev() {
383            let factors = Self::get_factors(chunk_size as u32);
384            let asp_ratios = factors
385                .into_iter()
386                .sorted()
387                .map(|factors| (factors, chunk_size as u32 / factors));
388            for (h, w) in asp_ratios {
389                let ratio_float = h as f32 / w as f32;
390                asp_map
391                    .entry(NotNan::new(ratio_float).context("f32 is NaN")?)
392                    .or_insert_with(Vec::new)
393                    .push((h, w));
394            }
395        }
396
397        // Get the resolutions multiplied by the patch size
398        let possible_resolutions = asp_map
399            .into_values()
400            .flatten()
401            .map(|(height, depth)| (height * patch_size, depth * patch_size))
402            .collect::<Vec<_>>();
403
404        Ok(possible_resolutions)
405    }
406
407    #[allow(clippy::type_complexity)]
408    fn group_images_by_shape(
409        &self,
410        images: &[Tensor],
411    ) -> Result<(
412        HashMap<(usize, usize), Tensor>,
413        HashMap<usize, ((usize, usize), usize)>,
414    )> {
415        let mut grouped_images = HashMap::new();
416        let mut grouped_images_index = HashMap::new();
417        for (i, image) in images.iter().enumerate() {
418            let (_c, h, w) = image.dims3()?;
419            let shape = (h, w);
420            grouped_images
421                .entry(shape)
422                .or_insert_with(Vec::new)
423                .push(image.clone());
424            grouped_images_index.insert(i, (shape, grouped_images[&shape].len() - 1));
425        }
426        // Stack images with the same shape
427        let mut grouped_images_stack = HashMap::new();
428        for (shape, images) in grouped_images {
429            grouped_images_stack.insert(shape, Tensor::stack(&images, 0)?);
430        }
431
432        Ok((grouped_images_stack, grouped_images_index))
433    }
434
435    fn get_best_fit(
436        &self,
437        (original_height, original_width): (u32, u32),
438        possible_resolutions: Vec<(u32, u32)>,
439        resize_to_max_canvas: bool,
440    ) -> Result<(u32, u32)> {
441        // All possible reslns h/w
442        let (target_heights, target_widths): (Vec<u32>, Vec<u32>) =
443            possible_resolutions.iter().copied().unzip();
444
445        // Scaling factors to resize image without distortion
446        let scale_w = target_widths
447            .iter()
448            .map(|tw| *tw as f32 / original_width as f32);
449        let scale_h = target_heights
450            .iter()
451            .map(|th| *th as f32 / original_height as f32);
452
453        // Min scale between w and h (limiting size -> no distortion)
454        let scales = scale_w.zip(scale_h).map(|(w, h)| if h > w { w } else { h });
455
456        // Filter only scales that allow upscaling
457        let upscaling_options = scales
458            .clone()
459            .filter(|s| *s >= 1.)
460            .map(|x| NotNan::new(x).unwrap())
461            .collect::<Vec<_>>();
462        let downscaling_options = scales
463            .clone()
464            .filter(|s| *s < 1.)
465            .map(|x| NotNan::new(x).unwrap())
466            .collect::<Vec<_>>();
467        let selected_scale = if !upscaling_options.is_empty() {
468            if resize_to_max_canvas {
469                upscaling_options.into_iter().max().unwrap().into_inner()
470            } else {
471                upscaling_options.into_iter().min().unwrap().into_inner()
472            }
473        } else {
474            // No upscaling; get min downscaling (max scale for scales < 1)
475            downscaling_options.into_iter().max().unwrap().into_inner()
476        };
477
478        // All reslns that support this scaling factor
479        // Ex. can upscale 224x224, 224x448, 224x672 without distortion
480        // If there are multiple resolutions, get the one with minimum area to reduce padding.
481        // Sort by increasing areas and take 1.
482        let chosen_canvas = possible_resolutions
483            .into_iter()
484            .zip(scales)
485            .filter_map(|(possible, scale)| {
486                if scale == selected_scale {
487                    Some(possible)
488                } else {
489                    None
490                }
491            })
492            .sorted_by_key(|(h, w)| h * w)
493            .take(1)
494            .collect::<Vec<_>>()[0];
495
496        Ok(chosen_canvas)
497    }
498
499    fn get_max_res_without_distortion(
500        &self,
501        image_size: (u32, u32),
502        target_size: (u32, u32),
503    ) -> (u32, u32) {
504        let (original_height, original_width) = image_size;
505        let (target_height, target_width) = target_size;
506
507        let scale_w = target_width as f64 / original_width as f64;
508        let scale_h = target_height as f64 / original_height as f64;
509
510        if scale_w < scale_h {
511            let new_width = target_width;
512            // Calculate new height and ensure it doesn't exceed target_height
513            let new_height = std::cmp::min(
514                (original_height as f64 * scale_w).floor() as u32,
515                target_height,
516            );
517            (new_height, new_width)
518        } else {
519            let new_height = target_height;
520            // Calculate new width and ensure it doesn't exceed target_width
521            let new_width = std::cmp::min(
522                (original_width as f64 * scale_h).floor() as u32,
523                target_width,
524            );
525            (new_height, new_width)
526        }
527    }
528
529    fn split_to_tiles(
530        &self,
531        images: &Tensor,
532        num_tiles_h: usize,
533        num_tiles_w: usize,
534    ) -> Result<Tensor> {
535        let (bs, c, h, w) = images.dims4()?;
536        let mut images = images.reshape((
537            bs,
538            c,
539            num_tiles_h,
540            h / num_tiles_h,
541            num_tiles_w,
542            w / num_tiles_w,
543        ))?;
544        images = images.permute((0, 2, 4, 1, 3, 5))?.contiguous()?;
545        images.reshape((
546            bs,
547            num_tiles_h * num_tiles_w,
548            c,
549            h / num_tiles_h,
550            w / num_tiles_w,
551        ))
552    }
553
554    fn reorder_images(
555        &self,
556        processed_images: HashMap<(usize, usize), Tensor>,
557        grouped_images_index: HashMap<usize, ((usize, usize), usize)>,
558    ) -> Result<Vec<Tensor>> {
559        grouped_images_index
560            .values()
561            .map(|(k, v)| processed_images[k].i(*v))
562            .collect::<Result<Vec<Tensor>>>()
563    }
564}
565
566impl ImagePreProcessor for Llama4ImageProcessor {
567    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
568    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
569
570    fn preprocess(
571        &self,
572        images_d: Vec<DynamicImage>,
573        videos: Vec<Vec<DynamicImage>>,
574        config: &PreProcessorConfig,
575        device: &Device,
576        (_bs, _max_num_images): (usize, usize),
577    ) -> Result<PreprocessedImages> {
578        assert!(videos.is_empty());
579
580        let max_patches = config.max_patches.unwrap_or(16);
581        let size = config.size.clone().unwrap_or(HashMap::from_iter([
582            ("height".to_string(), 336),
583            ("width".to_string(), 336),
584        ]));
585        let resize_to_max_canvas = config.resize_to_max_canvas.unwrap_or(false);
586        let do_rescale = config.do_rescale.unwrap_or(true);
587        let do_normalize = config.do_normalize.unwrap_or(true);
588
589        let possible_resolutions = self.find_supported_resolutions(max_patches, &size)?;
590
591        let mut images = Vec::new();
592        for mut image in images_d {
593            // Convert to rgb, default to true
594            if config.do_convert_rgb.unwrap_or(true) {
595                image = DynamicImage::ImageRgb8(image.to_rgb8());
596            }
597
598            let to_tensor_rescale = Transforms {
599                input: &ToTensorNoNorm,
600                inner_transforms: &[],
601            };
602            let image = image.apply(to_tensor_rescale, device)?;
603            images.push(image);
604        }
605
606        let (grouped_images, grouped_images_index) = self.group_images_by_shape(&images)?;
607
608        let mut grouped_processed_images = HashMap::new();
609        let mut grouped_aspect_ratios = HashMap::new();
610        for (shape, stacked_images) in grouped_images {
611            let image_size = (
612                stacked_images.dim(D::Minus2)? as u32,
613                stacked_images.dim(D::Minus1)? as u32,
614            );
615            let target_size = self.get_best_fit(
616                image_size,
617                possible_resolutions.clone(),
618                resize_to_max_canvas,
619            )?;
620            // If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
621            let max_upscaling_size = if resize_to_max_canvas {
622                None
623            } else {
624                Some(size["height"])
625            };
626            let target_size_without_distortion =
627                if let Some(max_upscaling_size) = max_upscaling_size {
628                    let nt_h = image_size.0.max(max_upscaling_size).min(target_size.0);
629                    let nt_w = image_size.1.max(max_upscaling_size).min(target_size.1);
630                    (nt_h, nt_w)
631                } else {
632                    candle_core::bail!("Currently resize_to_max_canvas is assumed!");
633                };
634
635            // Resize to target_size while preserving aspect ratio
636            let new_size_without_distortion =
637                self.get_max_res_without_distortion(image_size, target_size_without_distortion);
638            let mut processed_images = stacked_images.interpolate2d(
639                new_size_without_distortion.0.max(1) as usize,
640                new_size_without_distortion.1.max(1) as usize,
641            )?;
642
643            // Pad to target_size to be able to split into tiles
644            processed_images = {
645                let (target_h, target_w) = target_size;
646                let (h, w) = (
647                    processed_images.dim(D::Minus2)?,
648                    processed_images.dim(D::Minus1)?,
649                );
650                let paste_x_r = target_w as usize - w;
651                let paste_y_r = target_h as usize - h;
652                processed_images
653                    .pad_with_zeros(D::Minus2, 0, paste_y_r)?
654                    .pad_with_zeros(D::Minus1, 0, paste_x_r)?
655            };
656
657            let rescale_and_norm_transforms = TensorTransforms {
658                inner_transforms: &[
659                    &do_rescale.then_some(Rescale {
660                        factor: config.rescale_factor,
661                    }),
662                    &do_normalize.then_some(Normalize {
663                        mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
664                        std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
665                    }),
666                ],
667            };
668            processed_images = <Tensor as ApplyTensorTransforms>::apply(
669                &processed_images,
670                rescale_and_norm_transforms,
671                device,
672            )?;
673
674            let (ratio_h, ratio_w) = (
675                target_size.0 / size["height"],
676                target_size.1 / size["width"],
677            );
678            // Split into tiles
679            processed_images =
680                self.split_to_tiles(&processed_images, ratio_h as usize, ratio_w as usize)?;
681            grouped_processed_images.insert(shape, processed_images.clone());
682            grouped_aspect_ratios.insert(
683                shape,
684                Tensor::new(vec![vec![ratio_h, ratio_w]; stacked_images.dim(0)?], device)?,
685            );
686
687            // Add a global tile to the processed tile if there are more than one tiles
688            if ratio_h * ratio_w > 1 {
689                let mut global_tiles = stacked_images
690                    .interpolate2d(size["height"] as usize, size["width"] as usize)?;
691                global_tiles = <Tensor as ApplyTensorTransforms>::apply(
692                    &global_tiles,
693                    rescale_and_norm_transforms,
694                    device,
695                )?;
696                grouped_processed_images.insert(
697                    shape,
698                    Tensor::cat(&[processed_images, global_tiles.unsqueeze(1)?], 1)?,
699                );
700            }
701        }
702
703        let processed_images =
704            self.reorder_images(grouped_processed_images, grouped_images_index.clone())?;
705        let aspect_ratios_list =
706            self.reorder_images(grouped_aspect_ratios, grouped_images_index.clone())?;
707
708        let processed_images = Tensor::cat(&processed_images, 0)?;
709        let aspect_ratios = Tensor::stack(&aspect_ratios_list, 0)?;
710
711        Ok(PreprocessedImages {
712            pixel_values: processed_images,
713            pixel_attention_mask: None,
714            image_sizes: None,
715            num_img_tokens: None,
716            aspect_ratio_ids: Some(aspect_ratios),
717            aspect_ratio_mask: None,
718            num_tiles: None,
719            image_grid_thw: None,
720            video_grid_thw: None,
721            rows: None,
722            cols: None,
723            pixel_values_list: None,
724            tgt_sizes: None,
725            image_sizes_all: None,
726            num_crops: None,
727        })
728    }
729}