mistralrs_core/vision_models/llama4/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    any::Any,
5    collections::{HashMap, HashSet},
6    num::NonZeroUsize,
7    sync::Arc,
8};
9
10use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
11use image::DynamicImage;
12use itertools::Itertools;
13use mistralrs_vision::{
14    ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15    Transforms,
16};
17use ordered_float::NotNan;
18use tokenizers::Tokenizer;
19use tracing::warn;
20
21use crate::{
22    device_map::DeviceMapper,
23    pipeline::{
24        text_models_inputs_processor::{
25            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
26        },
27        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
28    },
29    sequence::Sequence,
30    vision_models::{
31        image_processor::{ImagePreProcessor, PreprocessedImages},
32        preprocessor_config::PreProcessorConfig,
33        processor_config::ProcessorConfig,
34        ModelInputs,
35    },
36};
37
38use super::Llama4ModelSpecificArgs;
39
40pub(crate) const IMAGE_TOKEN: &str = "<|image|>";
41const IMAGE_START: &str = "<|image_start|>";
42const IMAGE_END: &str = "<|image_end|>";
43const PATCH: &str = "<|patch|>";
44const TILE_X_SEP: &str = "<|tile_x_separator|>";
45const TILE_Y_SEP: &str = "<|tile_y_separator|>";
46
47// Input processor
48pub struct Llama4ImageProcessor {
49    pub patch_size: usize,
50    pub downsample_ratio: usize,
51}
52
53impl Llama4ImageProcessor {
54    pub fn new(patch_size: Option<usize>, pixel_shuffle_ratio: Option<f32>) -> Self {
55        Self {
56            patch_size: patch_size.unwrap_or(14),
57            downsample_ratio: (1. / pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round() as usize,
58        }
59    }
60}
61
62// Processor
63pub struct Llama4Processor {
64    patch_size: usize,
65    downsample_ratio: usize,
66}
67
68impl Llama4Processor {
69    pub fn new(cfg: &ProcessorConfig) -> Self {
70        Self {
71            patch_size: cfg.patch_size.unwrap_or(14),
72            downsample_ratio: (1. / cfg.pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round()
73                as usize,
74        }
75    }
76}
77
78impl Processor for Llama4Processor {
79    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
80        Arc::new(Llama4ImageProcessor {
81            patch_size: self.patch_size,
82            downsample_ratio: self.downsample_ratio,
83        })
84    }
85
86    fn get_special_tokens(&self) -> &[&'static str] {
87        &[
88            IMAGE_START,
89            IMAGE_END,
90            PATCH,
91            TILE_X_SEP,
92            TILE_Y_SEP,
93            IMAGE_TOKEN,
94        ]
95    }
96
97    fn template_action(&self) -> MessagesAction {
98        MessagesAction::FlattenOnlyText
99    }
100}
101
102impl Llama4ImageProcessor {
103    fn prompt_split_image(&self, aspect_ratio: &Tensor, num_patches_per_chunk: usize) -> String {
104        let mut img_string = IMAGE_START.to_string();
105        let aspect_ratio = aspect_ratio.to_vec1::<u32>().unwrap();
106        let (ratio_h, ratio_w) = (aspect_ratio[0] as usize, aspect_ratio[1] as usize);
107        if ratio_h * ratio_w > 1 {
108            for _yy in 0..ratio_h {
109                for xx in 0..ratio_w {
110                    img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
111                    if xx < ratio_w - 1 {
112                        img_string.push_str(TILE_X_SEP);
113                    }
114                }
115                img_string.push_str(TILE_Y_SEP);
116            }
117        }
118        img_string.push_str(IMAGE_TOKEN);
119        img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
120        img_string.push_str(IMAGE_END);
121        img_string
122    }
123}
124
125impl InputsProcessor for Llama4ImageProcessor {
126    fn get_type(&self) -> InputsProcessorType {
127        InputsProcessorType::Vision
128    }
129    fn process_inputs(
130        &self,
131        tokenizer: Option<Arc<Tokenizer>>,
132        input_seqs: &mut [&mut Sequence],
133        is_prompt: bool,
134        is_xlora: bool,
135        device: &Device,
136        no_kv_cache: bool,
137        last_n_context_len: Option<(usize, usize)>,
138        return_raw_logits: bool,
139        other_config: Option<Arc<dyn Any>>,
140        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
141        prompt_chunksize: Option<NonZeroUsize>,
142        mapper: Option<&dyn DeviceMapper>,
143    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
144        if is_xlora {
145            return Box::new(std::iter::once(Err(anyhow::Error::msg(
146                "Cannot make inputs for X-LoRA vision model.",
147            ))));
148        }
149        if no_kv_cache {
150            return Box::new(std::iter::once(Err(anyhow::Error::msg(
151                "Vision model must have kv cache.",
152            ))));
153        }
154        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
155        if prompt_chunksize.is_some() {
156            warn!("`prompt_chunksize` is set. Llama4 does not support prompt batching.");
157        }
158        let Some(tokenizer) = tokenizer else {
159            return Box::new(std::iter::once(Err(anyhow::Error::msg(
160                "Llama4InputProcessor requires a specified tokenizer.",
161            ))));
162        };
163
164        let config = other_config.expect("Need a PreProcessorConfig config.");
165        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
166
167        let has_images = input_seqs.iter().all(|seq| seq.has_images());
168
169        let pixel_values = if has_images {
170            let mut pixel_values_accum = Vec::new();
171            let mut aspect_ratios_accum = Vec::new();
172
173            let bs = input_seqs.len();
174            let detokenized = tokenizer
175                .decode_batch(
176                    &input_seqs
177                        .iter()
178                        .map(|seq| seq.get_toks())
179                        .collect::<Vec<_>>(),
180                    false,
181                )
182                .expect("Detokenization failed!");
183            let n_images_in_text = detokenized
184                .iter()
185                .map(|text| text.matches(IMAGE_TOKEN).count())
186                .collect::<Vec<_>>();
187            let n_images_in_images = input_seqs
188                .iter()
189                .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
190                .collect::<Vec<_>>();
191
192            if n_images_in_text != n_images_in_images {
193                return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
194                    "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
195                )))));
196            }
197
198            let max_num_images = *n_images_in_images
199                .iter()
200                .max()
201                .expect("No max images per batch!");
202
203            for seq in input_seqs.iter_mut() {
204                let PreprocessedImages {
205                    pixel_values,
206                    pixel_attention_mask: _,
207                    image_sizes: _,
208                    num_img_tokens: _,
209                    aspect_ratio_ids,
210                    aspect_ratio_mask: _,
211                    num_tiles: _,
212                    image_grid_thw: _,
213                    video_grid_thw: _,
214                    rows: _,
215                    cols: _,
216                    pixel_values_list: _,
217                    tgt_sizes: _,
218                    image_sizes_all: _,
219                    num_crops: _,
220                } = self
221                    .preprocess(
222                        seq.take_images()
223                            .expect("Need to have images by this point."),
224                        vec![],
225                        config,
226                        device,
227                        (bs, max_num_images), // Don't use it here...
228                    )
229                    .expect("Preprocessing failed");
230                // Intentionally don't unsqueeze here as the BS is already included. Just stack now.
231                pixel_values_accum.push(pixel_values);
232                aspect_ratios_accum.push(aspect_ratio_ids.unwrap());
233            }
234
235            let pixel_values = Tensor::cat(&pixel_values_accum, 0).unwrap();
236            let aspect_ratios = Tensor::cat(&aspect_ratios_accum, 0).unwrap();
237
238            let (image_h, image_w) = (
239                pixel_values.dim(D::Minus2).unwrap(),
240                pixel_values.dim(D::Minus1).unwrap(),
241            );
242            let num_patches_per_chunk =
243                (image_h / self.patch_size) * (image_w / self.patch_size) / self.downsample_ratio;
244
245            let placeholder_counts = input_seqs
246                .iter()
247                .map(|seq| seq.get_initial_prompt().match_indices(IMAGE_TOKEN).count())
248                .collect::<Vec<_>>();
249
250            let mut image_index = 0;
251            for (seq, placeholder_count) in input_seqs.iter_mut().zip(placeholder_counts) {
252                if placeholder_count == 0 {
253                    continue;
254                }
255                let prompt_splits: std::str::Split<'_, &str> =
256                    seq.get_initial_prompt().split(IMAGE_TOKEN);
257                let mut new_prompt = Vec::new();
258                for (local_image_index, split_part) in prompt_splits.enumerate() {
259                    new_prompt.push(split_part.to_string());
260                    if local_image_index < placeholder_count {
261                        let tokens_for_this_image = self.prompt_split_image(
262                            &aspect_ratios.i(image_index).unwrap(),
263                            num_patches_per_chunk,
264                        );
265                        image_index += 1;
266                        new_prompt.push(tokens_for_this_image);
267                    }
268                }
269                let prompt = new_prompt.join("");
270
271                seq.set_initial_prompt(prompt.clone());
272                let toks = tokenizer
273                    .encode_fast(prompt, false)
274                    .expect("Detokenization failed!");
275
276                let ids = toks.get_ids().to_vec();
277                seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
278            }
279
280            Some(pixel_values)
281        } else {
282            None
283        };
284
285        let text_models_inputs_processor::InnerInputProcessorOutput {
286            inputs:
287                text_models_inputs_processor::InputMetadata {
288                    input,
289                    positions,
290                    context_lens,
291                    position_ids,
292                    paged_attn_meta,
293                    flash_meta,
294                },
295            seq_indices,
296        } = if is_prompt {
297            get_prompt_input(
298                input_seqs
299                    .iter()
300                    .map(|seq| seq.get_toks().to_vec())
301                    .collect::<Vec<_>>(),
302                input_seqs,
303                device,
304                last_n_context_len,
305                return_raw_logits,
306                paged_attn_metadata.as_mut(),
307                None, // TODO: evaluate if it is possible to batch this
308                mapper,
309            )
310            .nth(0)
311            .unwrap()
312            .unwrap()
313        } else {
314            get_completion_input(
315                input_seqs
316                    .iter()
317                    .map(|seq| seq.get_toks().to_vec())
318                    .collect::<Vec<_>>(),
319                input_seqs,
320                device,
321                no_kv_cache,
322                last_n_context_len,
323                return_raw_logits,
324                paged_attn_metadata.as_mut(),
325                None, // TODO: evaluate if it is possible to batch this
326                mapper,
327            )
328            .nth(0)
329            .unwrap()
330            .unwrap()
331        };
332
333        let inputs: Box<dyn Any> = Box::new(ModelInputs {
334            input_ids: input,
335            seqlen_offsets: positions,
336            context_lens,
337            position_ids,
338            pixel_values,
339            model_specific_args: Box::new(Llama4ModelSpecificArgs),
340            paged_attn_meta,
341            flash_meta,
342        });
343        Box::new(std::iter::once(Ok(InputProcessorOutput {
344            inputs,
345            seq_indices,
346        })))
347    }
348}
349
350impl Llama4ImageProcessor {
351    fn get_factors(dividend: u32) -> HashSet<u32> {
352        let mut factors_set = HashSet::new();
353
354        let sqrt = (dividend as f64).sqrt() as u32;
355        for i in 1..=sqrt {
356            if dividend % i == 0 {
357                factors_set.insert(i);
358                factors_set.insert(dividend / i);
359            }
360        }
361
362        factors_set
363    }
364
365    fn find_supported_resolutions(
366        &self,
367        max_num_chunks: usize,
368        size: &HashMap<String, u32>,
369    ) -> Result<Vec<(u32, u32)>> {
370        let height = size["height"];
371        let width = size["width"];
372        if height != width {
373            candle_core::bail!("Expected config size height==width ({height}!={width})");
374        }
375
376        let patch_size = height;
377
378        let mut asp_map = HashMap::new();
379        for chunk_size in (0..max_num_chunks).rev() {
380            let factors = Self::get_factors(chunk_size as u32);
381            let asp_ratios = factors
382                .into_iter()
383                .sorted()
384                .map(|factors| (factors, chunk_size as u32 / factors));
385            for (h, w) in asp_ratios {
386                let ratio_float = h as f32 / w as f32;
387                asp_map
388                    .entry(NotNan::new(ratio_float).context("f32 is NaN")?)
389                    .or_insert_with(Vec::new)
390                    .push((h, w));
391            }
392        }
393
394        // Get the resolutions multiplied by the patch size
395        let possible_resolutions = asp_map
396            .into_values()
397            .flatten()
398            .map(|(height, depth)| (height * patch_size, depth * patch_size))
399            .collect::<Vec<_>>();
400
401        Ok(possible_resolutions)
402    }
403
404    #[allow(clippy::type_complexity)]
405    fn group_images_by_shape(
406        &self,
407        images: &[Tensor],
408    ) -> Result<(
409        HashMap<(usize, usize), Tensor>,
410        HashMap<usize, ((usize, usize), usize)>,
411    )> {
412        let mut grouped_images = HashMap::new();
413        let mut grouped_images_index = HashMap::new();
414        for (i, image) in images.iter().enumerate() {
415            let (_c, h, w) = image.dims3()?;
416            let shape = (h, w);
417            grouped_images
418                .entry(shape)
419                .or_insert_with(Vec::new)
420                .push(image.clone());
421            grouped_images_index.insert(i, (shape, grouped_images[&shape].len() - 1));
422        }
423        // Stack images with the same shape
424        let mut grouped_images_stack = HashMap::new();
425        for (shape, images) in grouped_images {
426            grouped_images_stack.insert(shape, Tensor::stack(&images, 0)?);
427        }
428
429        Ok((grouped_images_stack, grouped_images_index))
430    }
431
432    fn get_best_fit(
433        &self,
434        (original_height, original_width): (u32, u32),
435        possible_resolutions: Vec<(u32, u32)>,
436        resize_to_max_canvas: bool,
437    ) -> Result<(u32, u32)> {
438        // All possible reslns h/w
439        let (target_heights, target_widths): (Vec<u32>, Vec<u32>) =
440            possible_resolutions.iter().copied().unzip();
441
442        // Scaling factors to resize image without distortion
443        let scale_w = target_widths
444            .iter()
445            .map(|tw| *tw as f32 / original_width as f32);
446        let scale_h = target_heights
447            .iter()
448            .map(|th| *th as f32 / original_height as f32);
449
450        // Min scale between w and h (limiting size -> no distortion)
451        let scales = scale_w.zip(scale_h).map(|(w, h)| if h > w { w } else { h });
452
453        // Filter only scales that allow upscaling
454        let upscaling_options = scales
455            .clone()
456            .filter(|s| *s >= 1.)
457            .map(|x| NotNan::new(x).unwrap())
458            .collect::<Vec<_>>();
459        let downscaling_options = scales
460            .clone()
461            .filter(|s| *s < 1.)
462            .map(|x| NotNan::new(x).unwrap())
463            .collect::<Vec<_>>();
464        let selected_scale = if !upscaling_options.is_empty() {
465            if resize_to_max_canvas {
466                upscaling_options.into_iter().max().unwrap().into_inner()
467            } else {
468                upscaling_options.into_iter().min().unwrap().into_inner()
469            }
470        } else {
471            // No upscaling; get min downscaling (max scale for scales < 1)
472            downscaling_options.into_iter().max().unwrap().into_inner()
473        };
474
475        // All reslns that support this scaling factor
476        // Ex. can upscale 224x224, 224x448, 224x672 without distortion
477        // If there are multiple resolutions, get the one with minimum area to reduce padding.
478        // Sort by increasing areas and take 1.
479        let chosen_canvas = possible_resolutions
480            .into_iter()
481            .zip(scales)
482            .filter_map(|(possible, scale)| {
483                if scale == selected_scale {
484                    Some(possible)
485                } else {
486                    None
487                }
488            })
489            .sorted_by_key(|(h, w)| h * w)
490            .take(1)
491            .collect::<Vec<_>>()[0];
492
493        Ok(chosen_canvas)
494    }
495
496    fn get_max_res_without_distortion(
497        &self,
498        image_size: (u32, u32),
499        target_size: (u32, u32),
500    ) -> (u32, u32) {
501        let (original_height, original_width) = image_size;
502        let (target_height, target_width) = target_size;
503
504        let scale_w = target_width as f64 / original_width as f64;
505        let scale_h = target_height as f64 / original_height as f64;
506
507        if scale_w < scale_h {
508            let new_width = target_width;
509            // Calculate new height and ensure it doesn't exceed target_height
510            let new_height = std::cmp::min(
511                (original_height as f64 * scale_w).floor() as u32,
512                target_height,
513            );
514            (new_height, new_width)
515        } else {
516            let new_height = target_height;
517            // Calculate new width and ensure it doesn't exceed target_width
518            let new_width = std::cmp::min(
519                (original_width as f64 * scale_h).floor() as u32,
520                target_width,
521            );
522            (new_height, new_width)
523        }
524    }
525
526    fn split_to_tiles(
527        &self,
528        images: &Tensor,
529        num_tiles_h: usize,
530        num_tiles_w: usize,
531    ) -> Result<Tensor> {
532        let (bs, c, h, w) = images.dims4()?;
533        let mut images = images.reshape((
534            bs,
535            c,
536            num_tiles_h,
537            h / num_tiles_h,
538            num_tiles_w,
539            w / num_tiles_w,
540        ))?;
541        images = images.permute((0, 2, 4, 1, 3, 5))?.contiguous()?;
542        images.reshape((
543            bs,
544            num_tiles_h * num_tiles_w,
545            c,
546            h / num_tiles_h,
547            w / num_tiles_w,
548        ))
549    }
550
551    fn reorder_images(
552        &self,
553        processed_images: HashMap<(usize, usize), Tensor>,
554        grouped_images_index: HashMap<usize, ((usize, usize), usize)>,
555    ) -> Result<Vec<Tensor>> {
556        grouped_images_index
557            .values()
558            .map(|(k, v)| processed_images[k].i(*v))
559            .collect::<Result<Vec<Tensor>>>()
560    }
561}
562
563impl ImagePreProcessor for Llama4ImageProcessor {
564    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
565    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
566
567    fn preprocess(
568        &self,
569        images_d: Vec<DynamicImage>,
570        videos: Vec<Vec<DynamicImage>>,
571        config: &PreProcessorConfig,
572        device: &Device,
573        (_bs, _max_num_images): (usize, usize),
574    ) -> Result<PreprocessedImages> {
575        assert!(videos.is_empty());
576
577        let max_patches = config.max_patches.unwrap_or(16);
578        let size = config.size.clone().unwrap_or(HashMap::from_iter([
579            ("height".to_string(), 336),
580            ("width".to_string(), 336),
581        ]));
582        let resize_to_max_canvas = config.resize_to_max_canvas.unwrap_or(false);
583        let do_rescale = config.do_rescale.unwrap_or(true);
584        let do_normalize = config.do_normalize.unwrap_or(true);
585
586        let possible_resolutions = self.find_supported_resolutions(max_patches, &size)?;
587
588        let mut images = Vec::new();
589        for mut image in images_d {
590            // Convert to rgb, default to true
591            if config.do_convert_rgb.unwrap_or(true) {
592                image = DynamicImage::ImageRgb8(image.to_rgb8());
593            }
594
595            let to_tensor_rescale = Transforms {
596                input: &ToTensorNoNorm,
597                inner_transforms: &[],
598            };
599            let image = image.apply(to_tensor_rescale, device)?;
600            images.push(image);
601        }
602
603        let (grouped_images, grouped_images_index) = self.group_images_by_shape(&images)?;
604
605        let mut grouped_processed_images = HashMap::new();
606        let mut grouped_aspect_ratios = HashMap::new();
607        for (shape, stacked_images) in grouped_images {
608            let image_size = (
609                stacked_images.dim(D::Minus2)? as u32,
610                stacked_images.dim(D::Minus1)? as u32,
611            );
612            let target_size = self.get_best_fit(
613                image_size,
614                possible_resolutions.clone(),
615                resize_to_max_canvas,
616            )?;
617            // If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
618            let max_upscaling_size = if resize_to_max_canvas {
619                None
620            } else {
621                Some(size["height"])
622            };
623            let target_size_without_distortion =
624                if let Some(max_upscaling_size) = max_upscaling_size {
625                    let nt_h = image_size.0.max(max_upscaling_size).min(target_size.0);
626                    let nt_w = image_size.1.max(max_upscaling_size).min(target_size.1);
627                    (nt_h, nt_w)
628                } else {
629                    candle_core::bail!("Currently resize_to_max_canvas is assumed!");
630                };
631
632            // Resize to target_size while preserving aspect ratio
633            let new_size_without_distortion =
634                self.get_max_res_without_distortion(image_size, target_size_without_distortion);
635            let mut processed_images = stacked_images.interpolate2d(
636                new_size_without_distortion.0.max(1) as usize,
637                new_size_without_distortion.1.max(1) as usize,
638            )?;
639
640            // Pad to target_size to be able to split into tiles
641            processed_images = {
642                let (target_h, target_w) = target_size;
643                let (h, w) = (
644                    processed_images.dim(D::Minus2)?,
645                    processed_images.dim(D::Minus1)?,
646                );
647                let paste_x_r = target_w as usize - w;
648                let paste_y_r = target_h as usize - h;
649                processed_images
650                    .pad_with_zeros(D::Minus2, 0, paste_y_r)?
651                    .pad_with_zeros(D::Minus1, 0, paste_x_r)?
652            };
653
654            let rescale_and_norm_transforms = TensorTransforms {
655                inner_transforms: &[
656                    &do_rescale.then_some(Rescale {
657                        factor: config.rescale_factor,
658                    }),
659                    &do_normalize.then_some(Normalize {
660                        mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
661                        std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
662                    }),
663                ],
664            };
665            processed_images = <Tensor as ApplyTensorTransforms>::apply(
666                &processed_images,
667                rescale_and_norm_transforms,
668                device,
669            )?;
670
671            let (ratio_h, ratio_w) = (
672                target_size.0 / size["height"],
673                target_size.1 / size["width"],
674            );
675            // Split into tiles
676            processed_images =
677                self.split_to_tiles(&processed_images, ratio_h as usize, ratio_w as usize)?;
678            grouped_processed_images.insert(shape, processed_images.clone());
679            grouped_aspect_ratios.insert(
680                shape,
681                Tensor::new(vec![vec![ratio_h, ratio_w]; stacked_images.dim(0)?], device)?,
682            );
683
684            // Add a global tile to the processed tile if there are more than one tiles
685            if ratio_h * ratio_w > 1 {
686                let mut global_tiles = stacked_images
687                    .interpolate2d(size["height"] as usize, size["width"] as usize)?;
688                global_tiles = <Tensor as ApplyTensorTransforms>::apply(
689                    &global_tiles,
690                    rescale_and_norm_transforms,
691                    device,
692                )?;
693                grouped_processed_images.insert(
694                    shape,
695                    Tensor::cat(&[processed_images, global_tiles.unsqueeze(1)?], 1)?,
696                );
697            }
698        }
699
700        let processed_images =
701            self.reorder_images(grouped_processed_images, grouped_images_index.clone())?;
702        let aspect_ratios_list =
703            self.reorder_images(grouped_aspect_ratios, grouped_images_index.clone())?;
704
705        let processed_images = Tensor::cat(&processed_images, 0)?;
706        let aspect_ratios = Tensor::stack(&aspect_ratios_list, 0)?;
707
708        Ok(PreprocessedImages {
709            pixel_values: processed_images,
710            pixel_attention_mask: None,
711            image_sizes: None,
712            num_img_tokens: None,
713            aspect_ratio_ids: Some(aspect_ratios),
714            aspect_ratio_mask: None,
715            num_tiles: None,
716            image_grid_thw: None,
717            video_grid_thw: None,
718            rows: None,
719            cols: None,
720            pixel_values_list: None,
721            tgt_sizes: None,
722            image_sizes_all: None,
723            num_crops: None,
724        })
725    }
726}