1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13 device_map::DeviceMapper,
14 pipeline::{
15 text_models_inputs_processor::{
16 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17 },
18 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19 ProcessorCreator,
20 },
21 sequence::Sequence,
22};
23
24use crate::vision_models::{
25 image_processor::{ImagePreProcessor, PreprocessedImages},
26 phi4::Phi4MMVisionSpecificArgs,
27 preprocessor_config::PreProcessorConfig,
28 processor_config::ProcessorConfig,
29 ModelInputs,
30};
31
32use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
33
34const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
35const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
36pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
37
38pub struct Phi4MMInputsProcessor;
40pub struct Phi4MMProcessor {
42 inputs_processor: Arc<Phi4MMInputsProcessor>,
43}
44
45impl ProcessorCreator for Phi4MMProcessor {
46 fn new_processor(
47 _: Option<ProcessorConfig>,
48 _: PreProcessorConfig,
49 ) -> Arc<dyn Processor + Send + Sync> {
50 Arc::new(Self {
51 inputs_processor: Arc::new(Phi4MMInputsProcessor),
52 })
53 }
54}
55
56impl Processor for Phi4MMProcessor {
57 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58 self.inputs_processor.clone()
59 }
60 fn get_special_tokens(&self) -> &[&'static str] {
61 &[]
62 }
63 fn template_action(&self) -> MessagesAction {
64 MessagesAction::FlattenOnlyText
65 }
66}
67
68impl InputsProcessor for Phi4MMInputsProcessor {
69 fn get_type(&self) -> InputsProcessorType {
70 InputsProcessorType::Vision
71 }
72 fn process_inputs(
73 &self,
74 tokenizer: Option<Arc<Tokenizer>>,
75 input_seqs: &mut [&mut Sequence],
76 is_prompt: bool,
77 is_xlora: bool,
78 device: &Device,
79 no_kv_cache: bool,
80 last_n_context_len: Option<(usize, usize)>,
81 return_raw_logits: bool,
82 other_config: Option<Arc<dyn Any>>,
83 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84 prompt_chunksize: Option<NonZeroUsize>,
85 mapper: Option<&dyn DeviceMapper>,
86 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87 if is_xlora {
88 return Box::new(std::iter::once(Err(anyhow::Error::msg(
89 "Cannot make inputs for X-LoRA vision model.",
90 ))));
91 }
92 if no_kv_cache {
93 return Box::new(std::iter::once(Err(anyhow::Error::msg(
94 "Vision model must have kv cache.",
95 ))));
96 }
97 if prompt_chunksize.is_some() {
99 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100 }
101 let Some(tokenizer) = tokenizer else {
102 return Box::new(std::iter::once(Err(anyhow::Error::msg(
103 "Phi4MMInputProcessor requires a specified tokenizer.",
104 ))));
105 };
106
107 let config = other_config
108 .clone()
109 .expect("Need a PreProcessorConfig config.");
110 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112 let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114 let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
115 let mut pixel_values_accum = Vec::new();
116 let mut pixel_attention_masks_accum = Vec::new();
117 let mut image_sizes_accum = Vec::new();
118 let mut num_img_tokens_accum = Vec::new();
119 for seq in input_seqs.iter_mut() {
120 let imgs = seq
121 .take_images()
122 .expect("Need to have images by this point.");
123 let PreprocessedImages {
124 pixel_values,
125 pixel_attention_mask,
126 image_sizes: _,
127 num_img_tokens,
128 aspect_ratio_ids: _,
129 aspect_ratio_mask: _,
130 num_tiles: _,
131 image_grid_thw: _,
132 video_grid_thw: _,
133 rows: _,
134 cols: _,
135 pixel_values_list: _,
136 tgt_sizes: _,
137 image_sizes_all,
138 num_crops: _,
139 } = self
140 .preprocess(
141 imgs,
142 vec![],
143 config,
144 device,
145 (usize::MAX, usize::MAX), )
147 .expect("Preprocessor failed");
148 let image_sizes = image_sizes_all.unwrap();
149 let pixel_attention_mask = pixel_attention_mask.unwrap();
150 pixel_values_accum.push(pixel_values);
151 pixel_attention_masks_accum.push(pixel_attention_mask);
152 image_sizes_accum.extend(image_sizes);
154 num_img_tokens_accum.push(num_img_tokens.unwrap());
155 }
156 (
157 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
158 Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
159 Some(image_sizes_accum),
160 Some(num_img_tokens_accum),
161 )
162 } else {
163 return Box::new(
164 text_models_inputs_processor::TextInputsProcessor
165 .process_inputs(
166 Some(tokenizer),
167 input_seqs,
168 is_prompt,
169 is_xlora,
170 device,
171 no_kv_cache,
172 last_n_context_len,
173 return_raw_logits,
174 other_config,
175 paged_attn_metadata,
176 None, mapper,
178 )
179 .map(|metadata| {
180 let InputProcessorOutput {
181 inputs,
182 seq_indices,
183 } = metadata?;
184
185 let text_models_inputs_processor::ModelInputs {
186 input_ids,
187 input_ids_full: _,
188 seqlen_offsets,
189 seqlen_offsets_full: _,
190 context_lens,
191 position_ids,
192 paged_attn_meta,
193 flash_meta,
194 flash_meta_full: _,
195 } = *inputs
196 .downcast::<text_models_inputs_processor::ModelInputs>()
197 .expect("Downcast failed.");
198
199 let inputs: Box<dyn Any> = Box::new(ModelInputs {
200 input_ids,
201 seqlen_offsets,
202 context_lens,
203 position_ids,
204 pixel_values: None,
205 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
206 image_sizes: None,
207 image_attention_mask: None,
208 input_image_embeds: None,
209 }),
210 paged_attn_meta,
211 flash_meta,
212 });
213 Ok(InputProcessorOutput {
214 inputs,
215 seq_indices,
216 })
217 }),
218 );
219 };
220
221 let detokenized = tokenizer
222 .decode_batch(
223 &input_seqs
224 .iter()
225 .map(|seq| seq.get_toks())
226 .collect::<Vec<_>>(),
227 false,
228 )
229 .expect("Decode failed");
230
231 let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
232
233 let mut toks = Vec::new();
234
235 for (mut detokenized, (seq, num_img_tokens)) in detokenized
236 .into_iter()
237 .zip(input_seqs.iter_mut().zip(num_img_tokens.unwrap()))
238 {
239 detokenized = img_token_pattern
240 .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
241 .to_string();
242
243 let has_changed_prompt = seq.has_changed_prompt;
244 if !has_changed_prompt {
245 seq.set_toks_and_reallocate(
246 tokenizer
247 .encode_fast(detokenized.clone(), false)
248 .expect("Encode failed")
249 .get_ids()
250 .to_vec(),
251 paged_attn_metadata.as_mut(),
252 );
253
254 seq.set_initial_prompt(detokenized);
255 }
256
257 let mut i = 0;
258 let mut image_token_count_iter = num_img_tokens.iter();
259 while i < seq.get_toks().len() {
260 let token_id = seq.get_toks()[i];
261 let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
262 image_token_count_iter.next().unwrap()
263 } else {
264 i += 1;
265 continue;
266 };
267
268 let mut new_ids = seq.get_toks()[..i].to_vec();
269 new_ids.extend(vec![token_id; *token_count]);
270 new_ids.extend(seq.get_toks()[i + 1..].to_vec());
271 if !has_changed_prompt {
272 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
273 }
274 i += token_count;
275 }
276 if !has_changed_prompt {
277 seq.has_changed_prompt = true;
278 }
279 toks.push(seq.get_toks().to_vec());
280 }
281
282 let iter = if is_prompt {
283 get_prompt_input(
284 toks,
285 input_seqs,
286 device,
287 last_n_context_len,
288 return_raw_logits,
289 paged_attn_metadata.as_mut(),
290 None, mapper,
292 )
293 } else {
294 get_completion_input(
295 toks,
296 input_seqs,
297 device,
298 no_kv_cache,
299 last_n_context_len,
300 return_raw_logits,
301 paged_attn_metadata.as_mut(),
302 None, mapper,
304 )
305 };
306
307 Box::new(iter.into_iter().map(move |metadata| {
308 let pixel_values = pixel_values.clone();
309 let pixel_attention_mask = pixel_attention_mask.clone();
310 let text_models_inputs_processor::InnerInputProcessorOutput {
311 inputs:
312 text_models_inputs_processor::InputMetadata {
313 input,
314 positions,
315 context_lens,
316 position_ids,
317 paged_attn_meta,
318 flash_meta,
319 },
320 seq_indices,
321 } = metadata?;
322 let inputs: Box<dyn Any> = Box::new(ModelInputs {
323 input_ids: input,
324 seqlen_offsets: positions,
325 context_lens,
326 position_ids,
327 pixel_values: pixel_values.clone(),
328 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
329 image_sizes: image_sizes.clone(),
330 image_attention_mask: pixel_attention_mask,
331 input_image_embeds: pixel_values,
332 }),
333 paged_attn_meta,
334 flash_meta,
335 });
336 Ok(InputProcessorOutput {
337 inputs,
338 seq_indices,
339 })
340 }))
341 }
342}
343
344impl Phi4MMInputsProcessor {
345 fn pad_image(
346 image: &DynamicImage,
347 top: u32,
348 bottom: u32,
349 left: u32,
350 right: u32,
351 pad_color: Rgba<u8>,
352 ) -> DynamicImage {
353 let new_width = image.width() + left + right;
355 let new_height = image.height() + top + bottom;
356
357 let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
359 for x in 0..new_width {
360 for y in 0..new_height {
361 new_image.put_pixel(x, y, pad_color);
362 }
363 }
364
365 new_image
367 .copy_from(image, left, top)
368 .expect("Failed to copy image");
369
370 new_image
371 }
372
373 fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
374 let mut ratios: HashSet<(u32, u32)> = HashSet::new();
375 for n in min_num..=max_num {
376 for i in 1..=n {
377 for j in 1..=n {
378 if i * j >= min_num && i * j <= max_num {
379 ratios.insert((i, j));
380 }
381 }
382 }
383 }
384 let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
385 sorted_ratios.sort_by_key(|&(i, j)| i * j);
386 sorted_ratios
387 }
388
389 fn find_closest_aspect_ratio(
390 aspect_ratio: f64,
391 target_ratios: Vec<(u32, u32)>,
392 width: u32,
393 height: u32,
394 image_size: usize,
395 ) -> (u32, u32) {
396 let mut best_ratio_diff = f64::INFINITY;
397 let mut best_ratio = (1, 1);
398 let area = width * height;
399 for ratio in target_ratios {
400 let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
401 let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
402 if ratio_diff < best_ratio_diff {
403 best_ratio_diff = ratio_diff;
404 best_ratio = ratio;
405 } else if ratio_diff == best_ratio_diff
406 && area as f64 > 0.5 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
407 {
408 best_ratio = ratio;
409 }
410 }
411 best_ratio
412 }
413
414 fn dynamic_preprocess(
415 &self,
416 mut image: DynamicImage,
417 min_num: usize,
418 max_num: usize,
419 image_size: usize,
420 mask_size: usize,
421 device: &Device,
422 ) -> Result<(DynamicImage, Tensor)> {
423 let (orig_w, orig_h) = image.dimensions();
424
425 let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
426 let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
427 let (target_aspect_ratio, target_width, target_height) =
428 if w_crop_num * h_crop_num > max_num as f64 {
429 let aspect_ratio = orig_w as f64 / orig_h as f64;
430 let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
431
432 let target_aspect_ratio = Self::find_closest_aspect_ratio(
433 aspect_ratio,
434 target_ratios,
435 orig_w,
436 orig_h,
437 image_size,
438 );
439
440 let target_width = image_size * target_aspect_ratio.0 as usize;
441 let target_height = image_size * target_aspect_ratio.1 as usize;
442
443 (
444 (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
445 target_width,
446 target_height,
447 )
448 } else {
449 let target_width = (image_size as f64 * w_crop_num) as usize;
450 let target_height = (image_size as f64 * h_crop_num) as usize;
451 let target_aspect_ratio = (w_crop_num, h_crop_num);
452
453 (target_aspect_ratio, target_width, target_height)
454 };
455
456 let ratio_width = target_width as f64 / orig_w as f64;
457 let ratio_height = target_height as f64 / orig_h as f64;
458 let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
459 (
460 (target_width, (orig_h as f64 * ratio_width) as usize),
461 0_usize,
462 target_height - (orig_h as f64 * ratio_width) as usize,
463 )
464 } else {
465 (
466 ((orig_w as f64 * ratio_height) as usize, target_height),
467 target_width - (orig_w as f64 * ratio_height) as usize,
468 0_usize,
469 )
470 };
471
472 let mut attention_mask = Tensor::ones(
473 (
474 (mask_size as f64 * target_aspect_ratio.1) as usize,
475 (mask_size as f64 * target_aspect_ratio.0) as usize,
476 ),
477 DType::U32,
478 device,
479 )?;
480 if padding_width >= 14 {
481 attention_mask = attention_mask.slice_assign(
482 &[&.., &(attention_mask.dim(1)? - padding_width / 14..)],
483 &Tensor::zeros(
484 (attention_mask.dim(0)?, padding_width / 14),
485 DType::U32,
486 device,
487 )?,
488 )?;
489 }
490 if padding_height >= 14 {
491 attention_mask = attention_mask.slice_assign(
492 &[&(attention_mask.dim(0)? - padding_height / 14..), &..],
493 &Tensor::zeros(
494 (padding_height / 14, attention_mask.dim(1)?),
495 DType::U32,
496 device,
497 )?,
498 )?;
499 }
500
501 image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
502 image = Self::pad_image(
503 &image,
504 0,
505 padding_height as u32,
506 padding_width as u32,
507 0,
508 Rgba([255u8, 255, 255, 255]),
509 );
510
511 Ok((image, attention_mask))
512 }
513}
514
515impl ImagePreProcessor for Phi4MMInputsProcessor {
516 #[allow(clippy::excessive_precision)]
517 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
518 #[allow(clippy::excessive_precision)]
519 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
520
521 fn preprocess(
522 &self,
523 mut images: Vec<DynamicImage>,
524 videos: Vec<Vec<DynamicImage>>,
525 config: &PreProcessorConfig,
526 device: &Device,
527 (_, _): (usize, usize),
528 ) -> Result<PreprocessedImages> {
529 assert!(!images.is_empty());
531 assert!(videos.is_empty());
532
533 let mut max_size = None;
535 for image in images.iter() {
536 if max_size.is_none() {
537 max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
538 } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
539 max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
540 } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
541 max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
542 }
543 }
544 let (max_h, max_w) = max_size.unwrap();
545 for image in images.iter_mut() {
546 *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
547 }
548
549 let mut image_sizes = Vec::new();
550 let mut padded_images = Vec::new();
551 let mut padded_masks = Vec::new();
552 let mut num_img_tokens = Vec::new();
553 for mut image in images {
554 if config.do_convert_rgb.unwrap_or(true) {
556 image = DynamicImage::ImageRgb8(image.to_rgb8());
557 }
558
559 let transforms = Transforms {
560 input: &ToTensor,
561 inner_transforms: &[&Normalize {
562 mean: vec![0.5, 0.5, 0.5],
563 std: vec![0.5, 0.5, 0.5],
564 }],
565 };
566 let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
568 let base_resolution = dyhd_base_resolution;
569 let mask_resolution = base_resolution / 14;
571 let min_num = 1;
572
573 let (elem, attention_mask) = self.dynamic_preprocess(
574 image,
575 min_num,
576 config.dynamic_hd.unwrap(),
577 base_resolution,
578 mask_resolution,
579 device,
580 )?;
581
582 let hd_image = elem.apply(transforms, device)?;
583 let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
584 let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
585
586 let global_image = hd_image
588 .unsqueeze(0)?
589 .interpolate2d(base_resolution, base_resolution)?;
590 let global_attention_mask =
591 Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
592
593 let hd_image_reshape = hd_image
594 .reshape((
595 1,
596 3,
597 (img_h as f32 / base_resolution as f32) as usize,
598 base_resolution,
599 (img_w as f32 / base_resolution as f32) as usize,
600 base_resolution,
601 ))?
602 .permute((0, 2, 4, 1, 3, 5))?
603 .reshape(((), 3, base_resolution, base_resolution))?;
604
605 let attention_mask_reshape = attention_mask
606 .reshape((
607 1,
608 (mask_h as f32 / mask_resolution as f32) as usize,
609 mask_resolution,
610 (mask_w as f32 / mask_resolution as f32) as usize,
611 mask_resolution,
612 ))?
613 .permute((0, 1, 3, 2, 4))?
614 .reshape(((), mask_resolution, mask_resolution))?;
615
616 let downsample_attention_mask = {
617 let h_indices =
618 Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
619 let w_indices =
620 Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
621 let selected = attention_mask_reshape
622 .index_select(&h_indices, 1)?
623 .index_select(&w_indices, 2)?;
624
625 let mask = selected
626 .reshape((
627 1,
628 mask_h / mask_resolution,
629 mask_w / mask_resolution,
630 mask_resolution / 2 + mask_resolution % 2,
631 mask_resolution / 2 + mask_resolution % 2,
632 ))?
633 .permute((0, 1, 3, 2, 4))?;
634 mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
635 };
636
637 let img_tokens = 256
638 + 1
639 + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
640 + downsample_attention_mask
641 .i((.., 0))?
642 .sum_all()?
643 .to_scalar::<u32>()? as usize
644 + 16;
645
646 let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
647 let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
648
649 image_sizes.push((img_h as u32, img_w as u32));
650 padded_images.push(hd_image_reshape);
651 padded_masks.push(hd_mask_reshape);
652 num_img_tokens.push(img_tokens);
653 }
654 Ok(PreprocessedImages {
655 pixel_values: Tensor::stack(&padded_images, 0)?,
656 pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
657 image_sizes: None,
658 num_img_tokens: Some(num_img_tokens),
659 aspect_ratio_ids: None,
660 aspect_ratio_mask: None,
661 num_tiles: None,
662 image_grid_thw: None,
663 video_grid_thw: None,
664 rows: None,
665 cols: None,
666 pixel_values_list: None,
667 tgt_sizes: None,
668 image_sizes_all: Some(image_sizes),
669 num_crops: None,
670 })
671 }
672}