1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 any::Any,
5 collections::HashMap,
6 num::NonZeroUsize,
7 sync::{Arc, RwLock},
8};
9
10use candle_core::{Context, DType, Device, Result, Tensor};
11use image::{imageops::FilterType, DynamicImage};
12use itertools::Itertools;
13use mistralrs_vision::{
14 ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15 Transforms,
16};
17use tokenizers::Tokenizer;
18use tracing::warn;
19
20use crate::{
21 device_map::DeviceMapper,
22 pipeline::{
23 text_models_inputs_processor::{
24 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
25 },
26 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
27 },
28 sequence::Sequence,
29 vision_models::{
30 image_processor::{ImagePreProcessor, PreprocessedImages},
31 preprocessor_config::{PreProcessorConfig, ToFilter},
32 ModelInputs,
33 },
34};
35
36use super::MLlamaSpecificArgs;
37
38const IMAGE_TOKEN: &str = "<|image|>";
39
40struct MLlamaImageProcessor {
42 max_image_tiles: RwLock<Option<usize>>,
44}
45pub struct MLlamaProcessor;
47
48impl MLlamaProcessor {
49 pub fn new() -> Self {
50 Self
51 }
52}
53
54impl Processor for MLlamaProcessor {
55 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
56 Arc::new(MLlamaImageProcessor {
57 max_image_tiles: RwLock::new(None),
58 })
59 }
60
61 fn get_special_tokens(&self) -> &[&'static str] {
62 &[IMAGE_TOKEN, "<|python_tag|>"]
63 }
64
65 fn template_action(&self) -> MessagesAction {
66 MessagesAction::FlattenOnlyText
67 }
68}
69
70fn get_cross_attention_token_mask(input_ids: Vec<u32>, image_token_id: u32) -> Vec<(i64, i64)> {
73 let image_token_locations = input_ids
74 .iter()
75 .positions(|token| *token == image_token_id)
76 .collect::<Vec<_>>();
77
78 if image_token_locations.is_empty() {
79 return vec![];
80 }
81
82 if image_token_locations.len() == 1 {
84 return vec![(image_token_locations[0] as i64, -1)];
85 }
86
87 let mut vision_masks = image_token_locations[..image_token_locations.len() - 1]
88 .iter()
89 .zip(&image_token_locations[1..])
90 .map(|(a, b)| (*a as i64, *b as i64))
91 .collect::<Vec<_>>();
92
93 vision_masks.push((
95 *image_token_locations.last().unwrap() as i64,
96 input_ids.len() as i64,
97 ));
98
99 let mut last_mask_end = vision_masks.last().unwrap().1;
102 for vision_mask in vision_masks.iter_mut().rev() {
103 if vision_mask.0 == vision_mask.1 - 1 {
104 vision_mask.1 = last_mask_end;
105 }
106 last_mask_end = vision_mask.1;
107 }
108
109 vision_masks
110}
111
112fn convert_sparse_cross_attention_mask_to_dense(
124 cross_attn_token_mask: Vec<Vec<(i64, i64)>>,
125 num_tiles: Vec<Vec<usize>>,
126 max_num_tiles: usize,
127 length: usize,
128 dev: &Device,
129) -> candle_core::Result<Tensor> {
130 let bs = cross_attn_token_mask.len();
131 let max_num_images = cross_attn_token_mask.iter().map(|x| x.len()).max().unwrap();
132
133 let mut cross_attention_mask = Tensor::zeros(
134 (bs, length, max_num_images, max_num_tiles),
135 DType::I64,
136 &Device::Cpu,
137 )?;
138
139 for (sample_idx, (sample_masks, sample_num_tiles)) in
140 cross_attn_token_mask.into_iter().zip(num_tiles).enumerate()
141 {
142 for (mask_idx, ((start, end), mask_num_tiles)) in
143 sample_masks.into_iter().zip(sample_num_tiles).enumerate()
144 {
145 let mut end = end.min(length as i64);
146 if end == -1 {
147 end = length as i64;
148 }
149 cross_attention_mask = cross_attention_mask.slice_assign(
150 &[
151 &sample_idx,
152 &(start as usize..end as usize),
153 &mask_idx,
154 &(..mask_num_tiles),
155 ],
156 &Tensor::ones(
157 (1, end as usize - start as usize, 1, mask_num_tiles),
158 DType::I64,
159 &Device::Cpu,
160 )?,
161 )?;
162 }
163 }
164
165 cross_attention_mask.to_device(dev)
166}
167
168impl InputsProcessor for MLlamaImageProcessor {
169 fn get_type(&self) -> InputsProcessorType {
170 InputsProcessorType::Vision
171 }
172 fn process_inputs(
173 &self,
174 tokenizer: Option<Arc<Tokenizer>>,
175 input_seqs: &mut [&mut Sequence],
176 is_prompt: bool,
177 is_xlora: bool,
178 device: &Device,
179 no_kv_cache: bool,
180 last_n_context_len: Option<(usize, usize)>,
181 return_raw_logits: bool,
182 other_config: Option<Arc<dyn Any>>,
183 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
184 prompt_chunksize: Option<NonZeroUsize>,
185 mapper: Option<&dyn DeviceMapper>,
186 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
187 if is_xlora {
188 return Box::new(std::iter::once(Err(anyhow::Error::msg(
189 "Cannot make inputs for X-LoRA vision model.",
190 ))));
191 }
192 if no_kv_cache {
193 return Box::new(std::iter::once(Err(anyhow::Error::msg(
194 "Vision model must have kv cache.",
195 ))));
196 }
197 if prompt_chunksize.is_some() {
199 warn!("`prompt_chunksize` is set. MLlama does not support prompt batching.");
200 }
201 let Some(tokenizer) = tokenizer else {
202 return Box::new(std::iter::once(Err(anyhow::Error::msg(
203 "MLlamaInputProcessor requires a specified tokenizer.",
204 ))));
205 };
206
207 let text_models_inputs_processor::InnerInputProcessorOutput {
208 inputs:
209 text_models_inputs_processor::InputMetadata {
210 input,
211 positions: _,
212 context_lens: _,
213 position_ids: _,
214 paged_attn_meta: _,
215 flash_meta: _,
216 },
217 seq_indices: _,
218 } = if is_prompt {
219 get_prompt_input(
220 input_seqs
221 .iter()
222 .map(|seq| seq.get_toks().to_vec())
223 .collect::<Vec<_>>(),
224 input_seqs,
225 device,
226 last_n_context_len,
227 return_raw_logits,
228 paged_attn_metadata.as_mut(),
229 None, mapper,
231 )
232 .nth(0)
233 .unwrap()
234 .unwrap()
235 } else {
236 get_completion_input(
237 input_seqs
238 .iter()
239 .map(|seq| seq.get_toks().to_vec())
240 .collect::<Vec<_>>(),
241 input_seqs,
242 device,
243 no_kv_cache,
244 last_n_context_len,
245 return_raw_logits,
246 paged_attn_metadata.as_mut(),
247 None, mapper,
249 )
250 .nth(0)
251 .unwrap()
252 .unwrap()
253 };
254 let config = other_config.expect("Need a PreProcessorConfig config.");
255 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
256
257 let has_images = input_seqs.iter().all(|seq| seq.has_images());
258
259 let (pixel_values, aspect_ratio_ids, aspect_ratio_mask, cross_attn_mask) = if has_images {
260 let mut pixel_values_accum = Vec::new();
261 let mut aspect_ratio_ids_accum = Vec::new();
262 let mut aspect_ratio_mask_accum = Vec::new();
263 let mut num_tiles_accum = Vec::new();
264
265 let bs = input_seqs.len();
266 let detokenized = tokenizer
267 .decode_batch(
268 &input_seqs
269 .iter()
270 .map(|seq| seq.get_toks())
271 .collect::<Vec<_>>(),
272 false,
273 )
274 .expect("Detokenization failed!");
275 let n_images_in_text = detokenized
276 .iter()
277 .map(|text| text.matches(IMAGE_TOKEN).count())
278 .collect::<Vec<_>>();
279 let n_images_in_images = input_seqs
280 .iter()
281 .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
282 .collect::<Vec<_>>();
283
284 if n_images_in_text != n_images_in_images {
285 return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
286 "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
287 )))));
288 }
289
290 let max_num_images = *n_images_in_images
291 .iter()
292 .max()
293 .expect("No max images per batch!");
294
295 for seq in input_seqs.iter_mut() {
296 let PreprocessedImages {
297 pixel_values,
298 pixel_attention_mask: _,
299 image_sizes: _,
300 num_img_tokens: _,
301 aspect_ratio_ids,
302 aspect_ratio_mask,
303 num_tiles,
304 image_grid_thw: _,
305 video_grid_thw: _,
306 rows: _,
307 cols: _,
308 pixel_values_list: _,
309 tgt_sizes: _,
310 image_sizes_all: _,
311 num_crops: _,
312 } = self
313 .preprocess(
314 seq.take_images()
315 .expect("Need to have images by this point."),
316 vec![],
317 config,
318 device,
319 (bs, max_num_images), )
321 .expect("Preprocessing failed");
322 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
323 aspect_ratio_ids_accum.push(aspect_ratio_ids.unwrap().unsqueeze(0).unwrap());
324 aspect_ratio_mask_accum.push(aspect_ratio_mask.unwrap().unsqueeze(0).unwrap());
325 num_tiles_accum.push(num_tiles.unwrap());
326 }
327
328 let image_token_id = tokenizer
330 .encode_fast(IMAGE_TOKEN, false)
331 .unwrap()
332 .get_ids()
333 .to_vec();
334 let image_token_id = if image_token_id.len() == 1 {
335 image_token_id[0]
336 } else {
337 panic!("{IMAGE_TOKEN} encoding should be one token, got {image_token_id:?}");
338 };
339 let chunks = input.chunk(input.dim(0).unwrap(), 0).unwrap();
340 let cross_attention_token_mask = chunks
341 .iter()
342 .map(|token_ids| {
343 get_cross_attention_token_mask(
344 token_ids.squeeze(0).unwrap().to_vec1::<u32>().unwrap(),
345 image_token_id,
346 )
347 })
348 .collect::<Vec<_>>();
349
350 let cross_attn_mask = convert_sparse_cross_attention_mask_to_dense(
351 cross_attention_token_mask,
352 num_tiles_accum,
353 self.max_image_tiles
354 .read()
355 .unwrap()
356 .expect("`max_image_tiles` must be set!"),
357 chunks
358 .iter()
359 .map(|input_ids| *input_ids.dims().last().unwrap())
360 .max()
361 .unwrap(),
362 chunks[0].device(),
363 );
364
365 let cross_attn_mask = match cross_attn_mask {
366 Ok(v) => v,
367 Err(e) => return Box::new(std::iter::once(Err(anyhow::Error::msg(e.to_string())))),
368 };
369
370 (
371 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
372 Some(Tensor::cat(&aspect_ratio_ids_accum, 0).unwrap()),
373 Some(Tensor::cat(&aspect_ratio_mask_accum, 0).unwrap()),
374 Some(cross_attn_mask),
375 )
376 } else {
377 (None, None, None, None)
378 };
379
380 let text_models_inputs_processor::InnerInputProcessorOutput {
381 inputs:
382 text_models_inputs_processor::InputMetadata {
383 input,
384 positions,
385 context_lens,
386 position_ids,
387 paged_attn_meta,
388 flash_meta,
389 },
390 seq_indices,
391 } = if is_prompt {
392 get_prompt_input(
393 input_seqs
394 .iter()
395 .map(|seq| seq.get_toks().to_vec())
396 .collect::<Vec<_>>(),
397 input_seqs,
398 device,
399 last_n_context_len,
400 return_raw_logits,
401 paged_attn_metadata.as_mut(),
402 None, mapper,
404 )
405 .nth(0)
406 .unwrap()
407 .unwrap()
408 } else {
409 get_completion_input(
410 input_seqs
411 .iter()
412 .map(|seq| seq.get_toks().to_vec())
413 .collect::<Vec<_>>(),
414 input_seqs,
415 device,
416 no_kv_cache,
417 last_n_context_len,
418 return_raw_logits,
419 paged_attn_metadata.as_mut(),
420 None, mapper,
422 )
423 .nth(0)
424 .unwrap()
425 .unwrap()
426 };
427
428 let inputs: Box<dyn Any> = Box::new(ModelInputs {
429 input_ids: input,
430 seqlen_offsets: positions,
431 context_lens,
432 position_ids,
433 pixel_values,
434 model_specific_args: Box::new(MLlamaSpecificArgs {
435 aspect_ratio_ids,
436 aspect_ratio_mask,
437 cross_attn_mask,
438 }),
439 paged_attn_meta,
440 flash_meta,
441 });
442 Box::new(std::iter::once(Ok(InputProcessorOutput {
443 inputs,
444 seq_indices,
445 })))
446 }
447}
448
449fn argmin<T, I>(iter: I) -> Option<usize>
450where
451 T: PartialOrd,
452 I: Iterator<Item = T>,
453{
454 iter.enumerate()
455 .fold(None, |min, (idx, item)| match min {
456 None => Some((idx, item)),
457 Some((min_idx, min_item)) => {
458 if item < min_item {
459 Some((idx, item))
460 } else {
461 Some((min_idx, min_item))
462 }
463 }
464 })
465 .map(|(min_idx, _)| min_idx)
466}
467
468impl MLlamaImageProcessor {
469 fn get_all_supported_aspect_ratios(max_image_tiles: usize) -> Vec<(usize, usize)> {
471 (1..max_image_tiles + 1)
472 .flat_map(|width| {
473 (1..max_image_tiles + 1).filter_map(move |height| {
474 if width * height <= max_image_tiles {
475 Some((width, height))
476 } else {
477 None
478 }
479 })
480 })
481 .collect::<Vec<_>>()
482 }
483
484 fn get_optimal_tiled_canvas(
486 image_height: u32,
487 image_width: u32,
488 max_image_tiles: usize,
489 tile_size: usize,
490 ) -> Result<(usize, usize)> {
491 let possible_tile_arrangements = Self::get_all_supported_aspect_ratios(max_image_tiles);
492 let possible_canvas_sizes: (Vec<_>, Vec<_>) = possible_tile_arrangements
493 .into_iter()
494 .map(|(h, w)| (h * tile_size, w * tile_size))
495 .unzip();
496 let (target_heights, target_widths) = possible_canvas_sizes;
498
499 let scale_h = target_heights
501 .iter()
502 .map(|h| *h as f32 / image_height as f32)
503 .collect::<Vec<_>>();
504 let scale_w = target_widths
505 .iter()
506 .map(|w| *w as f32 / image_width as f32)
507 .collect::<Vec<_>>();
508
509 let scales = scale_h
511 .into_iter()
512 .zip(scale_w)
513 .map(|(scale_h, scale_w)| if scale_w > scale_h { scale_h } else { scale_w })
514 .collect::<Vec<_>>();
515
516 let upscaling_options = scales
518 .iter()
519 .copied()
520 .filter(|scale| *scale >= 1.)
521 .collect::<Vec<_>>();
522 let selected_scale = if !upscaling_options.is_empty() {
523 upscaling_options
524 .into_iter()
525 .min_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
526 .context("No min, upscale")?
527 } else {
528 let downscaling_options = scales
530 .iter()
531 .copied()
532 .filter(|scale| *scale < 1.)
533 .collect::<Vec<_>>();
534 downscaling_options
535 .into_iter()
536 .max_by(|x, y| x.partial_cmp(y).expect("No ordering!"))
537 .context("No max, downscale")?
538 };
539
540 let chosen_canvas_h = target_heights
542 .iter()
543 .copied()
544 .enumerate()
545 .filter_map(|(i, h)| {
546 if scales[i] == selected_scale {
547 Some(h)
548 } else {
549 None
550 }
551 })
552 .collect::<Vec<_>>();
553 let chosen_canvas_w = target_widths
554 .iter()
555 .copied()
556 .enumerate()
557 .filter_map(|(i, w)| {
558 if scales[i] == selected_scale {
559 Some(w)
560 } else {
561 None
562 }
563 })
564 .collect::<Vec<_>>();
565
566 assert_eq!(chosen_canvas_h.len(), chosen_canvas_w.len());
567 if chosen_canvas_h.len() > 1 {
568 let optimal_idx = argmin(
569 chosen_canvas_h
570 .iter()
571 .zip(&chosen_canvas_w)
572 .map(|(h, w)| *h * *w),
573 )
574 .context("No argmin")?;
575 Ok((chosen_canvas_h[optimal_idx], chosen_canvas_w[optimal_idx]))
576 } else {
577 Ok((chosen_canvas_h[0], chosen_canvas_w[0]))
578 }
579 }
580
581 fn get_image_size_fit_to_canvas(
583 image_height: u32,
584 image_width: u32,
585 canvas_height: usize,
586 canvas_width: usize,
587 tile_size: usize,
588 ) -> (usize, usize) {
589 let target_width = (image_width as usize).clamp(tile_size, canvas_width);
590 let target_height = (image_height as usize).clamp(tile_size, canvas_height);
591
592 let scale_h = (target_height as f32) / (image_height as f32);
593 let scale_w = (target_width as f32) / (image_width as f32);
594
595 if scale_w < scale_h {
596 (
597 target_height.min((image_height as f32 * scale_w).floor() as usize),
598 target_width,
599 )
600 } else {
601 (
602 target_height,
603 target_width.min((image_width as f32 * scale_h).floor() as usize),
604 )
605 }
606 }
607
608 fn resize(
612 &self,
613 image: DynamicImage,
614 size: &HashMap<String, u32>,
615 max_image_tiles: usize,
616 filter: FilterType,
617 ) -> Result<(DynamicImage, (usize, usize))> {
618 let image_height = image.height();
619 let image_width = image.width();
620 let tile_size = size["height"] as usize;
621
622 let (canvas_height, canvas_width) =
623 Self::get_optimal_tiled_canvas(image_height, image_width, max_image_tiles, tile_size)?;
624 let num_tiles_height = canvas_height / tile_size;
625 let num_tiles_width = canvas_width / tile_size;
626
627 let (new_height, new_width) = Self::get_image_size_fit_to_canvas(
628 image_height,
629 image_width,
630 canvas_height,
631 canvas_width,
632 tile_size,
633 );
634
635 Ok((
636 image.resize_exact(new_width as u32, new_height as u32, filter),
637 (num_tiles_height, num_tiles_width),
638 ))
639 }
640
641 fn pad(
645 &self,
646 image: &Tensor,
647 size: &HashMap<String, u32>,
648 aspect_ratio: (usize, usize),
649 ) -> Result<Tensor> {
650 let (num_tiles_h, num_tiles_w) = aspect_ratio;
651 let padded_height = num_tiles_h * size["height"] as usize;
652 let padded_width = num_tiles_w * size["width"] as usize;
653
654 mistralrs_vision::pad(image, padded_height, padded_width)
656 }
657
658 fn split_to_tiles(
661 &self,
662 image: &Tensor,
663 num_tiles_height: usize,
664 num_tiles_width: usize,
665 ) -> Result<Tensor> {
666 let (ch, h, w) = image.dims3()?;
667 let tile_height = h / num_tiles_height;
668 let tile_width = w / num_tiles_width;
669
670 let mut image = image.reshape((
671 ch,
672 num_tiles_height,
673 tile_height,
674 num_tiles_width,
675 tile_width,
676 ))?;
677
678 image = image.permute((1, 3, 0, 2, 4))?;
680
681 image
683 .reshape((
684 num_tiles_width * num_tiles_height,
685 ch,
686 tile_height,
687 tile_width,
688 ))?
689 .contiguous()
690 }
691
692 fn pack_images(
698 &self,
699 images: Vec<Tensor>,
700 max_image_tiles: usize,
701 (_bs, max_num_images): (usize, usize),
702 ) -> Result<(Tensor, Vec<usize>)> {
703 let (_, ch, tile_h, tile_w) = images[0].dims4()?;
704
705 let mut stacked_images = Tensor::zeros(
706 (max_num_images, max_image_tiles, ch, tile_h, tile_w),
707 images[0].dtype(),
708 images[0].device(),
709 )?;
710 let mut num_sample_tiles = Vec::new();
711 for (i, image) in images.into_iter().enumerate() {
712 let num_tiles = image.dim(0)?;
713 stacked_images = stacked_images
714 .slice_assign(&[&i, &(..num_tiles), &.., &.., &..], &image.unsqueeze(0)?)?;
715 num_sample_tiles.push(num_tiles)
716 }
717 Ok((stacked_images, num_sample_tiles))
718 }
719
720 fn convert_aspect_ratios_to_ids(
724 &self,
725 aspect_ratios: Vec<(usize, usize)>,
726 max_image_tiles: usize,
727 (_bs, max_num_images): (usize, usize),
728 device: &Device,
729 ) -> Result<Tensor> {
730 let supported_aspect_ratios = Self::get_all_supported_aspect_ratios(max_image_tiles);
731
732 let mut aspect_ratios_ids = vec![0i64; max_num_images];
733 for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
734 aspect_ratios_ids[i] = (supported_aspect_ratios
735 .iter()
736 .position(|(h, w)| *h == *num_tiles_h && *w == *num_tiles_w)
737 .context("Could not find aspect ratio")?
738 + 1) as i64;
739 }
740
741 Tensor::new(aspect_ratios_ids, device)
742 }
743
744 fn build_aspect_ratio_mask(
745 &self,
746 aspect_ratios: Vec<(usize, usize)>,
747 max_image_tiles: usize,
748 (_bs, max_num_images): (usize, usize),
749 device: &Device,
750 ) -> Result<Tensor> {
751 let mut aspect_ratio_mask =
752 Tensor::zeros((max_num_images, max_image_tiles), DType::I64, device)?;
753
754 aspect_ratio_mask = aspect_ratio_mask.slice_assign(
758 &[&.., &0],
759 &Tensor::ones((max_num_images, 1), DType::I64, device)?,
760 )?;
761
762 for (i, (num_tiles_h, num_tiles_w)) in aspect_ratios.iter().enumerate() {
763 aspect_ratio_mask = aspect_ratio_mask.slice_assign(
764 &[&i, &(..*num_tiles_h * *num_tiles_w)],
765 &Tensor::ones((1, *num_tiles_h * *num_tiles_w), DType::I64, device)?,
766 )?;
767 }
768
769 Ok(aspect_ratio_mask)
770 }
771}
772
773impl ImagePreProcessor for MLlamaImageProcessor {
774 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
775 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
776
777 fn preprocess(
778 &self,
779 images: Vec<DynamicImage>,
780 videos: Vec<Vec<DynamicImage>>,
781 config: &PreProcessorConfig,
782 device: &Device,
783 (bs, max_num_images): (usize, usize),
784 ) -> Result<PreprocessedImages> {
785 assert!(videos.is_empty());
786
787 let mut sample_images = Vec::new();
788 let mut sample_aspect_ratios = Vec::new();
789 let max_image_tiles = config
790 .max_image_tiles
791 .context("`do_resize=false` is not supported, need `max_image_tiles`!")?;
792 *self.max_image_tiles.write().unwrap() = Some(max_image_tiles);
793
794 for mut image in images {
795 if config.do_convert_rgb.unwrap_or(true) {
797 image = DynamicImage::ImageRgb8(image.to_rgb8());
798 }
799
800 let size = config
801 .size
802 .as_ref()
803 .context("`do_resize=false` is not supported, need `size`!")?;
804
805 let (image, aspect_ratio) =
806 self.resize(image, size, max_image_tiles, config.resampling.to_filter()?)?;
807
808 let to_tensor_rescale = Transforms {
812 input: &ToTensorNoNorm,
813 inner_transforms: &[],
814 };
815 let mut image = image.apply(to_tensor_rescale, device)?;
816
817 image = self.pad(&image, size, aspect_ratio)?;
818
819 let transforms = TensorTransforms {
820 inner_transforms: &[
821 &config
822 .do_rescale
823 .is_some_and(|x| x)
824 .then_some(())
825 .map(|_| Rescale {
826 factor: config.rescale_factor,
827 }),
828 &config
829 .do_normalize
830 .is_some_and(|x| x)
831 .then_some(())
832 .map(|_| Normalize {
833 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
834 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
835 }),
836 ],
837 };
838 image = <Tensor as ApplyTensorTransforms>::apply(&image, transforms, device)?;
839
840 let (num_tiles_height, num_tiles_width) = aspect_ratio;
841 image = self.split_to_tiles(&image, num_tiles_height, num_tiles_width)?;
842
843 sample_images.push(image);
844 sample_aspect_ratios.push((num_tiles_height, num_tiles_width));
845 }
846
847 let (images, num_tiles) =
848 self.pack_images(sample_images, max_image_tiles, (bs, max_num_images))?;
849
850 let aspect_ratio_ids = self.convert_aspect_ratios_to_ids(
851 sample_aspect_ratios.clone(),
852 max_image_tiles,
853 (bs, max_num_images),
854 device,
855 )?;
856 let aspect_ratio_mask = self.build_aspect_ratio_mask(
857 sample_aspect_ratios,
858 max_image_tiles,
859 (bs, max_num_images),
860 device,
861 )?;
862
863 Ok(PreprocessedImages {
864 pixel_values: images,
865 pixel_attention_mask: None,
866 image_sizes: None,
867 num_img_tokens: None,
868 aspect_ratio_ids: Some(aspect_ratio_ids),
869 aspect_ratio_mask: Some(aspect_ratio_mask),
870 num_tiles: Some(num_tiles),
871 image_grid_thw: None,
872 video_grid_thw: None,
873 rows: None,
874 cols: None,
875 pixel_values_list: None,
876 tgt_sizes: None,
877 image_sizes_all: None,
878 num_crops: None,
879 })
880 }
881}