1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 any::Any,
5 collections::{HashMap, HashSet},
6 num::NonZeroUsize,
7 sync::Arc,
8};
9
10use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
11use image::DynamicImage;
12use itertools::Itertools;
13use mistralrs_vision::{
14 ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15 Transforms,
16};
17use ordered_float::NotNan;
18use tokenizers::Tokenizer;
19use tracing::warn;
20
21use crate::{
22 device_map::DeviceMapper,
23 pipeline::{
24 text_models_inputs_processor::{
25 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
26 },
27 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
28 },
29 sequence::Sequence,
30 vision_models::{
31 image_processor::{ImagePreProcessor, PreprocessedImages},
32 preprocessor_config::PreProcessorConfig,
33 processor_config::ProcessorConfig,
34 ModelInputs,
35 },
36};
37
38use super::Llama4ModelSpecificArgs;
39
40pub(crate) const IMAGE_TOKEN: &str = "<|image|>";
41const IMAGE_START: &str = "<|image_start|>";
42const IMAGE_END: &str = "<|image_end|>";
43const PATCH: &str = "<|patch|>";
44const TILE_X_SEP: &str = "<|tile_x_separator|>";
45const TILE_Y_SEP: &str = "<|tile_y_separator|>";
46
47pub struct Llama4ImageProcessor {
49 pub patch_size: usize,
50 pub downsample_ratio: usize,
51}
52
53impl Llama4ImageProcessor {
54 pub fn new(patch_size: Option<usize>, pixel_shuffle_ratio: Option<f32>) -> Self {
55 Self {
56 patch_size: patch_size.unwrap_or(14),
57 downsample_ratio: (1. / pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round() as usize,
58 }
59 }
60}
61
62pub struct Llama4Processor {
64 patch_size: usize,
65 downsample_ratio: usize,
66}
67
68impl Llama4Processor {
69 pub fn new(cfg: &ProcessorConfig) -> Self {
70 Self {
71 patch_size: cfg.patch_size.unwrap_or(14),
72 downsample_ratio: (1. / cfg.pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round()
73 as usize,
74 }
75 }
76}
77
78impl Processor for Llama4Processor {
79 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
80 Arc::new(Llama4ImageProcessor {
81 patch_size: self.patch_size,
82 downsample_ratio: self.downsample_ratio,
83 })
84 }
85
86 fn get_special_tokens(&self) -> &[&'static str] {
87 &[
88 IMAGE_START,
89 IMAGE_END,
90 PATCH,
91 TILE_X_SEP,
92 TILE_Y_SEP,
93 IMAGE_TOKEN,
94 ]
95 }
96
97 fn template_action(&self) -> MessagesAction {
98 MessagesAction::FlattenOnlyText
99 }
100}
101
102impl Llama4ImageProcessor {
103 fn prompt_split_image(&self, aspect_ratio: &Tensor, num_patches_per_chunk: usize) -> String {
104 let mut img_string = IMAGE_START.to_string();
105 let aspect_ratio = aspect_ratio.to_vec1::<u32>().unwrap();
106 let (ratio_h, ratio_w) = (aspect_ratio[0] as usize, aspect_ratio[1] as usize);
107 if ratio_h * ratio_w > 1 {
108 for _yy in 0..ratio_h {
109 for xx in 0..ratio_w {
110 img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
111 if xx < ratio_w - 1 {
112 img_string.push_str(TILE_X_SEP);
113 }
114 }
115 img_string.push_str(TILE_Y_SEP);
116 }
117 }
118 img_string.push_str(IMAGE_TOKEN);
119 img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
120 img_string.push_str(IMAGE_END);
121 img_string
122 }
123}
124
125impl InputsProcessor for Llama4ImageProcessor {
126 fn get_type(&self) -> InputsProcessorType {
127 InputsProcessorType::Vision
128 }
129 fn process_inputs(
130 &self,
131 tokenizer: Option<Arc<Tokenizer>>,
132 input_seqs: &mut [&mut Sequence],
133 is_prompt: bool,
134 is_xlora: bool,
135 device: &Device,
136 no_kv_cache: bool,
137 last_n_context_len: Option<(usize, usize)>,
138 return_raw_logits: bool,
139 other_config: Option<Arc<dyn Any>>,
140 mut paged_attn_metadata: Option<PagedAttentionMeta>,
141 prompt_chunksize: Option<NonZeroUsize>,
142 mapper: Option<&dyn DeviceMapper>,
143 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
144 if is_xlora {
145 return Box::new(std::iter::once(Err(anyhow::Error::msg(
146 "Cannot make inputs for X-LoRA vision model.",
147 ))));
148 }
149 if no_kv_cache {
150 return Box::new(std::iter::once(Err(anyhow::Error::msg(
151 "Vision model must have kv cache.",
152 ))));
153 }
154 if prompt_chunksize.is_some() {
156 warn!("`prompt_chunksize` is set. Llama4 does not support prompt batching.");
157 }
158 let Some(tokenizer) = tokenizer else {
159 return Box::new(std::iter::once(Err(anyhow::Error::msg(
160 "Llama4InputProcessor requires a specified tokenizer.",
161 ))));
162 };
163
164 let config = other_config.expect("Need a PreProcessorConfig config.");
165 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
166
167 let has_images = input_seqs.iter().all(|seq| seq.has_images());
168
169 let pixel_values = if has_images {
170 let mut pixel_values_accum = Vec::new();
171 let mut aspect_ratios_accum = Vec::new();
172
173 let bs = input_seqs.len();
174 let detokenized = tokenizer
175 .decode_batch(
176 &input_seqs
177 .iter()
178 .map(|seq| seq.get_toks())
179 .collect::<Vec<_>>(),
180 false,
181 )
182 .expect("Detokenization failed!");
183 let n_images_in_text = detokenized
184 .iter()
185 .map(|text| text.matches(IMAGE_TOKEN).count())
186 .collect::<Vec<_>>();
187 let n_images_in_images = input_seqs
188 .iter()
189 .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
190 .collect::<Vec<_>>();
191
192 if n_images_in_text != n_images_in_images {
193 return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
194 "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
195 )))));
196 }
197
198 let max_num_images = *n_images_in_images
199 .iter()
200 .max()
201 .expect("No max images per batch!");
202
203 for seq in input_seqs.iter_mut() {
204 let PreprocessedImages {
205 pixel_values,
206 pixel_attention_mask: _,
207 image_sizes: _,
208 num_img_tokens: _,
209 aspect_ratio_ids,
210 aspect_ratio_mask: _,
211 num_tiles: _,
212 image_grid_thw: _,
213 video_grid_thw: _,
214 rows: _,
215 cols: _,
216 pixel_values_list: _,
217 tgt_sizes: _,
218 image_sizes_all: _,
219 num_crops: _,
220 } = self
221 .preprocess(
222 seq.take_images()
223 .expect("Need to have images by this point."),
224 vec![],
225 config,
226 device,
227 (bs, max_num_images), )
229 .expect("Preprocessing failed");
230 pixel_values_accum.push(pixel_values);
232 aspect_ratios_accum.push(aspect_ratio_ids.unwrap());
233 }
234
235 let pixel_values = Tensor::cat(&pixel_values_accum, 0).unwrap();
236 let aspect_ratios = Tensor::cat(&aspect_ratios_accum, 0).unwrap();
237
238 let (image_h, image_w) = (
239 pixel_values.dim(D::Minus2).unwrap(),
240 pixel_values.dim(D::Minus1).unwrap(),
241 );
242 let num_patches_per_chunk =
243 (image_h / self.patch_size) * (image_w / self.patch_size) / self.downsample_ratio;
244
245 let placeholder_counts = input_seqs
246 .iter()
247 .map(|seq| seq.get_initial_prompt().match_indices(IMAGE_TOKEN).count())
248 .collect::<Vec<_>>();
249
250 let mut image_index = 0;
251 for (seq, placeholder_count) in input_seqs.iter_mut().zip(placeholder_counts) {
252 if placeholder_count == 0 {
253 continue;
254 }
255 let prompt_splits: std::str::Split<'_, &str> =
256 seq.get_initial_prompt().split(IMAGE_TOKEN);
257 let mut new_prompt = Vec::new();
258 for (local_image_index, split_part) in prompt_splits.enumerate() {
259 new_prompt.push(split_part.to_string());
260 if local_image_index < placeholder_count {
261 let tokens_for_this_image = self.prompt_split_image(
262 &aspect_ratios.i(image_index).unwrap(),
263 num_patches_per_chunk,
264 );
265 image_index += 1;
266 new_prompt.push(tokens_for_this_image);
267 }
268 }
269 let prompt = new_prompt.join("");
270
271 if !seq.multimodal.has_changed_prompt {
272 seq.set_initial_prompt(prompt.clone());
273 let toks = tokenizer
274 .encode_fast(prompt, false)
275 .expect("Detokenization failed!");
276
277 let ids = toks.get_ids().to_vec();
278 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
279 seq.multimodal.has_changed_prompt = true;
280 }
281 }
282
283 Some(pixel_values)
284 } else {
285 None
286 };
287
288 let text_models_inputs_processor::InnerInputProcessorOutput {
289 inputs:
290 text_models_inputs_processor::InputMetadata {
291 input,
292 positions,
293 context_lens,
294 position_ids,
295 paged_attn_meta,
296 flash_meta,
297 },
298 seq_indices,
299 } = if is_prompt {
300 get_prompt_input(
301 input_seqs
302 .iter()
303 .map(|seq| seq.get_toks())
304 .collect::<Vec<_>>(),
305 input_seqs,
306 device,
307 last_n_context_len,
308 return_raw_logits,
309 paged_attn_metadata.as_mut(),
310 None, mapper,
312 )
313 .nth(0)
314 .unwrap()
315 .unwrap()
316 } else {
317 get_completion_input(
318 input_seqs
319 .iter()
320 .map(|seq| seq.get_toks())
321 .collect::<Vec<_>>(),
322 input_seqs,
323 device,
324 no_kv_cache,
325 last_n_context_len,
326 return_raw_logits,
327 paged_attn_metadata.as_mut(),
328 None, mapper,
330 )
331 .nth(0)
332 .unwrap()
333 .unwrap()
334 };
335
336 let inputs: Box<dyn Any> = Box::new(ModelInputs {
337 input_ids: input,
338 seqlen_offsets: positions,
339 context_lens,
340 position_ids,
341 pixel_values,
342 model_specific_args: Box::new(Llama4ModelSpecificArgs),
343 paged_attn_meta,
344 flash_meta,
345 });
346 Box::new(std::iter::once(Ok(InputProcessorOutput {
347 inputs,
348 seq_indices,
349 })))
350 }
351}
352
353impl Llama4ImageProcessor {
354 fn get_factors(dividend: u32) -> HashSet<u32> {
355 let mut factors_set = HashSet::new();
356
357 let sqrt = (dividend as f64).sqrt() as u32;
358 for i in 1..=sqrt {
359 if dividend % i == 0 {
360 factors_set.insert(i);
361 factors_set.insert(dividend / i);
362 }
363 }
364
365 factors_set
366 }
367
368 fn find_supported_resolutions(
369 &self,
370 max_num_chunks: usize,
371 size: &HashMap<String, u32>,
372 ) -> Result<Vec<(u32, u32)>> {
373 let height = size["height"];
374 let width = size["width"];
375 if height != width {
376 candle_core::bail!("Expected config size height==width ({height}!={width})");
377 }
378
379 let patch_size = height;
380
381 let mut asp_map = HashMap::new();
382 for chunk_size in (0..max_num_chunks).rev() {
383 let factors = Self::get_factors(chunk_size as u32);
384 let asp_ratios = factors
385 .into_iter()
386 .sorted()
387 .map(|factors| (factors, chunk_size as u32 / factors));
388 for (h, w) in asp_ratios {
389 let ratio_float = h as f32 / w as f32;
390 asp_map
391 .entry(NotNan::new(ratio_float).context("f32 is NaN")?)
392 .or_insert_with(Vec::new)
393 .push((h, w));
394 }
395 }
396
397 let possible_resolutions = asp_map
399 .into_values()
400 .flatten()
401 .map(|(height, depth)| (height * patch_size, depth * patch_size))
402 .collect::<Vec<_>>();
403
404 Ok(possible_resolutions)
405 }
406
407 #[allow(clippy::type_complexity)]
408 fn group_images_by_shape(
409 &self,
410 images: &[Tensor],
411 ) -> Result<(
412 HashMap<(usize, usize), Tensor>,
413 HashMap<usize, ((usize, usize), usize)>,
414 )> {
415 let mut grouped_images = HashMap::new();
416 let mut grouped_images_index = HashMap::new();
417 for (i, image) in images.iter().enumerate() {
418 let (_c, h, w) = image.dims3()?;
419 let shape = (h, w);
420 grouped_images
421 .entry(shape)
422 .or_insert_with(Vec::new)
423 .push(image.clone());
424 grouped_images_index.insert(i, (shape, grouped_images[&shape].len() - 1));
425 }
426 let mut grouped_images_stack = HashMap::new();
428 for (shape, images) in grouped_images {
429 grouped_images_stack.insert(shape, Tensor::stack(&images, 0)?);
430 }
431
432 Ok((grouped_images_stack, grouped_images_index))
433 }
434
435 fn get_best_fit(
436 &self,
437 (original_height, original_width): (u32, u32),
438 possible_resolutions: Vec<(u32, u32)>,
439 resize_to_max_canvas: bool,
440 ) -> Result<(u32, u32)> {
441 let (target_heights, target_widths): (Vec<u32>, Vec<u32>) =
443 possible_resolutions.iter().copied().unzip();
444
445 let scale_w = target_widths
447 .iter()
448 .map(|tw| *tw as f32 / original_width as f32);
449 let scale_h = target_heights
450 .iter()
451 .map(|th| *th as f32 / original_height as f32);
452
453 let scales = scale_w.zip(scale_h).map(|(w, h)| if h > w { w } else { h });
455
456 let upscaling_options = scales
458 .clone()
459 .filter(|s| *s >= 1.)
460 .map(|x| NotNan::new(x).unwrap())
461 .collect::<Vec<_>>();
462 let downscaling_options = scales
463 .clone()
464 .filter(|s| *s < 1.)
465 .map(|x| NotNan::new(x).unwrap())
466 .collect::<Vec<_>>();
467 let selected_scale = if !upscaling_options.is_empty() {
468 if resize_to_max_canvas {
469 upscaling_options.into_iter().max().unwrap().into_inner()
470 } else {
471 upscaling_options.into_iter().min().unwrap().into_inner()
472 }
473 } else {
474 downscaling_options.into_iter().max().unwrap().into_inner()
476 };
477
478 let chosen_canvas = possible_resolutions
483 .into_iter()
484 .zip(scales)
485 .filter_map(|(possible, scale)| {
486 if scale == selected_scale {
487 Some(possible)
488 } else {
489 None
490 }
491 })
492 .sorted_by_key(|(h, w)| h * w)
493 .take(1)
494 .collect::<Vec<_>>()[0];
495
496 Ok(chosen_canvas)
497 }
498
499 fn get_max_res_without_distortion(
500 &self,
501 image_size: (u32, u32),
502 target_size: (u32, u32),
503 ) -> (u32, u32) {
504 let (original_height, original_width) = image_size;
505 let (target_height, target_width) = target_size;
506
507 let scale_w = target_width as f64 / original_width as f64;
508 let scale_h = target_height as f64 / original_height as f64;
509
510 if scale_w < scale_h {
511 let new_width = target_width;
512 let new_height = std::cmp::min(
514 (original_height as f64 * scale_w).floor() as u32,
515 target_height,
516 );
517 (new_height, new_width)
518 } else {
519 let new_height = target_height;
520 let new_width = std::cmp::min(
522 (original_width as f64 * scale_h).floor() as u32,
523 target_width,
524 );
525 (new_height, new_width)
526 }
527 }
528
529 fn split_to_tiles(
530 &self,
531 images: &Tensor,
532 num_tiles_h: usize,
533 num_tiles_w: usize,
534 ) -> Result<Tensor> {
535 let (bs, c, h, w) = images.dims4()?;
536 let mut images = images.reshape((
537 bs,
538 c,
539 num_tiles_h,
540 h / num_tiles_h,
541 num_tiles_w,
542 w / num_tiles_w,
543 ))?;
544 images = images.permute((0, 2, 4, 1, 3, 5))?.contiguous()?;
545 images.reshape((
546 bs,
547 num_tiles_h * num_tiles_w,
548 c,
549 h / num_tiles_h,
550 w / num_tiles_w,
551 ))
552 }
553
554 fn reorder_images(
555 &self,
556 processed_images: HashMap<(usize, usize), Tensor>,
557 grouped_images_index: HashMap<usize, ((usize, usize), usize)>,
558 ) -> Result<Vec<Tensor>> {
559 grouped_images_index
560 .values()
561 .map(|(k, v)| processed_images[k].i(*v))
562 .collect::<Result<Vec<Tensor>>>()
563 }
564}
565
566impl ImagePreProcessor for Llama4ImageProcessor {
567 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
568 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
569
570 fn preprocess(
571 &self,
572 images_d: Vec<DynamicImage>,
573 videos: Vec<Vec<DynamicImage>>,
574 config: &PreProcessorConfig,
575 device: &Device,
576 (_bs, _max_num_images): (usize, usize),
577 ) -> Result<PreprocessedImages> {
578 assert!(videos.is_empty());
579
580 let max_patches = config.max_patches.unwrap_or(16);
581 let size = config.size.clone().unwrap_or(HashMap::from_iter([
582 ("height".to_string(), 336),
583 ("width".to_string(), 336),
584 ]));
585 let resize_to_max_canvas = config.resize_to_max_canvas.unwrap_or(false);
586 let do_rescale = config.do_rescale.unwrap_or(true);
587 let do_normalize = config.do_normalize.unwrap_or(true);
588
589 let possible_resolutions = self.find_supported_resolutions(max_patches, &size)?;
590
591 let mut images = Vec::new();
592 for mut image in images_d {
593 if config.do_convert_rgb.unwrap_or(true) {
595 image = DynamicImage::ImageRgb8(image.to_rgb8());
596 }
597
598 let to_tensor_rescale = Transforms {
599 input: &ToTensorNoNorm,
600 inner_transforms: &[],
601 };
602 let image = image.apply(to_tensor_rescale, device)?;
603 images.push(image);
604 }
605
606 let (grouped_images, grouped_images_index) = self.group_images_by_shape(&images)?;
607
608 let mut grouped_processed_images = HashMap::new();
609 let mut grouped_aspect_ratios = HashMap::new();
610 for (shape, stacked_images) in grouped_images {
611 let image_size = (
612 stacked_images.dim(D::Minus2)? as u32,
613 stacked_images.dim(D::Minus1)? as u32,
614 );
615 let target_size = self.get_best_fit(
616 image_size,
617 possible_resolutions.clone(),
618 resize_to_max_canvas,
619 )?;
620 let max_upscaling_size = if resize_to_max_canvas {
622 None
623 } else {
624 Some(size["height"])
625 };
626 let target_size_without_distortion =
627 if let Some(max_upscaling_size) = max_upscaling_size {
628 let nt_h = image_size.0.max(max_upscaling_size).min(target_size.0);
629 let nt_w = image_size.1.max(max_upscaling_size).min(target_size.1);
630 (nt_h, nt_w)
631 } else {
632 candle_core::bail!("Currently resize_to_max_canvas is assumed!");
633 };
634
635 let new_size_without_distortion =
637 self.get_max_res_without_distortion(image_size, target_size_without_distortion);
638 let mut processed_images = stacked_images.interpolate2d(
639 new_size_without_distortion.0.max(1) as usize,
640 new_size_without_distortion.1.max(1) as usize,
641 )?;
642
643 processed_images = {
645 let (target_h, target_w) = target_size;
646 let (h, w) = (
647 processed_images.dim(D::Minus2)?,
648 processed_images.dim(D::Minus1)?,
649 );
650 let paste_x_r = target_w as usize - w;
651 let paste_y_r = target_h as usize - h;
652 processed_images
653 .pad_with_zeros(D::Minus2, 0, paste_y_r)?
654 .pad_with_zeros(D::Minus1, 0, paste_x_r)?
655 };
656
657 let rescale_and_norm_transforms = TensorTransforms {
658 inner_transforms: &[
659 &do_rescale.then_some(Rescale {
660 factor: config.rescale_factor,
661 }),
662 &do_normalize.then_some(Normalize {
663 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
664 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
665 }),
666 ],
667 };
668 processed_images = <Tensor as ApplyTensorTransforms>::apply(
669 &processed_images,
670 rescale_and_norm_transforms,
671 device,
672 )?;
673
674 let (ratio_h, ratio_w) = (
675 target_size.0 / size["height"],
676 target_size.1 / size["width"],
677 );
678 processed_images =
680 self.split_to_tiles(&processed_images, ratio_h as usize, ratio_w as usize)?;
681 grouped_processed_images.insert(shape, processed_images.clone());
682 grouped_aspect_ratios.insert(
683 shape,
684 Tensor::new(vec![vec![ratio_h, ratio_w]; stacked_images.dim(0)?], device)?,
685 );
686
687 if ratio_h * ratio_w > 1 {
689 let mut global_tiles = stacked_images
690 .interpolate2d(size["height"] as usize, size["width"] as usize)?;
691 global_tiles = <Tensor as ApplyTensorTransforms>::apply(
692 &global_tiles,
693 rescale_and_norm_transforms,
694 device,
695 )?;
696 grouped_processed_images.insert(
697 shape,
698 Tensor::cat(&[processed_images, global_tiles.unsqueeze(1)?], 1)?,
699 );
700 }
701 }
702
703 let processed_images =
704 self.reorder_images(grouped_processed_images, grouped_images_index.clone())?;
705 let aspect_ratios_list =
706 self.reorder_images(grouped_aspect_ratios, grouped_images_index.clone())?;
707
708 let processed_images = Tensor::cat(&processed_images, 0)?;
709 let aspect_ratios = Tensor::stack(&aspect_ratios_list, 0)?;
710
711 Ok(PreprocessedImages {
712 pixel_values: processed_images,
713 pixel_attention_mask: None,
714 image_sizes: None,
715 num_img_tokens: None,
716 aspect_ratio_ids: Some(aspect_ratios),
717 aspect_ratio_mask: None,
718 num_tiles: None,
719 image_grid_thw: None,
720 video_grid_thw: None,
721 rows: None,
722 cols: None,
723 pixel_values_list: None,
724 tgt_sizes: None,
725 image_sizes_all: None,
726 num_crops: None,
727 })
728 }
729}