1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 any::Any,
5 collections::{HashMap, HashSet},
6 sync::Arc,
7};
8
9use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
10use image::DynamicImage;
11use itertools::Itertools;
12use mistralrs_vision::{
13 ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
14 Transforms,
15};
16use ordered_float::NotNan;
17use tokenizers::Tokenizer;
18
19use crate::{
20 device_map::DeviceMapper,
21 pipeline::{
22 text_models_inputs_processor::{
23 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
24 },
25 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
26 },
27 sequence::Sequence,
28 vision_models::{
29 image_processor::{ImagePreProcessor, PreprocessedImages},
30 preprocessor_config::PreProcessorConfig,
31 processor_config::ProcessorConfig,
32 ModelInputs,
33 },
34};
35
36use super::Llama4ModelSpecificArgs;
37
38pub(crate) const IMAGE_TOKEN: &str = "<|image|>";
39const IMAGE_START: &str = "<|image_start|>";
40const IMAGE_END: &str = "<|image_end|>";
41const PATCH: &str = "<|patch|>";
42const TILE_X_SEP: &str = "<|tile_x_separator|>";
43const TILE_Y_SEP: &str = "<|tile_y_separator|>";
44
45pub struct Llama4ImageProcessor {
47 pub patch_size: usize,
48 pub downsample_ratio: usize,
49}
50
51impl Llama4ImageProcessor {
52 pub fn new(patch_size: Option<usize>, pixel_shuffle_ratio: Option<f32>) -> Self {
53 Self {
54 patch_size: patch_size.unwrap_or(14),
55 downsample_ratio: (1. / pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round() as usize,
56 }
57 }
58}
59
60pub struct Llama4Processor {
62 patch_size: usize,
63 downsample_ratio: usize,
64}
65
66impl Llama4Processor {
67 pub fn new(cfg: &ProcessorConfig) -> Self {
68 Self {
69 patch_size: cfg.patch_size.unwrap_or(14),
70 downsample_ratio: (1. / cfg.pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round()
71 as usize,
72 }
73 }
74}
75
76impl Processor for Llama4Processor {
77 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
78 Arc::new(Llama4ImageProcessor {
79 patch_size: self.patch_size,
80 downsample_ratio: self.downsample_ratio,
81 })
82 }
83
84 fn get_special_tokens(&self) -> &[&'static str] {
85 &[
86 IMAGE_START,
87 IMAGE_END,
88 PATCH,
89 TILE_X_SEP,
90 TILE_Y_SEP,
91 IMAGE_TOKEN,
92 ]
93 }
94
95 fn template_action(&self) -> MessagesAction {
96 MessagesAction::FlattenOnlyText
97 }
98}
99
100impl Llama4ImageProcessor {
101 fn prompt_split_image(&self, aspect_ratio: &Tensor, num_patches_per_chunk: usize) -> String {
102 let mut img_string = IMAGE_START.to_string();
103 let aspect_ratio = aspect_ratio.to_vec1::<u32>().unwrap();
104 let (ratio_h, ratio_w) = (aspect_ratio[0] as usize, aspect_ratio[1] as usize);
105 if ratio_h * ratio_w > 1 {
106 for _yy in 0..ratio_h {
107 for xx in 0..ratio_w {
108 img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
109 if xx < ratio_w - 1 {
110 img_string.push_str(TILE_X_SEP);
111 }
112 }
113 img_string.push_str(TILE_Y_SEP);
114 }
115 }
116 img_string.push_str(IMAGE_TOKEN);
117 img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
118 img_string.push_str(IMAGE_END);
119 img_string
120 }
121}
122
123impl InputsProcessor for Llama4ImageProcessor {
124 fn get_type(&self) -> InputsProcessorType {
125 InputsProcessorType::Vision
126 }
127 fn process_inputs(
128 &self,
129 tokenizer: Option<Arc<Tokenizer>>,
130 input_seqs: &mut [&mut Sequence],
131 is_prompt: bool,
132 is_xlora: bool,
133 device: &Device,
134 no_kv_cache: bool,
135 last_n_context_len: Option<(usize, usize)>,
136 return_raw_logits: bool,
137 other_config: Option<Arc<dyn Any>>,
138 mut paged_attn_metadata: Option<PagedAttentionMeta>,
139 mapper: Option<&dyn DeviceMapper>,
140 ) -> anyhow::Result<InputProcessorOutput> {
141 if is_xlora {
142 return Err(anyhow::Error::msg(
143 "Cannot make inputs for X-LoRA vision model.",
144 ));
145 }
146 if no_kv_cache {
147 return Err(anyhow::Error::msg("Vision model must have kv cache."));
148 }
149 let Some(tokenizer) = tokenizer else {
150 return Err(anyhow::Error::msg(
151 "Llama4InputProcessor requires a specified tokenizer.",
152 ));
153 };
154
155 let config = other_config.expect("Need a PreProcessorConfig config.");
156 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
157
158 let has_images = input_seqs.iter().all(|seq| seq.has_images());
159
160 let pixel_values = if has_images {
161 let mut pixel_values_accum = Vec::new();
162 let mut aspect_ratios_accum = Vec::new();
163
164 let bs = input_seqs.len();
165 let detokenized = tokenizer
166 .decode_batch(
167 &input_seqs
168 .iter()
169 .map(|seq| seq.get_toks())
170 .collect::<Vec<_>>(),
171 false,
172 )
173 .expect("Detokenization failed!");
174 let n_images_in_text = detokenized
175 .iter()
176 .map(|text| text.matches(IMAGE_TOKEN).count())
177 .collect::<Vec<_>>();
178 let n_images_in_images = input_seqs
179 .iter()
180 .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
181 .collect::<Vec<_>>();
182
183 if n_images_in_text != n_images_in_images {
184 return Err(anyhow::Error::msg(format!(
185 "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
186 )));
187 }
188
189 let max_num_images = *n_images_in_images
190 .iter()
191 .max()
192 .expect("No max images per batch!");
193
194 for seq in input_seqs.iter_mut() {
195 let PreprocessedImages {
196 pixel_values,
197 pixel_attention_mask: _,
198 image_sizes: _,
199 num_img_tokens: _,
200 aspect_ratio_ids,
201 aspect_ratio_mask: _,
202 num_tiles: _,
203 image_grid_thw: _,
204 video_grid_thw: _,
205 rows: _,
206 cols: _,
207 pixel_values_list: _,
208 tgt_sizes: _,
209 image_sizes_all: _,
210 num_crops: _,
211 } = self
212 .preprocess(
213 seq.take_images()
214 .expect("Need to have images by this point."),
215 vec![],
216 config,
217 device,
218 (bs, max_num_images), )
220 .expect("Preprocessing failed");
221 pixel_values_accum.push(pixel_values);
223 aspect_ratios_accum.push(aspect_ratio_ids.unwrap());
224 }
225
226 let pixel_values = Tensor::cat(&pixel_values_accum, 0).unwrap();
227 let aspect_ratios = Tensor::cat(&aspect_ratios_accum, 0).unwrap();
228
229 let (image_h, image_w) = (
230 pixel_values.dim(D::Minus2).unwrap(),
231 pixel_values.dim(D::Minus1).unwrap(),
232 );
233 let num_patches_per_chunk =
234 (image_h / self.patch_size) * (image_w / self.patch_size) / self.downsample_ratio;
235
236 let placeholder_counts = input_seqs
237 .iter()
238 .map(|seq| seq.get_initial_prompt().match_indices(IMAGE_TOKEN).count())
239 .collect::<Vec<_>>();
240
241 let mut image_index = 0;
242 for (seq, placeholder_count) in input_seqs.iter_mut().zip(placeholder_counts) {
243 if placeholder_count == 0 {
244 continue;
245 }
246 let prompt_splits: std::str::Split<'_, &str> =
247 seq.get_initial_prompt().split(IMAGE_TOKEN);
248 let mut new_prompt = Vec::new();
249 for (local_image_index, split_part) in prompt_splits.enumerate() {
250 new_prompt.push(split_part.to_string());
251 if local_image_index < placeholder_count {
252 let tokens_for_this_image = self.prompt_split_image(
253 &aspect_ratios.i(image_index).unwrap(),
254 num_patches_per_chunk,
255 );
256 image_index += 1;
257 new_prompt.push(tokens_for_this_image);
258 }
259 }
260 let prompt = new_prompt.join("");
261
262 if !seq.multimodal.has_changed_prompt {
263 seq.set_initial_prompt(prompt.clone());
264 let toks = tokenizer
265 .encode_fast(prompt, false)
266 .expect("Detokenization failed!");
267
268 let ids = toks.get_ids().to_vec();
269 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
270 seq.multimodal.has_changed_prompt = true;
271 }
272 }
273
274 Some(pixel_values)
275 } else {
276 None
277 };
278
279 let text_models_inputs_processor::InnerInputProcessorOutput {
280 inputs:
281 text_models_inputs_processor::InputMetadata {
282 input,
283 positions,
284 context_lens,
285 position_ids,
286 paged_attn_meta,
287 flash_meta,
288 },
289 seq_indices,
290 } = if is_prompt {
291 get_prompt_input(
292 input_seqs
293 .iter()
294 .map(|seq| seq.get_toks())
295 .collect::<Vec<_>>(),
296 input_seqs,
297 device,
298 last_n_context_len,
299 return_raw_logits,
300 paged_attn_metadata.as_mut(),
301 mapper,
302 )
303 .unwrap()
304 } else {
305 get_completion_input(
306 input_seqs
307 .iter()
308 .map(|seq| seq.get_toks())
309 .collect::<Vec<_>>(),
310 input_seqs,
311 device,
312 no_kv_cache,
313 last_n_context_len,
314 return_raw_logits,
315 paged_attn_metadata.as_mut(),
316 mapper,
317 )
318 .unwrap()
319 };
320
321 let inputs: Box<dyn Any> = Box::new(ModelInputs {
322 input_ids: input,
323 seqlen_offsets: positions,
324 context_lens,
325 position_ids,
326 pixel_values,
327 model_specific_args: Box::new(Llama4ModelSpecificArgs),
328 paged_attn_meta,
329 flash_meta,
330 });
331 Ok(InputProcessorOutput {
332 inputs,
333 seq_indices,
334 })
335 }
336}
337
338impl Llama4ImageProcessor {
339 fn get_factors(dividend: u32) -> HashSet<u32> {
340 let mut factors_set = HashSet::new();
341
342 let sqrt = (dividend as f64).sqrt() as u32;
343 for i in 1..=sqrt {
344 if dividend % i == 0 {
345 factors_set.insert(i);
346 factors_set.insert(dividend / i);
347 }
348 }
349
350 factors_set
351 }
352
353 fn find_supported_resolutions(
354 &self,
355 max_num_chunks: usize,
356 size: &HashMap<String, u32>,
357 ) -> Result<Vec<(u32, u32)>> {
358 let height = size["height"];
359 let width = size["width"];
360 if height != width {
361 candle_core::bail!("Expected config size height==width ({height}!={width})");
362 }
363
364 let patch_size = height;
365
366 let mut asp_map = HashMap::new();
367 for chunk_size in (0..max_num_chunks).rev() {
368 let factors = Self::get_factors(chunk_size as u32);
369 let asp_ratios = factors
370 .into_iter()
371 .sorted()
372 .map(|factors| (factors, chunk_size as u32 / factors));
373 for (h, w) in asp_ratios {
374 let ratio_float = h as f32 / w as f32;
375 asp_map
376 .entry(NotNan::new(ratio_float).context("f32 is NaN")?)
377 .or_insert_with(Vec::new)
378 .push((h, w));
379 }
380 }
381
382 let possible_resolutions = asp_map
384 .into_values()
385 .flatten()
386 .map(|(height, depth)| (height * patch_size, depth * patch_size))
387 .collect::<Vec<_>>();
388
389 Ok(possible_resolutions)
390 }
391
392 #[allow(clippy::type_complexity)]
393 fn group_images_by_shape(
394 &self,
395 images: &[Tensor],
396 ) -> Result<(
397 HashMap<(usize, usize), Tensor>,
398 HashMap<usize, ((usize, usize), usize)>,
399 )> {
400 let mut grouped_images = HashMap::new();
401 let mut grouped_images_index = HashMap::new();
402 for (i, image) in images.iter().enumerate() {
403 let (_c, h, w) = image.dims3()?;
404 let shape = (h, w);
405 grouped_images
406 .entry(shape)
407 .or_insert_with(Vec::new)
408 .push(image.clone());
409 grouped_images_index.insert(i, (shape, grouped_images[&shape].len() - 1));
410 }
411 let mut grouped_images_stack = HashMap::new();
413 for (shape, images) in grouped_images {
414 grouped_images_stack.insert(shape, Tensor::stack(&images, 0)?);
415 }
416
417 Ok((grouped_images_stack, grouped_images_index))
418 }
419
420 fn get_best_fit(
421 &self,
422 (original_height, original_width): (u32, u32),
423 possible_resolutions: Vec<(u32, u32)>,
424 resize_to_max_canvas: bool,
425 ) -> Result<(u32, u32)> {
426 let (target_heights, target_widths): (Vec<u32>, Vec<u32>) =
428 possible_resolutions.iter().copied().unzip();
429
430 let scale_w = target_widths
432 .iter()
433 .map(|tw| *tw as f32 / original_width as f32);
434 let scale_h = target_heights
435 .iter()
436 .map(|th| *th as f32 / original_height as f32);
437
438 let scales = scale_w.zip(scale_h).map(|(w, h)| if h > w { w } else { h });
440
441 let upscaling_options = scales
443 .clone()
444 .filter(|s| *s >= 1.)
445 .map(|x| NotNan::new(x).unwrap())
446 .collect::<Vec<_>>();
447 let downscaling_options = scales
448 .clone()
449 .filter(|s| *s < 1.)
450 .map(|x| NotNan::new(x).unwrap())
451 .collect::<Vec<_>>();
452 let selected_scale = if !upscaling_options.is_empty() {
453 if resize_to_max_canvas {
454 upscaling_options.into_iter().max().unwrap().into_inner()
455 } else {
456 upscaling_options.into_iter().min().unwrap().into_inner()
457 }
458 } else {
459 downscaling_options.into_iter().max().unwrap().into_inner()
461 };
462
463 let chosen_canvas = possible_resolutions
468 .into_iter()
469 .zip(scales)
470 .filter_map(|(possible, scale)| {
471 if scale == selected_scale {
472 Some(possible)
473 } else {
474 None
475 }
476 })
477 .sorted_by_key(|(h, w)| h * w)
478 .take(1)
479 .collect::<Vec<_>>()[0];
480
481 Ok(chosen_canvas)
482 }
483
484 fn get_max_res_without_distortion(
485 &self,
486 image_size: (u32, u32),
487 target_size: (u32, u32),
488 ) -> (u32, u32) {
489 let (original_height, original_width) = image_size;
490 let (target_height, target_width) = target_size;
491
492 let scale_w = target_width as f64 / original_width as f64;
493 let scale_h = target_height as f64 / original_height as f64;
494
495 if scale_w < scale_h {
496 let new_width = target_width;
497 let new_height = std::cmp::min(
499 (original_height as f64 * scale_w).floor() as u32,
500 target_height,
501 );
502 (new_height, new_width)
503 } else {
504 let new_height = target_height;
505 let new_width = std::cmp::min(
507 (original_width as f64 * scale_h).floor() as u32,
508 target_width,
509 );
510 (new_height, new_width)
511 }
512 }
513
514 fn split_to_tiles(
515 &self,
516 images: &Tensor,
517 num_tiles_h: usize,
518 num_tiles_w: usize,
519 ) -> Result<Tensor> {
520 let (bs, c, h, w) = images.dims4()?;
521 let mut images = images.reshape((
522 bs,
523 c,
524 num_tiles_h,
525 h / num_tiles_h,
526 num_tiles_w,
527 w / num_tiles_w,
528 ))?;
529 images = images.permute((0, 2, 4, 1, 3, 5))?.contiguous()?;
530 images.reshape((
531 bs,
532 num_tiles_h * num_tiles_w,
533 c,
534 h / num_tiles_h,
535 w / num_tiles_w,
536 ))
537 }
538
539 fn reorder_images(
540 &self,
541 processed_images: HashMap<(usize, usize), Tensor>,
542 grouped_images_index: HashMap<usize, ((usize, usize), usize)>,
543 ) -> Result<Vec<Tensor>> {
544 grouped_images_index
545 .values()
546 .map(|(k, v)| processed_images[k].i(*v))
547 .collect::<Result<Vec<Tensor>>>()
548 }
549}
550
551impl ImagePreProcessor for Llama4ImageProcessor {
552 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
553 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
554
555 fn preprocess(
556 &self,
557 images_d: Vec<DynamicImage>,
558 videos: Vec<Vec<DynamicImage>>,
559 config: &PreProcessorConfig,
560 device: &Device,
561 (_bs, _max_num_images): (usize, usize),
562 ) -> Result<PreprocessedImages> {
563 assert!(videos.is_empty());
564
565 let max_patches = config.max_patches.unwrap_or(16);
566 let size = config.size.clone().unwrap_or(HashMap::from_iter([
567 ("height".to_string(), 336),
568 ("width".to_string(), 336),
569 ]));
570 let resize_to_max_canvas = config.resize_to_max_canvas.unwrap_or(false);
571 let do_rescale = config.do_rescale.unwrap_or(true);
572 let do_normalize = config.do_normalize.unwrap_or(true);
573
574 let possible_resolutions = self.find_supported_resolutions(max_patches, &size)?;
575
576 let mut images = Vec::new();
577 for mut image in images_d {
578 if config.do_convert_rgb.unwrap_or(true) {
580 image = DynamicImage::ImageRgb8(image.to_rgb8());
581 }
582
583 let to_tensor_rescale = Transforms {
584 input: &ToTensorNoNorm,
585 inner_transforms: &[],
586 };
587 let image = image.apply(to_tensor_rescale, device)?;
588 images.push(image);
589 }
590
591 let (grouped_images, grouped_images_index) = self.group_images_by_shape(&images)?;
592
593 let mut grouped_processed_images = HashMap::new();
594 let mut grouped_aspect_ratios = HashMap::new();
595 for (shape, stacked_images) in grouped_images {
596 let image_size = (
597 stacked_images.dim(D::Minus2)? as u32,
598 stacked_images.dim(D::Minus1)? as u32,
599 );
600 let target_size = self.get_best_fit(
601 image_size,
602 possible_resolutions.clone(),
603 resize_to_max_canvas,
604 )?;
605 let max_upscaling_size = if resize_to_max_canvas {
607 None
608 } else {
609 Some(size["height"])
610 };
611 let target_size_without_distortion =
612 if let Some(max_upscaling_size) = max_upscaling_size {
613 let nt_h = image_size.0.max(max_upscaling_size).min(target_size.0);
614 let nt_w = image_size.1.max(max_upscaling_size).min(target_size.1);
615 (nt_h, nt_w)
616 } else {
617 candle_core::bail!("Currently resize_to_max_canvas is assumed!");
618 };
619
620 let new_size_without_distortion =
622 self.get_max_res_without_distortion(image_size, target_size_without_distortion);
623 let mut processed_images = stacked_images.interpolate2d(
624 new_size_without_distortion.0.max(1) as usize,
625 new_size_without_distortion.1.max(1) as usize,
626 )?;
627
628 processed_images = {
630 let (target_h, target_w) = target_size;
631 let (h, w) = (
632 processed_images.dim(D::Minus2)?,
633 processed_images.dim(D::Minus1)?,
634 );
635 let paste_x_r = target_w as usize - w;
636 let paste_y_r = target_h as usize - h;
637 processed_images
638 .pad_with_zeros(D::Minus2, 0, paste_y_r)?
639 .pad_with_zeros(D::Minus1, 0, paste_x_r)?
640 };
641
642 let rescale_and_norm_transforms = TensorTransforms {
643 inner_transforms: &[
644 &do_rescale.then_some(Rescale {
645 factor: config.rescale_factor,
646 }),
647 &do_normalize.then_some(Normalize {
648 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
649 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
650 }),
651 ],
652 };
653 processed_images = <Tensor as ApplyTensorTransforms>::apply(
654 &processed_images,
655 rescale_and_norm_transforms,
656 device,
657 )?;
658
659 let (ratio_h, ratio_w) = (
660 target_size.0 / size["height"],
661 target_size.1 / size["width"],
662 );
663 processed_images =
665 self.split_to_tiles(&processed_images, ratio_h as usize, ratio_w as usize)?;
666 grouped_processed_images.insert(shape, processed_images.clone());
667 grouped_aspect_ratios.insert(
668 shape,
669 Tensor::new(vec![vec![ratio_h, ratio_w]; stacked_images.dim(0)?], device)?,
670 );
671
672 if ratio_h * ratio_w > 1 {
674 let mut global_tiles = stacked_images
675 .interpolate2d(size["height"] as usize, size["width"] as usize)?;
676 global_tiles = <Tensor as ApplyTensorTransforms>::apply(
677 &global_tiles,
678 rescale_and_norm_transforms,
679 device,
680 )?;
681 grouped_processed_images.insert(
682 shape,
683 Tensor::cat(&[processed_images, global_tiles.unsqueeze(1)?], 1)?,
684 );
685 }
686 }
687
688 let processed_images =
689 self.reorder_images(grouped_processed_images, grouped_images_index.clone())?;
690 let aspect_ratios_list =
691 self.reorder_images(grouped_aspect_ratios, grouped_images_index.clone())?;
692
693 let processed_images = Tensor::cat(&processed_images, 0)?;
694 let aspect_ratios = Tensor::stack(&aspect_ratios_list, 0)?;
695
696 Ok(PreprocessedImages {
697 pixel_values: processed_images,
698 pixel_attention_mask: None,
699 image_sizes: None,
700 num_img_tokens: None,
701 aspect_ratio_ids: Some(aspect_ratios),
702 aspect_ratio_mask: None,
703 num_tiles: None,
704 image_grid_thw: None,
705 video_grid_thw: None,
706 rows: None,
707 cols: None,
708 pixel_values_list: None,
709 tgt_sizes: None,
710 image_sizes_all: None,
711 num_crops: None,
712 })
713 }
714}