1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13 device_map::DeviceMapper,
14 pipeline::{
15 text_models_inputs_processor::{
16 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17 },
18 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19 ProcessorCreator,
20 },
21 sequence::Sequence,
22};
23
24use crate::vision_models::{
25 image_processor::{ImagePreProcessor, PreprocessedImages},
26 phi4::Phi4MMVisionSpecificArgs,
27 preprocessor_config::PreProcessorConfig,
28 processor_config::ProcessorConfig,
29 ModelInputs,
30};
31
32use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
33
34const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
35const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
36pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
37
38pub struct Phi4MMInputsProcessor;
40pub struct Phi4MMProcessor {
42 inputs_processor: Arc<Phi4MMInputsProcessor>,
43}
44
45impl ProcessorCreator for Phi4MMProcessor {
46 fn new_processor(
47 _: Option<ProcessorConfig>,
48 _: PreProcessorConfig,
49 ) -> Arc<dyn Processor + Send + Sync> {
50 Arc::new(Self {
51 inputs_processor: Arc::new(Phi4MMInputsProcessor),
52 })
53 }
54}
55
56impl Processor for Phi4MMProcessor {
57 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58 self.inputs_processor.clone()
59 }
60 fn get_special_tokens(&self) -> &[&'static str] {
61 &[]
62 }
63 fn template_action(&self) -> MessagesAction {
64 MessagesAction::FlattenOnlyText
65 }
66}
67
68impl InputsProcessor for Phi4MMInputsProcessor {
69 fn get_type(&self) -> InputsProcessorType {
70 InputsProcessorType::Vision
71 }
72 fn process_inputs(
73 &self,
74 tokenizer: Option<Arc<Tokenizer>>,
75 input_seqs: &mut [&mut Sequence],
76 is_prompt: bool,
77 is_xlora: bool,
78 device: &Device,
79 no_kv_cache: bool,
80 last_n_context_len: Option<(usize, usize)>,
81 return_raw_logits: bool,
82 other_config: Option<Arc<dyn Any>>,
83 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84 prompt_chunksize: Option<NonZeroUsize>,
85 mapper: Option<&dyn DeviceMapper>,
86 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87 if is_xlora {
88 return Box::new(std::iter::once(Err(anyhow::Error::msg(
89 "Cannot make inputs for X-LoRA vision model.",
90 ))));
91 }
92 if no_kv_cache {
93 return Box::new(std::iter::once(Err(anyhow::Error::msg(
94 "Vision model must have kv cache.",
95 ))));
96 }
97 if prompt_chunksize.is_some() {
99 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100 }
101 let Some(tokenizer) = tokenizer else {
102 return Box::new(std::iter::once(Err(anyhow::Error::msg(
103 "Phi4MMInputProcessor requires a specified tokenizer.",
104 ))));
105 };
106
107 let config = other_config
108 .clone()
109 .expect("Need a PreProcessorConfig config.");
110 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112 let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114 let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
115 let mut pixel_values_accum = Vec::new();
116 let mut pixel_attention_masks_accum = Vec::new();
117 let mut image_sizes_accum = Vec::new();
118 let mut num_img_tokens_accum = Vec::new();
119 for seq in input_seqs.iter_mut() {
120 let imgs = seq
121 .take_images()
122 .expect("Need to have images by this point.");
123 let PreprocessedImages {
124 pixel_values,
125 pixel_attention_mask,
126 image_sizes: _,
127 num_img_tokens,
128 aspect_ratio_ids: _,
129 aspect_ratio_mask: _,
130 num_tiles: _,
131 image_grid_thw: _,
132 video_grid_thw: _,
133 rows: _,
134 cols: _,
135 pixel_values_list: _,
136 tgt_sizes: _,
137 image_sizes_all,
138 num_crops: _,
139 } = self
140 .preprocess(
141 imgs,
142 vec![],
143 config,
144 device,
145 (usize::MAX, usize::MAX), )
147 .expect("Preprocessor failed");
148 let image_sizes = image_sizes_all.unwrap();
149 let pixel_attention_mask = pixel_attention_mask.unwrap();
150 pixel_values_accum.push(pixel_values);
151 pixel_attention_masks_accum.push(pixel_attention_mask);
152 image_sizes_accum.extend(image_sizes);
154 num_img_tokens_accum.push(num_img_tokens.unwrap());
155 }
156 (
157 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
158 Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
159 Some(image_sizes_accum),
160 Some(num_img_tokens_accum),
161 )
162 } else {
163 return Box::new(
164 text_models_inputs_processor::TextInputsProcessor
165 .process_inputs(
166 Some(tokenizer),
167 input_seqs,
168 is_prompt,
169 is_xlora,
170 device,
171 no_kv_cache,
172 last_n_context_len,
173 return_raw_logits,
174 other_config,
175 paged_attn_metadata,
176 None, mapper,
178 )
179 .map(|metadata| {
180 let InputProcessorOutput {
181 inputs,
182 seq_indices,
183 } = metadata?;
184
185 let text_models_inputs_processor::ModelInputs {
186 input_ids,
187 input_ids_full: _,
188 seqlen_offsets,
189 seqlen_offsets_full: _,
190 context_lens,
191 position_ids,
192 paged_attn_meta,
193 flash_meta,
194 flash_meta_full: _,
195 } = *inputs
196 .downcast::<text_models_inputs_processor::ModelInputs>()
197 .expect("Downcast failed.");
198
199 let inputs: Box<dyn Any> = Box::new(ModelInputs {
200 input_ids,
201 seqlen_offsets,
202 context_lens,
203 position_ids,
204 pixel_values: None,
205 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
206 image_sizes: None,
207 image_attention_mask: None,
208 input_image_embeds: None,
209 }),
210 paged_attn_meta,
211 flash_meta,
212 });
213 Ok(InputProcessorOutput {
214 inputs,
215 seq_indices,
216 })
217 }),
218 );
219 };
220
221 let detokenized = tokenizer
222 .decode_batch(
223 &input_seqs
224 .iter()
225 .map(|seq| seq.get_toks())
226 .collect::<Vec<_>>(),
227 false,
228 )
229 .expect("Decode failed");
230
231 let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
232
233 let mut toks = Vec::new();
234
235 for (mut detokenized, (seq, num_img_tokens)) in detokenized
236 .into_iter()
237 .zip(input_seqs.iter_mut().zip(num_img_tokens.unwrap()))
238 {
239 detokenized = img_token_pattern
240 .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
241 .to_string();
242
243 seq.set_toks_and_reallocate(
244 tokenizer
245 .encode_fast(detokenized.clone(), false)
246 .expect("Encode failed")
247 .get_ids()
248 .to_vec(),
249 paged_attn_metadata.as_mut(),
250 );
251
252 seq.set_initial_prompt(detokenized);
253
254 let mut i = 0;
255 let mut image_token_count_iter = num_img_tokens.iter();
256 while i < seq.get_toks().len() {
257 let token_id = seq.get_toks()[i];
258 let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
259 image_token_count_iter.next().unwrap()
260 } else {
261 i += 1;
262 continue;
263 };
264
265 let mut new_ids = seq.get_toks()[..i].to_vec();
266 new_ids.extend(vec![token_id; *token_count]);
267 new_ids.extend(seq.get_toks()[i + 1..].to_vec());
268 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
269 i += token_count;
270 }
271 toks.push(seq.get_toks().to_vec());
272 }
273
274 let iter = if is_prompt {
275 get_prompt_input(
276 toks,
277 input_seqs,
278 device,
279 last_n_context_len,
280 return_raw_logits,
281 paged_attn_metadata.as_mut(),
282 None, mapper,
284 )
285 } else {
286 get_completion_input(
287 toks,
288 input_seqs,
289 device,
290 no_kv_cache,
291 last_n_context_len,
292 return_raw_logits,
293 paged_attn_metadata.as_mut(),
294 None, mapper,
296 )
297 };
298
299 Box::new(iter.into_iter().map(move |metadata| {
300 let pixel_values = pixel_values.clone();
301 let pixel_attention_mask = pixel_attention_mask.clone();
302 let text_models_inputs_processor::InnerInputProcessorOutput {
303 inputs:
304 text_models_inputs_processor::InputMetadata {
305 input,
306 positions,
307 context_lens,
308 position_ids,
309 paged_attn_meta,
310 flash_meta,
311 },
312 seq_indices,
313 } = metadata?;
314 let inputs: Box<dyn Any> = Box::new(ModelInputs {
315 input_ids: input,
316 seqlen_offsets: positions,
317 context_lens,
318 position_ids,
319 pixel_values: pixel_values.clone(),
320 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
321 image_sizes: image_sizes.clone(),
322 image_attention_mask: pixel_attention_mask,
323 input_image_embeds: pixel_values,
324 }),
325 paged_attn_meta,
326 flash_meta,
327 });
328 Ok(InputProcessorOutput {
329 inputs,
330 seq_indices,
331 })
332 }))
333 }
334}
335
336impl Phi4MMInputsProcessor {
337 fn pad_image(
338 image: &DynamicImage,
339 top: u32,
340 bottom: u32,
341 left: u32,
342 right: u32,
343 pad_color: Rgba<u8>,
344 ) -> DynamicImage {
345 let new_width = image.width() + left + right;
347 let new_height = image.height() + top + bottom;
348
349 let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
351 for x in 0..new_width {
352 for y in 0..new_height {
353 new_image.put_pixel(x, y, pad_color);
354 }
355 }
356
357 new_image
359 .copy_from(image, left, top)
360 .expect("Failed to copy image");
361
362 new_image
363 }
364
365 fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
366 let mut ratios: HashSet<(u32, u32)> = HashSet::new();
367 for n in min_num..=max_num {
368 for i in 1..=n {
369 for j in 1..=n {
370 if i * j >= min_num && i * j <= max_num {
371 ratios.insert((i, j));
372 }
373 }
374 }
375 }
376 let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
377 sorted_ratios.sort_by_key(|&(i, j)| i * j);
378 sorted_ratios
379 }
380
381 fn find_closest_aspect_ratio(
382 aspect_ratio: f64,
383 target_ratios: Vec<(u32, u32)>,
384 width: u32,
385 height: u32,
386 image_size: usize,
387 ) -> (u32, u32) {
388 let mut best_ratio_diff = f64::INFINITY;
389 let mut best_ratio = (1, 1);
390 let area = width * height;
391 for ratio in target_ratios {
392 let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
393 let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
394 if ratio_diff < best_ratio_diff {
395 best_ratio_diff = ratio_diff;
396 best_ratio = ratio;
397 } else if ratio_diff == best_ratio_diff
398 && area as f64 > 0.5 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
399 {
400 best_ratio = ratio;
401 }
402 }
403 best_ratio
404 }
405
406 fn dynamic_preprocess(
407 &self,
408 mut image: DynamicImage,
409 min_num: usize,
410 max_num: usize,
411 image_size: usize,
412 mask_size: usize,
413 device: &Device,
414 ) -> Result<(DynamicImage, Tensor)> {
415 let (orig_w, orig_h) = image.dimensions();
416
417 let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
418 let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
419 let (target_aspect_ratio, target_width, target_height) =
420 if w_crop_num * h_crop_num > max_num as f64 {
421 let aspect_ratio = orig_w as f64 / orig_h as f64;
422 let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
423
424 let target_aspect_ratio = Self::find_closest_aspect_ratio(
425 aspect_ratio,
426 target_ratios,
427 orig_w,
428 orig_h,
429 image_size,
430 );
431
432 let target_width = image_size * target_aspect_ratio.0 as usize;
433 let target_height = image_size * target_aspect_ratio.1 as usize;
434
435 (
436 (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
437 target_width,
438 target_height,
439 )
440 } else {
441 let target_width = (image_size as f64 * w_crop_num) as usize;
442 let target_height = (image_size as f64 * h_crop_num) as usize;
443 let target_aspect_ratio = (w_crop_num, h_crop_num);
444
445 (target_aspect_ratio, target_width, target_height)
446 };
447
448 let ratio_width = target_width as f64 / orig_w as f64;
449 let ratio_height = target_height as f64 / orig_h as f64;
450 let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
451 (
452 (target_width, (orig_h as f64 * ratio_width) as usize),
453 0_usize,
454 target_height - (orig_h as f64 * ratio_width) as usize,
455 )
456 } else {
457 (
458 ((orig_w as f64 * ratio_height) as usize, target_height),
459 target_width - (orig_w as f64 * ratio_height) as usize,
460 0_usize,
461 )
462 };
463
464 let mut attention_mask = Tensor::ones(
465 (
466 (mask_size as f64 * target_aspect_ratio.1) as usize,
467 (mask_size as f64 * target_aspect_ratio.0) as usize,
468 ),
469 DType::U32,
470 device,
471 )?;
472 if padding_width >= 14 {
473 attention_mask = attention_mask.slice_assign(
474 &[&.., &(attention_mask.dim(1)? - padding_width / 14..)],
475 &Tensor::zeros(
476 (attention_mask.dim(0)?, padding_width / 14),
477 DType::U32,
478 device,
479 )?,
480 )?;
481 }
482 if padding_height >= 14 {
483 attention_mask = attention_mask.slice_assign(
484 &[&(attention_mask.dim(0)? - padding_height / 14..), &..],
485 &Tensor::zeros(
486 (padding_height / 14, attention_mask.dim(1)?),
487 DType::U32,
488 device,
489 )?,
490 )?;
491 }
492
493 image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
494 image = Self::pad_image(
495 &image,
496 0,
497 padding_height as u32,
498 padding_width as u32,
499 0,
500 Rgba([255u8, 255, 255, 255]),
501 );
502
503 Ok((image, attention_mask))
504 }
505}
506
507impl ImagePreProcessor for Phi4MMInputsProcessor {
508 #[allow(clippy::excessive_precision)]
509 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
510 #[allow(clippy::excessive_precision)]
511 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
512
513 fn preprocess(
514 &self,
515 mut images: Vec<DynamicImage>,
516 videos: Vec<Vec<DynamicImage>>,
517 config: &PreProcessorConfig,
518 device: &Device,
519 (_, _): (usize, usize),
520 ) -> Result<PreprocessedImages> {
521 assert!(!images.is_empty());
523 assert!(videos.is_empty());
524
525 let mut max_size = None;
527 for image in images.iter() {
528 if max_size.is_none() {
529 max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
530 } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
531 max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
532 } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
533 max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
534 }
535 }
536 let (max_h, max_w) = max_size.unwrap();
537 for image in images.iter_mut() {
538 *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
539 }
540
541 let mut image_sizes = Vec::new();
542 let mut padded_images = Vec::new();
543 let mut padded_masks = Vec::new();
544 let mut num_img_tokens = Vec::new();
545 for mut image in images {
546 if config.do_convert_rgb.unwrap_or(true) {
548 image = DynamicImage::ImageRgb8(image.to_rgb8());
549 }
550
551 let transforms = Transforms {
552 input: &ToTensor,
553 inner_transforms: &[&Normalize {
554 mean: vec![0.5, 0.5, 0.5],
555 std: vec![0.5, 0.5, 0.5],
556 }],
557 };
558 let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
560 let base_resolution = dyhd_base_resolution;
561 let mask_resolution = base_resolution / 14;
563 let min_num = 1;
564
565 let (elem, attention_mask) = self.dynamic_preprocess(
566 image,
567 min_num,
568 config.dynamic_hd.unwrap(),
569 base_resolution,
570 mask_resolution,
571 device,
572 )?;
573
574 let hd_image = elem.apply(transforms, device)?;
575 let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
576 let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
577
578 let global_image = hd_image
580 .unsqueeze(0)?
581 .interpolate2d(base_resolution, base_resolution)?;
582 let global_attention_mask =
583 Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
584
585 let hd_image_reshape = hd_image
586 .reshape((
587 1,
588 3,
589 (img_h as f32 / base_resolution as f32) as usize,
590 base_resolution,
591 (img_w as f32 / base_resolution as f32) as usize,
592 base_resolution,
593 ))?
594 .permute((0, 2, 4, 1, 3, 5))?
595 .reshape(((), 3, base_resolution, base_resolution))?;
596
597 let attention_mask_reshape = attention_mask
598 .reshape((
599 1,
600 (mask_h as f32 / mask_resolution as f32) as usize,
601 mask_resolution,
602 (mask_w as f32 / mask_resolution as f32) as usize,
603 mask_resolution,
604 ))?
605 .permute((0, 1, 3, 2, 4))?
606 .reshape(((), mask_resolution, mask_resolution))?;
607
608 let downsample_attention_mask = {
609 let h_indices =
610 Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
611 let w_indices =
612 Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
613 let selected = attention_mask_reshape
614 .index_select(&h_indices, 1)?
615 .index_select(&w_indices, 2)?;
616
617 let mask = selected
618 .reshape((
619 1,
620 mask_h / mask_resolution,
621 mask_w / mask_resolution,
622 mask_resolution / 2 + mask_resolution % 2,
623 mask_resolution / 2 + mask_resolution % 2,
624 ))?
625 .permute((0, 1, 3, 2, 4))?;
626 mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
627 };
628
629 let img_tokens = 256
630 + 1
631 + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
632 + downsample_attention_mask
633 .i((.., 0))?
634 .sum_all()?
635 .to_scalar::<u32>()? as usize
636 + 16;
637
638 let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
639 let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
640
641 image_sizes.push((img_h as u32, img_w as u32));
642 padded_images.push(hd_image_reshape);
643 padded_masks.push(hd_mask_reshape);
644 num_img_tokens.push(img_tokens);
645 }
646 Ok(PreprocessedImages {
647 pixel_values: Tensor::stack(&padded_images, 0)?,
648 pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
649 image_sizes: None,
650 num_img_tokens: Some(num_img_tokens),
651 aspect_ratio_ids: None,
652 aspect_ratio_mask: None,
653 num_tiles: None,
654 image_grid_thw: None,
655 video_grid_thw: None,
656 rows: None,
657 cols: None,
658 pixel_values_list: None,
659 tgt_sizes: None,
660 image_sizes_all: Some(image_sizes),
661 num_crops: None,
662 })
663 }
664}