1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 any::Any,
5 collections::{HashMap, HashSet},
6 num::NonZeroUsize,
7 sync::Arc,
8};
9
10use candle_core::{Context, Device, IndexOp, Result, Tensor, D};
11use image::DynamicImage;
12use itertools::Itertools;
13use mistralrs_vision::{
14 ApplyTensorTransforms, ApplyTransforms, Normalize, Rescale, TensorTransforms, ToTensorNoNorm,
15 Transforms,
16};
17use ordered_float::NotNan;
18use tokenizers::Tokenizer;
19use tracing::warn;
20
21use crate::{
22 device_map::DeviceMapper,
23 pipeline::{
24 text_models_inputs_processor::{
25 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
26 },
27 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
28 },
29 sequence::Sequence,
30 vision_models::{
31 image_processor::{ImagePreProcessor, PreprocessedImages},
32 preprocessor_config::PreProcessorConfig,
33 processor_config::ProcessorConfig,
34 ModelInputs,
35 },
36};
37
38use super::Llama4ModelSpecificArgs;
39
40pub(crate) const IMAGE_TOKEN: &str = "<|image|>";
41const IMAGE_START: &str = "<|image_start|>";
42const IMAGE_END: &str = "<|image_end|>";
43const PATCH: &str = "<|patch|>";
44const TILE_X_SEP: &str = "<|tile_x_separator|>";
45const TILE_Y_SEP: &str = "<|tile_y_separator|>";
46
47pub struct Llama4ImageProcessor {
49 pub patch_size: usize,
50 pub downsample_ratio: usize,
51}
52
53impl Llama4ImageProcessor {
54 pub fn new(patch_size: Option<usize>, pixel_shuffle_ratio: Option<f32>) -> Self {
55 Self {
56 patch_size: patch_size.unwrap_or(14),
57 downsample_ratio: (1. / pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round() as usize,
58 }
59 }
60}
61
62pub struct Llama4Processor {
64 patch_size: usize,
65 downsample_ratio: usize,
66}
67
68impl Llama4Processor {
69 pub fn new(cfg: &ProcessorConfig) -> Self {
70 Self {
71 patch_size: cfg.patch_size.unwrap_or(14),
72 downsample_ratio: (1. / cfg.pixel_shuffle_ratio.unwrap_or(0.5).powi(2)).round()
73 as usize,
74 }
75 }
76}
77
78impl Processor for Llama4Processor {
79 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
80 Arc::new(Llama4ImageProcessor {
81 patch_size: self.patch_size,
82 downsample_ratio: self.downsample_ratio,
83 })
84 }
85
86 fn get_special_tokens(&self) -> &[&'static str] {
87 &[
88 IMAGE_START,
89 IMAGE_END,
90 PATCH,
91 TILE_X_SEP,
92 TILE_Y_SEP,
93 IMAGE_TOKEN,
94 ]
95 }
96
97 fn template_action(&self) -> MessagesAction {
98 MessagesAction::FlattenOnlyText
99 }
100}
101
102impl Llama4ImageProcessor {
103 fn prompt_split_image(&self, aspect_ratio: &Tensor, num_patches_per_chunk: usize) -> String {
104 let mut img_string = IMAGE_START.to_string();
105 let aspect_ratio = aspect_ratio.to_vec1::<u32>().unwrap();
106 let (ratio_h, ratio_w) = (aspect_ratio[0] as usize, aspect_ratio[1] as usize);
107 if ratio_h * ratio_w > 1 {
108 for _yy in 0..ratio_h {
109 for xx in 0..ratio_w {
110 img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
111 if xx < ratio_w - 1 {
112 img_string.push_str(TILE_X_SEP);
113 }
114 }
115 img_string.push_str(TILE_Y_SEP);
116 }
117 }
118 img_string.push_str(IMAGE_TOKEN);
119 img_string.push_str(&PATCH.repeat(num_patches_per_chunk));
120 img_string.push_str(IMAGE_END);
121 img_string
122 }
123}
124
125impl InputsProcessor for Llama4ImageProcessor {
126 fn get_type(&self) -> InputsProcessorType {
127 InputsProcessorType::Vision
128 }
129 fn process_inputs(
130 &self,
131 tokenizer: Option<Arc<Tokenizer>>,
132 input_seqs: &mut [&mut Sequence],
133 is_prompt: bool,
134 is_xlora: bool,
135 device: &Device,
136 no_kv_cache: bool,
137 last_n_context_len: Option<(usize, usize)>,
138 return_raw_logits: bool,
139 other_config: Option<Arc<dyn Any>>,
140 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
141 prompt_chunksize: Option<NonZeroUsize>,
142 mapper: Option<&dyn DeviceMapper>,
143 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
144 if is_xlora {
145 return Box::new(std::iter::once(Err(anyhow::Error::msg(
146 "Cannot make inputs for X-LoRA vision model.",
147 ))));
148 }
149 if no_kv_cache {
150 return Box::new(std::iter::once(Err(anyhow::Error::msg(
151 "Vision model must have kv cache.",
152 ))));
153 }
154 if prompt_chunksize.is_some() {
156 warn!("`prompt_chunksize` is set. Llama4 does not support prompt batching.");
157 }
158 let Some(tokenizer) = tokenizer else {
159 return Box::new(std::iter::once(Err(anyhow::Error::msg(
160 "Llama4InputProcessor requires a specified tokenizer.",
161 ))));
162 };
163
164 let config = other_config.expect("Need a PreProcessorConfig config.");
165 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
166
167 let has_images = input_seqs.iter().all(|seq| seq.has_images());
168
169 let pixel_values = if has_images {
170 let mut pixel_values_accum = Vec::new();
171 let mut aspect_ratios_accum = Vec::new();
172
173 let bs = input_seqs.len();
174 let detokenized = tokenizer
175 .decode_batch(
176 &input_seqs
177 .iter()
178 .map(|seq| seq.get_toks())
179 .collect::<Vec<_>>(),
180 false,
181 )
182 .expect("Detokenization failed!");
183 let n_images_in_text = detokenized
184 .iter()
185 .map(|text| text.matches(IMAGE_TOKEN).count())
186 .collect::<Vec<_>>();
187 let n_images_in_images = input_seqs
188 .iter()
189 .map(|seq| seq.images().map(|imgs| imgs.len()).unwrap_or(0))
190 .collect::<Vec<_>>();
191
192 if n_images_in_text != n_images_in_images {
193 return Box::new(std::iter::once(Err(anyhow::Error::msg(format!(
194 "The number of images in each batch {n_images_in_text:?} should be the same as the number of images {n_images_in_images:?}. The model cannot support a different number of images per patch. Perhaps you forgot a `<|image|>` tag?"
195 )))));
196 }
197
198 let max_num_images = *n_images_in_images
199 .iter()
200 .max()
201 .expect("No max images per batch!");
202
203 for seq in input_seqs.iter_mut() {
204 let PreprocessedImages {
205 pixel_values,
206 pixel_attention_mask: _,
207 image_sizes: _,
208 num_img_tokens: _,
209 aspect_ratio_ids,
210 aspect_ratio_mask: _,
211 num_tiles: _,
212 image_grid_thw: _,
213 video_grid_thw: _,
214 rows: _,
215 cols: _,
216 pixel_values_list: _,
217 tgt_sizes: _,
218 image_sizes_all: _,
219 num_crops: _,
220 } = self
221 .preprocess(
222 seq.take_images()
223 .expect("Need to have images by this point."),
224 vec![],
225 config,
226 device,
227 (bs, max_num_images), )
229 .expect("Preprocessing failed");
230 pixel_values_accum.push(pixel_values);
232 aspect_ratios_accum.push(aspect_ratio_ids.unwrap());
233 }
234
235 let pixel_values = Tensor::cat(&pixel_values_accum, 0).unwrap();
236 let aspect_ratios = Tensor::cat(&aspect_ratios_accum, 0).unwrap();
237
238 let (image_h, image_w) = (
239 pixel_values.dim(D::Minus2).unwrap(),
240 pixel_values.dim(D::Minus1).unwrap(),
241 );
242 let num_patches_per_chunk =
243 (image_h / self.patch_size) * (image_w / self.patch_size) / self.downsample_ratio;
244
245 let placeholder_counts = input_seqs
246 .iter()
247 .map(|seq| seq.get_initial_prompt().match_indices(IMAGE_TOKEN).count())
248 .collect::<Vec<_>>();
249
250 let mut image_index = 0;
251 for (seq, placeholder_count) in input_seqs.iter_mut().zip(placeholder_counts) {
252 if placeholder_count == 0 {
253 continue;
254 }
255 let prompt_splits: std::str::Split<'_, &str> =
256 seq.get_initial_prompt().split(IMAGE_TOKEN);
257 let mut new_prompt = Vec::new();
258 for (local_image_index, split_part) in prompt_splits.enumerate() {
259 new_prompt.push(split_part.to_string());
260 if local_image_index < placeholder_count {
261 let tokens_for_this_image = self.prompt_split_image(
262 &aspect_ratios.i(image_index).unwrap(),
263 num_patches_per_chunk,
264 );
265 image_index += 1;
266 new_prompt.push(tokens_for_this_image);
267 }
268 }
269 let prompt = new_prompt.join("");
270
271 seq.set_initial_prompt(prompt.clone());
272 let toks = tokenizer
273 .encode_fast(prompt, false)
274 .expect("Detokenization failed!");
275
276 let ids = toks.get_ids().to_vec();
277 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
278 }
279
280 Some(pixel_values)
281 } else {
282 None
283 };
284
285 let text_models_inputs_processor::InnerInputProcessorOutput {
286 inputs:
287 text_models_inputs_processor::InputMetadata {
288 input,
289 positions,
290 context_lens,
291 position_ids,
292 paged_attn_meta,
293 flash_meta,
294 },
295 seq_indices,
296 } = if is_prompt {
297 get_prompt_input(
298 input_seqs
299 .iter()
300 .map(|seq| seq.get_toks().to_vec())
301 .collect::<Vec<_>>(),
302 input_seqs,
303 device,
304 last_n_context_len,
305 return_raw_logits,
306 paged_attn_metadata.as_mut(),
307 None, mapper,
309 )
310 .nth(0)
311 .unwrap()
312 .unwrap()
313 } else {
314 get_completion_input(
315 input_seqs
316 .iter()
317 .map(|seq| seq.get_toks().to_vec())
318 .collect::<Vec<_>>(),
319 input_seqs,
320 device,
321 no_kv_cache,
322 last_n_context_len,
323 return_raw_logits,
324 paged_attn_metadata.as_mut(),
325 None, mapper,
327 )
328 .nth(0)
329 .unwrap()
330 .unwrap()
331 };
332
333 let inputs: Box<dyn Any> = Box::new(ModelInputs {
334 input_ids: input,
335 seqlen_offsets: positions,
336 context_lens,
337 position_ids,
338 pixel_values,
339 model_specific_args: Box::new(Llama4ModelSpecificArgs),
340 paged_attn_meta,
341 flash_meta,
342 });
343 Box::new(std::iter::once(Ok(InputProcessorOutput {
344 inputs,
345 seq_indices,
346 })))
347 }
348}
349
350impl Llama4ImageProcessor {
351 fn get_factors(dividend: u32) -> HashSet<u32> {
352 let mut factors_set = HashSet::new();
353
354 let sqrt = (dividend as f64).sqrt() as u32;
355 for i in 1..=sqrt {
356 if dividend % i == 0 {
357 factors_set.insert(i);
358 factors_set.insert(dividend / i);
359 }
360 }
361
362 factors_set
363 }
364
365 fn find_supported_resolutions(
366 &self,
367 max_num_chunks: usize,
368 size: &HashMap<String, u32>,
369 ) -> Result<Vec<(u32, u32)>> {
370 let height = size["height"];
371 let width = size["width"];
372 if height != width {
373 candle_core::bail!("Expected config size height==width ({height}!={width})");
374 }
375
376 let patch_size = height;
377
378 let mut asp_map = HashMap::new();
379 for chunk_size in (0..max_num_chunks).rev() {
380 let factors = Self::get_factors(chunk_size as u32);
381 let asp_ratios = factors
382 .into_iter()
383 .sorted()
384 .map(|factors| (factors, chunk_size as u32 / factors));
385 for (h, w) in asp_ratios {
386 let ratio_float = h as f32 / w as f32;
387 asp_map
388 .entry(NotNan::new(ratio_float).context("f32 is NaN")?)
389 .or_insert_with(Vec::new)
390 .push((h, w));
391 }
392 }
393
394 let possible_resolutions = asp_map
396 .into_values()
397 .flatten()
398 .map(|(height, depth)| (height * patch_size, depth * patch_size))
399 .collect::<Vec<_>>();
400
401 Ok(possible_resolutions)
402 }
403
404 #[allow(clippy::type_complexity)]
405 fn group_images_by_shape(
406 &self,
407 images: &[Tensor],
408 ) -> Result<(
409 HashMap<(usize, usize), Tensor>,
410 HashMap<usize, ((usize, usize), usize)>,
411 )> {
412 let mut grouped_images = HashMap::new();
413 let mut grouped_images_index = HashMap::new();
414 for (i, image) in images.iter().enumerate() {
415 let (_c, h, w) = image.dims3()?;
416 let shape = (h, w);
417 grouped_images
418 .entry(shape)
419 .or_insert_with(Vec::new)
420 .push(image.clone());
421 grouped_images_index.insert(i, (shape, grouped_images[&shape].len() - 1));
422 }
423 let mut grouped_images_stack = HashMap::new();
425 for (shape, images) in grouped_images {
426 grouped_images_stack.insert(shape, Tensor::stack(&images, 0)?);
427 }
428
429 Ok((grouped_images_stack, grouped_images_index))
430 }
431
432 fn get_best_fit(
433 &self,
434 (original_height, original_width): (u32, u32),
435 possible_resolutions: Vec<(u32, u32)>,
436 resize_to_max_canvas: bool,
437 ) -> Result<(u32, u32)> {
438 let (target_heights, target_widths): (Vec<u32>, Vec<u32>) =
440 possible_resolutions.iter().copied().unzip();
441
442 let scale_w = target_widths
444 .iter()
445 .map(|tw| *tw as f32 / original_width as f32);
446 let scale_h = target_heights
447 .iter()
448 .map(|th| *th as f32 / original_height as f32);
449
450 let scales = scale_w.zip(scale_h).map(|(w, h)| if h > w { w } else { h });
452
453 let upscaling_options = scales
455 .clone()
456 .filter(|s| *s >= 1.)
457 .map(|x| NotNan::new(x).unwrap())
458 .collect::<Vec<_>>();
459 let downscaling_options = scales
460 .clone()
461 .filter(|s| *s < 1.)
462 .map(|x| NotNan::new(x).unwrap())
463 .collect::<Vec<_>>();
464 let selected_scale = if !upscaling_options.is_empty() {
465 if resize_to_max_canvas {
466 upscaling_options.into_iter().max().unwrap().into_inner()
467 } else {
468 upscaling_options.into_iter().min().unwrap().into_inner()
469 }
470 } else {
471 downscaling_options.into_iter().max().unwrap().into_inner()
473 };
474
475 let chosen_canvas = possible_resolutions
480 .into_iter()
481 .zip(scales)
482 .filter_map(|(possible, scale)| {
483 if scale == selected_scale {
484 Some(possible)
485 } else {
486 None
487 }
488 })
489 .sorted_by_key(|(h, w)| h * w)
490 .take(1)
491 .collect::<Vec<_>>()[0];
492
493 Ok(chosen_canvas)
494 }
495
496 fn get_max_res_without_distortion(
497 &self,
498 image_size: (u32, u32),
499 target_size: (u32, u32),
500 ) -> (u32, u32) {
501 let (original_height, original_width) = image_size;
502 let (target_height, target_width) = target_size;
503
504 let scale_w = target_width as f64 / original_width as f64;
505 let scale_h = target_height as f64 / original_height as f64;
506
507 if scale_w < scale_h {
508 let new_width = target_width;
509 let new_height = std::cmp::min(
511 (original_height as f64 * scale_w).floor() as u32,
512 target_height,
513 );
514 (new_height, new_width)
515 } else {
516 let new_height = target_height;
517 let new_width = std::cmp::min(
519 (original_width as f64 * scale_h).floor() as u32,
520 target_width,
521 );
522 (new_height, new_width)
523 }
524 }
525
526 fn split_to_tiles(
527 &self,
528 images: &Tensor,
529 num_tiles_h: usize,
530 num_tiles_w: usize,
531 ) -> Result<Tensor> {
532 let (bs, c, h, w) = images.dims4()?;
533 let mut images = images.reshape((
534 bs,
535 c,
536 num_tiles_h,
537 h / num_tiles_h,
538 num_tiles_w,
539 w / num_tiles_w,
540 ))?;
541 images = images.permute((0, 2, 4, 1, 3, 5))?.contiguous()?;
542 images.reshape((
543 bs,
544 num_tiles_h * num_tiles_w,
545 c,
546 h / num_tiles_h,
547 w / num_tiles_w,
548 ))
549 }
550
551 fn reorder_images(
552 &self,
553 processed_images: HashMap<(usize, usize), Tensor>,
554 grouped_images_index: HashMap<usize, ((usize, usize), usize)>,
555 ) -> Result<Vec<Tensor>> {
556 grouped_images_index
557 .values()
558 .map(|(k, v)| processed_images[k].i(*v))
559 .collect::<Result<Vec<Tensor>>>()
560 }
561}
562
563impl ImagePreProcessor for Llama4ImageProcessor {
564 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
565 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
566
567 fn preprocess(
568 &self,
569 images_d: Vec<DynamicImage>,
570 videos: Vec<Vec<DynamicImage>>,
571 config: &PreProcessorConfig,
572 device: &Device,
573 (_bs, _max_num_images): (usize, usize),
574 ) -> Result<PreprocessedImages> {
575 assert!(videos.is_empty());
576
577 let max_patches = config.max_patches.unwrap_or(16);
578 let size = config.size.clone().unwrap_or(HashMap::from_iter([
579 ("height".to_string(), 336),
580 ("width".to_string(), 336),
581 ]));
582 let resize_to_max_canvas = config.resize_to_max_canvas.unwrap_or(false);
583 let do_rescale = config.do_rescale.unwrap_or(true);
584 let do_normalize = config.do_normalize.unwrap_or(true);
585
586 let possible_resolutions = self.find_supported_resolutions(max_patches, &size)?;
587
588 let mut images = Vec::new();
589 for mut image in images_d {
590 if config.do_convert_rgb.unwrap_or(true) {
592 image = DynamicImage::ImageRgb8(image.to_rgb8());
593 }
594
595 let to_tensor_rescale = Transforms {
596 input: &ToTensorNoNorm,
597 inner_transforms: &[],
598 };
599 let image = image.apply(to_tensor_rescale, device)?;
600 images.push(image);
601 }
602
603 let (grouped_images, grouped_images_index) = self.group_images_by_shape(&images)?;
604
605 let mut grouped_processed_images = HashMap::new();
606 let mut grouped_aspect_ratios = HashMap::new();
607 for (shape, stacked_images) in grouped_images {
608 let image_size = (
609 stacked_images.dim(D::Minus2)? as u32,
610 stacked_images.dim(D::Minus1)? as u32,
611 );
612 let target_size = self.get_best_fit(
613 image_size,
614 possible_resolutions.clone(),
615 resize_to_max_canvas,
616 )?;
617 let max_upscaling_size = if resize_to_max_canvas {
619 None
620 } else {
621 Some(size["height"])
622 };
623 let target_size_without_distortion =
624 if let Some(max_upscaling_size) = max_upscaling_size {
625 let nt_h = image_size.0.max(max_upscaling_size).min(target_size.0);
626 let nt_w = image_size.1.max(max_upscaling_size).min(target_size.1);
627 (nt_h, nt_w)
628 } else {
629 candle_core::bail!("Currently resize_to_max_canvas is assumed!");
630 };
631
632 let new_size_without_distortion =
634 self.get_max_res_without_distortion(image_size, target_size_without_distortion);
635 let mut processed_images = stacked_images.interpolate2d(
636 new_size_without_distortion.0.max(1) as usize,
637 new_size_without_distortion.1.max(1) as usize,
638 )?;
639
640 processed_images = {
642 let (target_h, target_w) = target_size;
643 let (h, w) = (
644 processed_images.dim(D::Minus2)?,
645 processed_images.dim(D::Minus1)?,
646 );
647 let paste_x_r = target_w as usize - w;
648 let paste_y_r = target_h as usize - h;
649 processed_images
650 .pad_with_zeros(D::Minus2, 0, paste_y_r)?
651 .pad_with_zeros(D::Minus1, 0, paste_x_r)?
652 };
653
654 let rescale_and_norm_transforms = TensorTransforms {
655 inner_transforms: &[
656 &do_rescale.then_some(Rescale {
657 factor: config.rescale_factor,
658 }),
659 &do_normalize.then_some(Normalize {
660 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
661 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
662 }),
663 ],
664 };
665 processed_images = <Tensor as ApplyTensorTransforms>::apply(
666 &processed_images,
667 rescale_and_norm_transforms,
668 device,
669 )?;
670
671 let (ratio_h, ratio_w) = (
672 target_size.0 / size["height"],
673 target_size.1 / size["width"],
674 );
675 processed_images =
677 self.split_to_tiles(&processed_images, ratio_h as usize, ratio_w as usize)?;
678 grouped_processed_images.insert(shape, processed_images.clone());
679 grouped_aspect_ratios.insert(
680 shape,
681 Tensor::new(vec![vec![ratio_h, ratio_w]; stacked_images.dim(0)?], device)?,
682 );
683
684 if ratio_h * ratio_w > 1 {
686 let mut global_tiles = stacked_images
687 .interpolate2d(size["height"] as usize, size["width"] as usize)?;
688 global_tiles = <Tensor as ApplyTensorTransforms>::apply(
689 &global_tiles,
690 rescale_and_norm_transforms,
691 device,
692 )?;
693 grouped_processed_images.insert(
694 shape,
695 Tensor::cat(&[processed_images, global_tiles.unsqueeze(1)?], 1)?,
696 );
697 }
698 }
699
700 let processed_images =
701 self.reorder_images(grouped_processed_images, grouped_images_index.clone())?;
702 let aspect_ratios_list =
703 self.reorder_images(grouped_aspect_ratios, grouped_images_index.clone())?;
704
705 let processed_images = Tensor::cat(&processed_images, 0)?;
706 let aspect_ratios = Tensor::stack(&aspect_ratios_list, 0)?;
707
708 Ok(PreprocessedImages {
709 pixel_values: processed_images,
710 pixel_attention_mask: None,
711 image_sizes: None,
712 num_img_tokens: None,
713 aspect_ratio_ids: Some(aspect_ratios),
714 aspect_ratio_mask: None,
715 num_tiles: None,
716 image_grid_thw: None,
717 video_grid_thw: None,
718 rows: None,
719 cols: None,
720 pixel_values_list: None,
721 tgt_sizes: None,
722 image_sizes_all: None,
723 num_crops: None,
724 })
725 }
726}