1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 any::Any,
5 collections::HashMap,
6 num::NonZeroUsize,
7 sync::{Arc, RwLock},
8};
9
10use candle_core::{Context, DType, Device, Result, Tensor};
11use image::{imageops::FilterType, DynamicImage};
12use itertools::Itertools;
13use mistralrs_vision::{
14 ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15 Transforms,
16};
17use tokenizers::Tokenizer;
18use tracing::warn;
19
20use crate::{
21 device_map::DeviceMapper,
22 pipeline::{
23 text_models_inputs_processor::{
24 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
25 },
26 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
27 },
28 sequence::Sequence,
29 vision_models::{
30 image_processor::{ImagePreProcessor, PreprocessedImages},
31 preprocessor_config::{PreProcessorConfig, ToFilter},
32 ModelInputs,
33 },
34};
35
36use super::MLlamaSpecificArgs;
37
38const IMAGE_TOKEN: &str = "<|image|>";
39
40struct MLlamaImageProcessor {
42 max_image_tiles: RwLock<Option<usize>>,
44}
45pub struct MLlamaProcessor;
47
48impl MLlamaProcessor {
49 pub fn new() -> Self {
50 Self
51 }
52}
53
54impl Processor for MLlamaProcessor {
55 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
56 Arc::new(MLlamaImageProcessor {
57 max_image_tiles: RwLock::new(None),
58 })
59 }
60
61 fn get_special_tokens(&self) -> &[&'static str] {
62 &[IMAGE_TOKEN, "<|python_tag|>"]
63 }
64
65 fn template_action(&self) -> MessagesAction {
66 MessagesAction::FlattenOnlyText
67 }
68}
69
70fn get_cross_attention_token_mask(input_ids: Vec<u32>, image_token_id: u32) -> Vec<(i64, i64)> {
73 let image_token_locations = input_ids
74 .iter()
75 .positions(|token| *token == image_token_id)
76 .collect::<Vec<_>>();
77
78 if image_token_locations.is_empty() {
79 return vec![];
80 }
81
82 if image_token_locations.len() == 1 {
84 return vec![(image_token_locations[0] as i64, -1)];
85 }
86
87 let mut vision_masks = image_token_locations[..image_token_locations.len() - 1]
88 .iter()
89 .zip(&image_token_locations[1..])
90 .map(|(a, b)| (*a as i64, *b as i64))
91 .collect::<Vec<_>>();
92
93 vision_masks.push((
95 *image_token_locations.last().unwrap() as i64,
96 input_ids.len() as i64,
97 ));
98
99 let mut last_mask_end = vision_masks.last().unwrap().1;
102 for vision_mask in vision_masks.iter_mut().rev() {
103 if vision_mask.0 == vision_mask.1 - 1 {
104 vision_mask.1 = last_mask_end;
105 }
106 last_mask_end = vision_mask.1;
107 }
108
109 vision_masks
110}
111
112fn convert_sparse_cross_attention_mask_to_dense(
124 cross_attn_token_mask: Vec<Vec<(i64, i64)>>,
125 num_tiles: Vec<Vec<usize>>,
126 max_num_tiles: usize,
127 length: usize,
128 dev: &Device,
129) -> candle_core::Result<Tensor> {
130 let bs = cross_attn_token_mask.len();
131 let max_num_images = cross_attn_token_mask.iter().map(|x| x.len()).max().unwrap();
132
133 let mut cross_attention_mask = Tensor::zeros(
134 (bs, length, max_num_images, max_num_tiles),
135 DType::I64,
136 &Device::Cpu,
137 )?;
138
139 for (sample_idx, (sample_masks, sample_num_tiles)) in
140 cross_attn_token_mask.into_iter().zip(num_tiles).enumerate()
141 {
142 for (mask_idx, ((start, end), mask_num_tiles)) in
143 sample_masks.into_iter().zip(sample_num_tiles).enumerate()
144 {
145 let mut end = end.min(length as i64);
146 if end == -1 {
147 end = length as i64;
148 }
149 cross_attention_mask = cross_attention_mask.slice_assign(
150 &[
151 &sample_idx,
152 &(start as usize..end as usize),
153 &mask_idx,
154 &(..mask_num_tiles),
155 ],
156 &Tensor::ones(
157 (1, end as usize - start as usize, 1, mask_num_tiles),
158 DType::I64,
159 &Device::Cpu,
160 )?,
161 )?;
162 }
163 }
164
165 cross_attention_mask.to_device(dev)
166}
167
168impl InputsProcessor for MLlamaImageProcessor {
169 fn get_type(&self) -> InputsProcessorType {
170 InputsProcessorType::Vision
171 }
172 fn process_inputs(
173 &self,
174 tokenizer: Option<Arc<Tokenizer>>,
175 input_seqs: &mut [&mut Sequence],
176 is_prompt: bool,
177 is_xlora: bool,
178 device: &Device,
179 no_kv_cache: bool,
180 last_n_context_len: Option<(usize, usize)>,
181 return_raw_logits: bool,
182 other_config: Option<Arc<dyn Any>>,
183 mut paged_attn_metadata: Option<PagedAttentionMeta>,
184 prompt_chunksize: Option<NonZeroUsize>,
185 mapper: Option<&dyn DeviceMapper>,
186 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
187 if is_xlora {
188 return Box::new(std::iter::once(Err(anyhow::Error::msg(
189 "Cannot make inputs for X-LoRA vision model.",
190 ))));
191 }
192 if no_kv_cache {
193 return Box::new(std::iter::once(Err(anyhow::Error::msg(
194 "Vision model must have kv cache.",
195 ))));
196 }
197 if prompt_chunksize.is_some() {
199 warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
200 }
201 let Some(tokenizer) = tokenizer else {
202 return Box::new(std::iter::once(Err(anyhow::Error::msg(
203 "MLlamaInputProcessor requires a specified tokenizer.",
204 ))));
205 };
206
207 let text_models_inputs_processor::InnerInputProcessorOutput {
208 inputs:
209 text_models_inputs_processor::InputMetadata {
210 input,
211 positions: _,
212 context_lens: _,
213 position_ids: _,
214 paged_attn_meta: _,
215 flash_meta: _,
216 },
217 seq_indices: _,
218 } = if is_prompt {
219 get_prompt_input(
220 input_seqs
221 .iter()
222 .map(|seq| seq.get_toks())
223 .collect::<Vec<_>>(),
224 input_seqs,
225 device,
226 last_n_context_len,
227 return_raw_logits,
228 paged_attn_metadata.as_mut(),
229 None, mapper,
231 )
232 .nth(0)
233 .unwrap()
234 .unwrap()
235 } else {
236 get_completion_input(
237 input_seqs
238 .iter()
239 .map(|seq| seq.get_toks())
240 .collect::<Vec<_>>(),
241 input_seqs,
242 device,
243 no_kv_cache,
244 last_n_context_len,
245 return_raw_logits,
246 paged_attn_metadata.as_mut(),
247 None, mapper,
249 )
250 .nth(0)
251 .unwrap()
252 .unwrap()
253 };
254 let config = other_config.expect("Need a PreProcessorConfig config.");
255 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
256
257 let has_images = input_seqs.iter().all(|seq| seq.has_images());
258
259 let (pixel_values, aspect_ratio_ids, aspect_ratio_mask, cross_attn_mask) = if has_images {
260 let mut pixel_values_accum = Vec::new();
261 let mut aspect_ratio_ids_accum = Vec::new();
262 let mut aspect_ratio_mask_accum = Vec::new();
263 let mut num_tiles_accum = Vec::new();
264
265 let bs = input_seqs.len();
266 let detokenized = tokenizer
267 .decode_batch(
268 &input_seqs
269 .iter()
270 .map(|seq| seq.get_toks())
271 .collect::<Vec<_>>(),
272 false,
273 )
274 .expect("Detokenization failed!");
275 let n_images_in_text = detokenized
276 .iter()
277 .map(|text| text.matches(IMAGE_TOKEN).count())
278 .collect::<Vec<_>>();
279 let n_images_in_images = input_seqs
280 .iter()
281 .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
282 .collect::<Vec<_>>();
283
284 if n_images_in_text != n_images_in_images {
285 return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
286 "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
287 )))));
288 }
289
290 let max_num_images = *n_images_in_images
291 .iter()
292 .max()
293 .expect("No max images per batch!");
294
295 for seq in input_seqs.iter_mut() {
296 let PreprocessedImages {
297 pixel_values,
298 pixel_attention_mask: _,
299 image_sizes: _,
300 num_img_tokens: _,
301 aspect_ratio_ids,
302 aspect_ratio_mask,
303 num_tiles,
304 image_grid_thw: _,
305 video_grid_thw: _,
306 rows: _,
307 cols: _,
308 pixel_values_list: _,
309 tgt_sizes: _,
310 image_sizes_all: _,
311 num_crops: _,
312 } = self
313 .preprocess(
314 seq.take_images()
315 .expect("Need to have images by this point."),
316 vec![],
317 config,
318 device,
319 (bs, max_num_images), )
321 .expect("Preprocessing failed");
322 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
323 aspect_ratio_ids_accum.push(aspect_ratio_ids.unwrap().unsqueeze(0).unwrap());
324 aspect_ratio_mask_accum.push(aspect_ratio_mask.unwrap().unsqueeze(0).unwrap());
325 num_tiles_accum.push(num_tiles.unwrap());
326
327 seq.multimodal.has_changed_prompt = true;
328 }
329
330 let image_token_id = tokenizer
332 .encode_fast(IMAGE_TOKEN, false)
333 .unwrap()
334 .get_ids()
335 .to_vec();
336 let image_token_id = if image_token_id.len() == 1 {
337 image_token_id[0]
338 } else {
339 panic!("{IMAGE_TOKEN} encoding should be one token, got {image_token_id:?}");
340 };
341 let chunks = input.chunk(input.dim(0).unwrap(), 0).unwrap();
342 let cross_attention_token_mask = chunks
343 .iter()
344 .map(|token_ids| {
345 get_cross_attention_token_mask(
346 token_ids.squeeze(0).unwrap().to_vec1::<u32>().unwrap(),
347 image_token_id,
348 )
349 })
350 .collect::<Vec<_>>();
351
352 let cross_attn_mask = convert_sparse_cross_attention_mask_to_dense(
353 cross_attention_token_mask,
354 num_tiles_accum,
355 self.max_image_tiles
356 .read()
357 .unwrap()
358 .expect("`max_image_tiles` must be set!"),
359 chunks
360 .iter()
361 .map(|input_ids| *input_ids.dims().last().unwrap())
362 .max()
363 .unwrap(),
364 chunks[0].device(),
365 );
366
367 let cross_attn_mask = match cross_attn_mask {
368 Ok(v) => v,
369 Err(e) => return Box::new(std::iter::once(Err(anyhow::Error::msg(e.to_string())))),
370 };
371
372 (
373 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
374 Some(Tensor::cat(&aspect_ratio_ids_accum, 0).unwrap()),
375 Some(Tensor::cat(&aspect_ratio_mask_accum, 0).unwrap()),
376 Some(cross_attn_mask),
377 )
378 } else {
379 (None, None, None, None)
380 };
381
382 let text_models_inputs_processor::InnerInputProcessorOutput {
383 inputs:
384 text_models_inputs_processor::InputMetadata {
385 input,
386 positions,
387 context_lens,
388 position_ids,
389 paged_attn_meta,
390 flash_meta,
391 },
392 seq_indices,
393 } = if is_prompt {
394 get_prompt_input(
395 input_seqs
396 .iter()
397 .map(|seq| seq.get_toks())
398 .collect::<Vec<_>>(),
399 input_seqs,
400 device,
401 last_n_context_len,
402 return_raw_logits,
403 paged_attn_metadata.as_mut(),
404 None, mapper,
406 )
407 .nth(0)
408 .unwrap()
409 .unwrap()
410 } else {
411 get_completion_input(
412 input_seqs
413 .iter()
414 .map(|seq| seq.get_toks())
415 .collect::<Vec<_>>(),
416 input_seqs,
417 device,
418 no_kv_cache,
419 last_n_context_len,
420 return_raw_logits,
421 paged_attn_metadata.as_mut(),
422 None, mapper,
424 )
425 .nth(0)
426 .unwrap()
427 .unwrap()
428 };
429
430 let inputs: Box<dyn Any> = Box::new(ModelInputs {
431 input_ids: input,
432 seqlen_offsets: positions,
433 context_lens,
434 position_ids,
435 pixel_values,
436 model_specific_args: Box::new(MLlamaSpecificArgs {
437 aspect_ratio_ids,
438 aspect_ratio_mask,
439 cross_attn_mask,
440 }),
441 paged_attn_meta,
442 flash_meta,
443 });
444 Box::new(std::iter::once(Ok(InputProcessorOutput {
445 inputs,
446 seq_indices,
447 })))
448 }
449}
450
451fn argmin<T, I>(iter: I) -> Option<usize>
452where
453 T: PartialOrd,
454 I: Iterator<Item = T>,
455{
456 iter.enumerate()
457 .fold(None, |min, (idx, item)| match min {
458 None => Some((idx, item)),
459 Some((min_idx, min_item)) => {
460 if item < min_item {
461 Some((idx, item))
462 } else {
463 Some((min_idx, min_item))
464 }
465 }
466 })
467 .map(|(min_idx, _)| min_idx)
468}
469
470impl MLlamaImageProcessor {
471 fn get_all_supported_aspect_ratios(max_image_tiles: usize) -> Vec<(usize, usize)> {
473 (1..max_image_tiles + 1)
474 .flat_map(|width| {
475 (1..max_image_tiles + 1).filter_map(move |height| {
476 if width * height <= max_image_tiles {
477 Some((width, height))
478 } else {
479 None
480 }
481 })
482 })
483 .collect::<Vec<_>>()
484 }
485
486 fn get_optimal_tiled_canvas(
488 image_height: u32,
489 image_width: u32,
490 max_image_tiles: usize,
491 tile_size: usize,
492 ) -> Result<(usize, usize)> {
493 let possible_tile_arrangements = Self::get_all_supported_aspect_ratios(max_image_tiles);
494 let possible_canvas_sizes: (Vec<_>, Vec<_>) = possible_tile_arrangements
495 .into_iter()
496 .map(|(h, w)| (h * tile_size, w * tile_size))
497 .unzip();
498 let (target_heights, target_widths) = possible_canvas_sizes;
500
501 let scale_h = target_heights
503 .iter()
504 .map(|h| *h as f32 / image_height as f32)
505 .collect::<Vec<_>>();
506 let scale_w = target_widths
507 .iter()
508 .map(|w| *w as f32 / image_width as f32)
509 .collect::<Vec<_>>();
510
511 let scales = scale_h
513 .into_iter()
514 .zip(scale_w)
515 .map(|(scale_h, scale_w)| if scale_w > scale_h { scale_h } else { scale_w })
516 .collect::<Vec<_>>();
517
518 let upscaling_options = scales
520 .iter()
521 .copied()
522 .filter(|scale| *scale >= 1.)
523 .collect::<Vec<_>>();
524 let selected_scale = if !upscaling_options.is_empty() {
525 upscaling_options
526 .into_iter()
527 .min_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
528 .context("No min, upscale")?
529 } else {
530 let downscaling_options = scales
532 .iter()
533 .copied()
534 .filter(|scale| *scale < 1.)
535 .collect::<Vec<_>>();
536 downscaling_options
537 .into_iter()
538 .max_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
539 .context("No max, downscale")?
540 };
541
542 let chosen_canvas_h = target_heights
544 .iter()
545 .copied()
546 .enumerate()
547 .filter_map(|(i, h)| {
548 if scales[i] == selected_scale {
549 Some(h)
550 } else {
551 None
552 }
553 })
554 .collect::<Vec<_>>();
555 let chosen_canvas_w = target_widths
556 .iter()
557 .copied()
558 .enumerate()
559 .filter_map(|(i, w)| {
560 if scales[i] == selected_scale {
561 Some(w)
562 } else {
563 None
564 }
565 })
566 .collect::<Vec<_>>();
567
568 assert_eq!(chosen_canvas_h.len(), chosen_canvas_w.len());
569 if chosen_canvas_h.len() > 1 {
570 let optimal_idx = argmin(
571 chosen_canvas_h
572 .iter()
573 .zip(&chosen_canvas_w)
574 .map(|(h, w)| *h * *w),
575 )
576 .context("No argmin")?;
577 Ok((chosen_canvas_h[optimal_idx], chosen_canvas_w[optimal_idx]))
578 } else {
579 Ok((chosen_canvas_h[0], chosen_canvas_w[0]))
580 }
581 }
582
583 fn get_image_size_fit_to_canvas(
585 image_height: u32,
586 image_width: u32,
587 canvas_height: usize,
588 canvas_width: usize,
589 tile_size: usize,
590 ) -> (usize, usize) {
591 let target_width = (image_width as usize).clamp(tile_size, canvas_width);
592 let target_height = (image_height as usize).clamp(tile_size, canvas_height);
593
594 let scale_h = (target_height as f32) / (image_height as f32);
595 let scale_w = (target_width as f32) / (image_width as f32);
596
597 if scale_w < scale_h {
598 (
599 target_height.min((image_height as f32 * scale_w).floor() as usize),
600 target_width,
601 )
602 } else {
603 (
604 target_height,
605 target_width.min((image_width as f32 * scale_h).floor() as usize),
606 )
607 }
608 }
609
610 fn resize(
614 &self,
615 image: DynamicImage,
616 size: &HashMap<String, u32>,
617 max_image_tiles: usize,
618 filter: FilterType,
619 ) -> Result<(DynamicImage, (usize, usize))> {
620 let image_height = image.height();
621 let image_width = image.width();
622 let tile_size = size["height"] as usize;
623
624 let (canvas_height, canvas_width) =
625 Self::get_optimal_tiled_canvas(image_height, image_width, max_image_tiles, tile_size)?;
626 let num_tiles_height = canvas_height / tile_size;
627 let num_tiles_width = canvas_width / tile_size;
628
629 let (new_height, new_width) = Self::get_image_size_fit_to_canvas(
630 image_height,
631 image_width,
632 canvas_height,
633 canvas_width,
634 tile_size,
635 );
636
637 Ok((
638 image.resize_exact(new_width as u32, new_height as u32, filter),
639 (num_tiles_height, num_tiles_width),
640 ))
641 }
642
643 fn pad(
647 &self,
648 image: &Tensor,
649 size: &HashMap<String, u32>,
650 aspect_ratio: (usize, usize),
651 ) -> Result<Tensor> {
652 let (num_tiles_h, num_tiles_w) = aspect_ratio;
653 let padded_height = num_tiles_h * size["height"] as usize;
654 let padded_width = num_tiles_w * size["width"] as usize;
655
656 mistralrs_vision::pad(image, padded_height, padded_width)
658 }
659
660 fn split_to_tiles(
663 &self,
664 image: &Tensor,
665 num_tiles_height: usize,
666 num_tiles_width: usize,
667 ) -> Result<Tensor> {
668 let (ch, h, w) = image.dims3()?;
669 let tile_height = h / num_tiles_height;
670 let tile_width = w / num_tiles_width;
671
672 let mut image = image.reshape((
673 ch,
674 num_tiles_height,
675 tile_height,
676 num_tiles_width,
677 tile_width,
678 ))?;
679
680 image = image.permute((1, 3, 0, 2, 4))?;
682
683 image
685 .reshape((
686 num_tiles_width * num_tiles_height,
687 ch,
688 tile_height,
689 tile_width,
690 ))?
691 .contiguous()
692 }
693
694 fn pack_images(
700 &self,
701 images: Vec<Tensor>,
702 max_image_tiles: usize,
703 (_bs, max_num_images): (usize, usize),
704 ) -> Result<(Tensor, Vec<usize>)> {
705 let (_, ch, tile_h, tile_w) = images[0].dims4()?;
706
707 let mut stacked_images = Tensor::zeros(
708 (max_num_images, max_image_tiles, ch, tile_h, tile_w),
709 images[0].dtype(),
710 images[0].device(),
711 )?;
712 let mut num_sample_tiles = Vec::new();
713 for (i, image) in images.into_iter().enumerate() {
714 let num_tiles = image.dim(0)?;
715 stacked_images = stacked_images
716 .slice_assign(&[&i, &(..num_tiles), &.., &.., &..], &image.unsqueeze(0)?)?;
717 num_sample_tiles.push(num_tiles)
718 }
719 Ok((stacked_images, num_sample_tiles))
720 }
721
722 fn convert_aspect_ratios_to_ids(
726 &self,
727 aspect_ratios: Vec<(usize, usize)>,
728 max_image_tiles: usize,
729 (_bs, max_num_images): (usize, usize),
730 device: &Device,
731 ) -> Result<Tensor> {
732 let supported_aspect_ratios = Self::get_all_supported_aspect_ratios(max_image_tiles);
733
734 let mut aspect_ratios_ids = vec![0i64; max_num_images];
735 for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
736 aspect_ratios_ids[i] = (supported_aspect_ratios
737 .iter()
738 .position(|(h, w)| *h == *num_tiles_h && *w == *num_tiles_w)
739 .context("Could not find aspect ratio")?
740 + 1) as i64;
741 }
742
743 Tensor::new(aspect_ratios_ids, device)
744 }
745
746 fn build_aspect_ratio_mask(
747 &self,
748 aspect_ratios: Vec<(usize, usize)>,
749 max_image_tiles: usize,
750 (_bs, max_num_images): (usize, usize),
751 device: &Device,
752 ) -> Result<Tensor> {
753 let mut aspect_ratio_mask =
754 Tensor::zeros((max_num_images, max_image_tiles), DType::I64, device)?;
755
756 aspect_ratio_mask = aspect_ratio_mask.slice_assign(
760 &[&.., &0],
761 &Tensor::ones((max_num_images, 1), DType::I64, device)?,
762 )?;
763
764 for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
765 aspect_ratio_mask = aspect_ratio_mask.slice_assign(
766 &[&i, &(..*num_tiles_h * *num_tiles_w)],
767 &Tensor::ones((1, *num_tiles_h * *num_tiles_w), DType::I64, device)?,
768 )?;
769 }
770
771 Ok(aspect_ratio_mask)
772 }
773}
774
775impl ImagePreProcessor for MLlamaImageProcessor {
776 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
777 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
778
779 fn preprocess(
780 &self,
781 images: Vec<DynamicImage>,
782 videos: Vec<Vec<DynamicImage>>,
783 config: &PreProcessorConfig,
784 device: &Device,
785 (bs, max_num_images): (usize, usize),
786 ) -> Result<PreprocessedImages> {
787 assert!(videos.is_empty());
788
789 let mut sample_images = Vec::new();
790 let mut sample_aspect_ratios = Vec::new();
791 let max_image_tiles = config
792 .max_image_tiles
793 .context("`do_resize=false` is not supported, need `max_image_tiles`!")?;
794 *self.max_image_tiles.write().unwrap() = Some(max_image_tiles);
795
796 for mut image in images {
797 if config.do_convert_rgb.unwrap_or(true) {
799 image = DynamicImage::ImageRgb8(image.to_rgb8());
800 }
801
802 let size = config
803 .size
804 .as_ref()
805 .context("`do_resize=false` is not supported, need `size`!")?;
806
807 let (image, aspect_ratio) =
808 self.resize(image, size, max_image_tiles, config.resampling.to_filter()?)?;
809
810 let to_tensor_rescale = Transforms {
814 input: &ToTensorNoNorm,
815 inner_transforms: &[],
816 };
817 let mut image = image.apply(to_tensor_rescale, device)?;
818
819 image = self.pad(&image, size, aspect_ratio)?;
820
821 let transforms = TensorTransforms {
822 inner_transforms: &[
823 &config
824 .do_rescale
825 .is_some_and(|x| x)
826 .then_some(())
827 .map(|_| Rescale {
828 factor: config.rescale_factor,
829 }),
830 &config
831 .do_normalize
832 .is_some_and(|x| x)
833 .then_some(())
834 .map(|_| Normalize {
835 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
836 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
837 }),
838 ],
839 };
840 image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
841
842 let (num_tiles_height, num_tiles_width) = aspect_ratio;
843 image = self.split_to_tiles(&image, num_tiles_height, num_tiles_width)?;
844
845 sample_images.push(image);
846 sample_aspect_ratios.push((num_tiles_height, num_tiles_width));
847 }
848
849 let (images, num_tiles) =
850 self.pack_images(sample_images, max_image_tiles, (bs, max_num_images))?;
851
852 let aspect_ratio_ids = self.convert_aspect_ratios_to_ids(
853 sample_aspect_ratios.clone(),
854 max_image_tiles,
855 (bs, max_num_images),
856 device,
857 )?;
858 let aspect_ratio_mask = self.build_aspect_ratio_mask(
859 sample_aspect_ratios,
860 max_image_tiles,
861 (bs, max_num_images),
862 device,
863 )?;
864
865 Ok(PreprocessedImages {
866 pixel_values: images,
867 pixel_attention_mask: None,
868 image_sizes: None,
869 num_img_tokens: None,
870 aspect_ratio_ids: Some(aspect_ratio_ids),
871 aspect_ratio_mask: Some(aspect_ratio_mask),
872 num_tiles: Some(num_tiles),
873 image_grid_thw: None,
874 video_grid_thw: None,
875 rows: None,
876 cols: None,
877 pixel_values_list: None,
878 tgt_sizes: None,
879 image_sizes_all: None,
880 num_crops: None,
881 })
882 }
883}