1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10
11use apodize::hanning_iter;
12use rubato::{
13 Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
14};
15use rustfft::{num_complex::Complex32, FftPlanner};
16
17use crate::{
18 device_map::DeviceMapper,
19 pipeline::{
20 text_models_inputs_processor::{
21 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
22 },
23 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
24 ProcessorCreator,
25 },
26 sequence::Sequence,
27};
28
29use crate::vision_models::{
30 image_processor::{ImagePreProcessor, PreprocessedImages},
31 phi4::Phi4MMVisionSpecificArgs,
32 preprocessor_config::PreProcessorConfig,
33 processor_config::ProcessorConfig,
34 ModelInputs,
35};
36
37use super::audio_embedding::AUDIO_SPECIAL_TOKEN_ID;
38use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
39
40const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
41const COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN: &str = r"<\|audio_\d+\|>";
42const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
43const AUDIO_SPECIAL_TOKEN: &str = "<|endoftext11|>";
44pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
45
46const AUDIO_FEATURE_SIZE: usize = 80; type AudioProcessingResult = Result<(Option<Tensor>, Option<Vec<usize>>, Option<Tensor>)>;
49
50pub struct Phi4MMInputsProcessor {
52 audio_compression_rate: usize,
53 audio_downsample_rate: usize,
54 audio_feat_stride: usize,
55 eightk_method: String, }
57
58pub struct Phi4MMProcessor {
60 inputs_processor: Arc<Phi4MMInputsProcessor>,
61}
62
63impl ProcessorCreator for Phi4MMProcessor {
64 fn new_processor(
65 _: Option<ProcessorConfig>,
66 pre_processor_config: PreProcessorConfig,
67 ) -> Arc<dyn Processor + Send + Sync> {
68 Arc::new(Self {
69 inputs_processor: Arc::new(Phi4MMInputsProcessor {
70 audio_compression_rate: pre_processor_config
71 .audio_compression_rate
72 .expect("audio_compression_rate"),
73 audio_downsample_rate: pre_processor_config
74 .audio_downsample_rate
75 .expect("audio_downsample_rate"),
76 audio_feat_stride: pre_processor_config
77 .audio_feat_stride
78 .expect("audio_feat_stride"),
79 eightk_method: "fillzero".to_string(), }),
81 })
82 }
83}
84
85impl Processor for Phi4MMProcessor {
86 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
87 self.inputs_processor.clone()
88 }
89 fn get_special_tokens(&self) -> &[&'static str] {
90 &[]
91 }
92 fn template_action(&self) -> MessagesAction {
93 MessagesAction::FlattenOnlyText
94 }
95}
96
97impl InputsProcessor for Phi4MMInputsProcessor {
98 fn get_type(&self) -> InputsProcessorType {
99 InputsProcessorType::Vision
100 }
101 fn process_inputs(
102 &self,
103 tokenizer: Option<Arc<Tokenizer>>,
104 input_seqs: &mut [&mut Sequence],
105 is_prompt: bool,
106 is_xlora: bool,
107 device: &Device,
108 no_kv_cache: bool,
109 last_n_context_len: Option<(usize, usize)>,
110 return_raw_logits: bool,
111 other_config: Option<Arc<dyn Any>>,
112 mut paged_attn_metadata: Option<PagedAttentionMeta>,
113 mapper: Option<&dyn DeviceMapper>,
114 ) -> anyhow::Result<InputProcessorOutput> {
115 if is_xlora {
116 return Err(anyhow::Error::msg(
117 "Cannot make inputs for X-LoRA vision model.",
118 ));
119 }
120 if no_kv_cache {
121 return Err(anyhow::Error::msg("Vision model must have kv cache."));
122 }
123 let Some(tokenizer) = tokenizer else {
124 return Err(anyhow::Error::msg(
125 "Phi4MMInputProcessor requires a specified tokenizer.",
126 ));
127 };
128
129 let config = other_config
130 .clone()
131 .expect("Need a PreProcessorConfig config.");
132 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
133
134 let has_audios = input_seqs.iter().all(|seq| seq.has_audios());
135 let has_images = input_seqs.iter().all(|seq| seq.has_images());
136
137 let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
138 let mut pixel_values_accum = Vec::new();
139 let mut pixel_attention_masks_accum = Vec::new();
140 let mut image_sizes_accum = Vec::new();
141 let mut num_img_tokens_accum = Vec::new();
142 for seq in input_seqs.iter_mut() {
143 let imgs = seq
144 .take_images()
145 .expect("Need to have images by this point.");
146 let PreprocessedImages {
147 pixel_values,
148 pixel_attention_mask,
149 image_sizes: _,
150 num_img_tokens,
151 aspect_ratio_ids: _,
152 aspect_ratio_mask: _,
153 num_tiles: _,
154 image_grid_thw: _,
155 video_grid_thw: _,
156 rows: _,
157 cols: _,
158 pixel_values_list: _,
159 tgt_sizes: _,
160 image_sizes_all,
161 num_crops: _,
162 } = self
163 .preprocess(
164 imgs,
165 vec![],
166 config,
167 device,
168 (usize::MAX, usize::MAX), )
170 .expect("Preprocessor failed");
171 let image_sizes = image_sizes_all.unwrap();
172 let pixel_attention_mask = pixel_attention_mask.unwrap();
173 pixel_values_accum.push(pixel_values);
174 pixel_attention_masks_accum.push(pixel_attention_mask);
175 image_sizes_accum.extend(image_sizes);
177 num_img_tokens_accum.push(num_img_tokens.unwrap());
178 }
179 (
180 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
181 Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
182 Some(image_sizes_accum),
183 Some(num_img_tokens_accum),
184 )
185 } else if has_audios {
186 (None, None, None, Some(vec![vec![]; input_seqs.len()]))
187 } else {
188 return text_models_inputs_processor::TextInputsProcessor
189 .process_inputs(
190 Some(tokenizer),
191 input_seqs,
192 is_prompt,
193 is_xlora,
194 device,
195 no_kv_cache,
196 last_n_context_len,
197 return_raw_logits,
198 other_config,
199 paged_attn_metadata,
200 mapper,
201 )
202 .map(|metadata| {
203 let InputProcessorOutput {
204 inputs,
205 seq_indices,
206 } = metadata;
207
208 let text_models_inputs_processor::ModelInputs {
209 input_ids,
210 input_ids_full: _,
211 seqlen_offsets,
212 seqlen_offsets_full: _,
213 context_lens,
214 position_ids,
215 paged_attn_meta,
216 flash_meta,
217 flash_meta_full: _,
218 } = *inputs
219 .downcast::<text_models_inputs_processor::ModelInputs>()
220 .expect("Downcast failed.");
221
222 let inputs: Box<dyn Any> = Box::new(ModelInputs {
223 input_ids,
224 seqlen_offsets,
225 context_lens,
226 position_ids,
227 pixel_values: None,
228 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
229 input_image_embeds: None,
230 image_attention_mask: None,
231 image_sizes: None,
232 input_audio_embeds: None,
233 audio_embed_sizes: None,
234 audio_attention_mask: None,
235 }),
236 paged_attn_meta,
237 flash_meta,
238 });
239 InputProcessorOutput {
240 inputs,
241 seq_indices,
242 }
243 });
244 };
245
246 let detokenized = tokenizer
247 .decode_batch(
248 &input_seqs
249 .iter()
250 .map(|seq| seq.get_toks())
251 .collect::<Vec<_>>(),
252 false,
253 )
254 .expect("Decode failed");
255
256 let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
257 let audio_token_pattern = Regex::new(COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN).unwrap();
258
259 for (mut detokenized, seq) in detokenized.into_iter().zip(input_seqs.iter_mut()) {
260 detokenized = img_token_pattern
261 .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
262 .to_string();
263 detokenized = audio_token_pattern
264 .replace_all(&detokenized, AUDIO_SPECIAL_TOKEN)
265 .to_string();
266
267 let has_changed_prompt = seq.multimodal.has_changed_prompt;
268 if !has_changed_prompt {
269 seq.set_toks_and_reallocate(
270 tokenizer
271 .encode_fast(detokenized.clone(), false)
272 .expect("Encode failed")
273 .get_ids()
274 .to_vec(),
275 paged_attn_metadata.as_mut(),
276 );
277
278 seq.set_initial_prompt(detokenized);
279 }
280 }
281
282 let (input_audio_embeds, audio_embed_sizes, audio_attention_mask) =
283 match self.process_audio_for_sequences(input_seqs, device) {
284 Ok(result) => result,
285 Err(e) => return Err(anyhow::Error::new(e)),
286 };
287
288 let mut toks = Vec::new();
289
290 for (seq, num_img_tokens) in input_seqs.iter_mut().zip(num_img_tokens.unwrap()) {
291 let has_changed_prompt = seq.multimodal.has_changed_prompt;
292
293 let mut i = 0;
294 let mut image_token_count_iter = num_img_tokens.iter();
295 let audio_sizes_tmp = audio_embed_sizes.clone().unwrap_or(vec![]);
296 let mut audio_embed_sizes = audio_sizes_tmp.iter();
297 while i < seq.get_toks().len() {
298 let token_id = seq.get_toks()[i];
299 let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
300 image_token_count_iter.next().unwrap()
301 } else if token_id == AUDIO_SPECIAL_TOKEN_ID as u32 {
302 audio_embed_sizes.next().unwrap()
303 } else {
304 i += 1;
305 continue;
306 };
307
308 let mut new_ids = seq.get_toks()[..i].to_vec();
309 new_ids.extend(vec![token_id; *token_count]);
310 new_ids.extend(seq.get_toks()[i + 1..].to_vec());
311 if !has_changed_prompt {
312 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
313 }
314 i += token_count;
315 }
316 if !has_changed_prompt {
317 seq.multimodal.has_changed_prompt = true;
318 }
319 toks.push(seq.get_toks().to_vec());
320 }
321
322 let result = if is_prompt {
323 get_prompt_input(
324 toks.iter().map(Vec::as_slice).collect(),
325 input_seqs,
326 device,
327 last_n_context_len,
328 return_raw_logits,
329 paged_attn_metadata.as_mut(),
330 mapper,
331 )
332 } else {
333 get_completion_input(
334 toks.iter().map(Vec::as_slice).collect(),
335 input_seqs,
336 device,
337 no_kv_cache,
338 last_n_context_len,
339 return_raw_logits,
340 paged_attn_metadata.as_mut(),
341 mapper,
342 )
343 };
344
345 result.map(move |metadata| {
346 let pixel_values = pixel_values.clone();
347 let pixel_attention_mask = pixel_attention_mask.clone();
348 let text_models_inputs_processor::InnerInputProcessorOutput {
349 inputs:
350 text_models_inputs_processor::InputMetadata {
351 input,
352 positions,
353 context_lens,
354 position_ids,
355 paged_attn_meta,
356 flash_meta,
357 },
358 seq_indices,
359 } = metadata;
360 let inputs: Box<dyn Any> = Box::new(ModelInputs {
361 input_ids: input,
362 seqlen_offsets: positions,
363 context_lens,
364 position_ids,
365 pixel_values: pixel_values.clone(),
366 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
367 input_image_embeds: pixel_values,
368 image_attention_mask: pixel_attention_mask,
369 image_sizes: image_sizes.clone(),
370 input_audio_embeds: input_audio_embeds.clone(),
371 audio_embed_sizes: audio_embed_sizes.clone(),
372 audio_attention_mask: audio_attention_mask.clone(),
373 }),
374 paged_attn_meta,
375 flash_meta,
376 });
377 InputProcessorOutput {
378 inputs,
379 seq_indices,
380 }
381 })
382 }
383}
384
385impl Phi4MMInputsProcessor {
386 fn extract_audio_features(
387 &self,
388 audio_data: &[f32],
389 sample_rate: u32,
390 ) -> Result<Vec<Vec<f32>>> {
391 let (resampled_audio, final_sample_rate) =
393 self.resample_audio_with_rubato(audio_data, sample_rate)?;
394
395 let mel_features =
397 self.extract_mel_spectrogram_rustfft(&resampled_audio, final_sample_rate)?;
398
399 Ok(mel_features)
400 }
401
402 fn resample_audio_with_rubato(&self, wav: &[f32], fs: u32) -> Result<(Vec<f32>, u32)> {
403 let target_fs = if fs > 16000 {
404 16000
405 } else if fs > 8000 && fs < 16000 {
406 8000
407 } else if fs < 8000 {
408 return Err(candle_core::Error::Msg(format!(
409 "Unsupported sample rate: {fs}"
410 )));
411 } else {
412 return Ok((wav.to_vec(), fs)); };
414
415 if fs == target_fs {
416 return Ok((wav.to_vec(), fs));
417 }
418
419 if fs == 8000 && self.eightk_method == "resample" {
421 let params = SincInterpolationParameters {
423 sinc_len: 256,
424 f_cutoff: 0.95,
425 interpolation: SincInterpolationType::Linear,
426 oversampling_factor: 256,
427 window: WindowFunction::BlackmanHarris2,
428 };
429
430 let mut resampler = SincFixedIn::<f32>::new(
431 2.0, 2.0,
433 params,
434 wav.len(),
435 1, )
437 .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
438
439 let input = vec![wav.to_vec()];
440 let output = resampler
441 .process(&input, None)
442 .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
443
444 return Ok((output[0].clone(), 16000));
445 }
446
447 let resample_ratio = target_fs as f64 / fs as f64;
449
450 let params = SincInterpolationParameters {
451 sinc_len: 256,
452 f_cutoff: 0.95,
453 interpolation: SincInterpolationType::Linear,
454 oversampling_factor: 256,
455 window: WindowFunction::BlackmanHarris2,
456 };
457
458 let mut resampler = SincFixedIn::<f32>::new(
459 resample_ratio,
460 2.0,
461 params,
462 wav.len(),
463 1, )
465 .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
466
467 let input = vec![wav.to_vec()];
468 let output = resampler
469 .process(&input, None)
470 .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
471
472 Ok((output[0].clone(), target_fs))
473 }
474
475 fn extract_mel_spectrogram_rustfft(&self, wav: &[f32], fs: u32) -> Result<Vec<Vec<f32>>> {
476 let (n_fft, win_length, hop_length) = if fs == 8000 {
478 (256, 200, 80)
479 } else if fs == 16000 {
480 (512, 400, 160)
481 } else {
482 return Err(candle_core::Error::Msg(format!(
483 "Unsupported sample rate: {fs}"
484 )));
485 };
486
487 let preemphasized = self.apply_preemphasis(wav, 0.97);
489
490 let mut planner = FftPlanner::<f32>::new();
492 let fft = planner.plan_fft_forward(n_fft);
493
494 let window: Vec<f64> = hanning_iter(win_length).collect();
496
497 let mel_filters = self.create_mel_filterbank(AUDIO_FEATURE_SIZE, n_fft, fs as f32)?;
499
500 let n_batch = (preemphasized.len() - win_length) / hop_length + 1;
502 let mut mel_features = Vec::new();
503
504 for i in 0..n_batch {
505 let start = i * hop_length;
506 let end = start + win_length;
507 if end > preemphasized.len() {
508 break;
509 }
510
511 let mut windowed: Vec<Complex32> = preemphasized[start..end]
513 .iter()
514 .zip(window.iter())
515 .map(|(s, w)| Complex32::new(s * *w as f32, 0.0))
516 .collect();
517
518 windowed.resize(n_fft, Complex32::new(0.0, 0.0));
520
521 fft.process(&mut windowed);
523
524 let power_spectrum: Vec<f32> = windowed[0..n_fft / 2 + 1]
526 .iter()
527 .map(|c| c.norm_sqr())
528 .collect();
529
530 let mut mel_frame = vec![0.0; AUDIO_FEATURE_SIZE];
532 for (mel_idx, filter) in mel_filters.iter().enumerate() {
533 let mut sum = 0.0;
534 for (freq_idx, &coeff) in filter.iter().enumerate() {
535 if freq_idx < power_spectrum.len() {
536 sum += power_spectrum[freq_idx] * coeff;
537 }
538 }
539 mel_frame[mel_idx] = (sum.max(1.0)).ln(); }
541
542 mel_features.push(mel_frame);
543 }
544
545 if fs == 8000 && self.eightk_method == "fillzero" {
547 for frame in &mut mel_features {
548 let original_len = frame.len();
550 frame.extend(vec![0.0; original_len]);
551 }
552 }
553
554 Ok(mel_features)
555 }
556
557 fn apply_preemphasis(&self, wav: &[f32], preemphasis: f32) -> Vec<f32> {
558 if wav.is_empty() {
559 return vec![];
560 }
561
562 let mut preemphasized = Vec::with_capacity(wav.len());
563
564 preemphasized.push(wav[0] * 32768.0);
566
567 for i in 1..wav.len() {
569 let filtered = (wav[i] - preemphasis * wav[i - 1]) * 32768.0;
570 preemphasized.push(filtered);
571 }
572
573 preemphasized
574 }
575
576 fn create_mel_filterbank(
577 &self,
578 n_mels: usize,
579 n_fft: usize,
580 sample_rate: f32,
581 ) -> Result<Vec<Vec<f32>>> {
582 let bank_width = n_fft / 2 + 1;
583 let fmax = sample_rate / 2.0;
584 let fmin = 0.0;
585
586 let hz_to_mel = |f: f32| 1127.0 * (1.0 + f / 700.0).ln();
588 let mel_to_hz = |mel: f32| 700.0 * (mel / 1127.0).exp() - 700.0;
589
590 let mel_low = hz_to_mel(fmin);
591 let mel_high = hz_to_mel(fmax);
592
593 let mel_centers: Vec<f32> = (0..=n_mels + 1)
595 .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (n_mels + 1) as f32)
596 .collect();
597
598 let hz_centers: Vec<f32> = mel_centers.iter().map(|&mel| mel_to_hz(mel)).collect();
599
600 let bin_centers: Vec<usize> = hz_centers
602 .iter()
603 .map(|&f| ((f * n_fft as f32 / sample_rate) + 0.5) as usize)
604 .collect();
605
606 let mut filters = Vec::new();
608 for m in 0..n_mels {
609 let mut filter = vec![0.0; bank_width];
610
611 let left_bin = bin_centers[m];
612 let center_bin = bin_centers[m + 1];
613 let right_bin = bin_centers[m + 2];
614
615 for (bin, filter) in filter
617 .iter_mut()
618 .enumerate()
619 .take(center_bin)
620 .skip(left_bin)
621 {
622 if bin < bank_width {
623 *filter = (bin - left_bin) as f32 / (center_bin - left_bin) as f32;
624 }
625 }
626
627 for (bin, filter) in filter
629 .iter_mut()
630 .enumerate()
631 .take(right_bin)
632 .skip(center_bin)
633 {
634 if bin < bank_width {
635 *filter = (right_bin - bin) as f32 / (right_bin - center_bin) as f32;
636 }
637 }
638
639 filters.push(filter);
640 }
641
642 Ok(filters)
643 }
644
645 fn compute_audio_embed_size(
646 &self,
647 audio_frames: usize,
648 compression_rate: usize,
649 downsample_rate: usize,
650 ) -> usize {
651 let integer = audio_frames / compression_rate;
653 let remainder = audio_frames % compression_rate;
654 let result = if remainder == 0 { integer } else { integer + 1 };
655
656 let integer = result / downsample_rate;
658 let remainder = result % downsample_rate;
659 if remainder == 0 {
660 integer
661 } else {
662 integer + 1
663 }
664 }
665
666 fn create_audio_attention_mask(
667 &self,
668 audio_frames_list: &[usize],
669 device: &Device,
670 ) -> Result<Tensor> {
671 let max_frames = *audio_frames_list.iter().max().unwrap_or(&0);
672 let batch_size = audio_frames_list.len();
673
674 let mut mask_data = vec![0u8; batch_size * max_frames];
675 for (batch_idx, &frames) in audio_frames_list.iter().enumerate() {
676 for frame_idx in 0..frames.min(max_frames) {
677 mask_data[batch_idx * max_frames + frame_idx] = 1;
678 }
679 }
680
681 Tensor::from_slice(&mask_data, (batch_size, max_frames), device)?.to_dtype(DType::F32)
682 }
683
684 fn process_audio_for_sequences(
685 &self,
686 input_seqs: &mut [&mut Sequence],
687 device: &Device,
688 ) -> AudioProcessingResult {
689 let has_audio_tokens = input_seqs
691 .iter()
692 .any(|seq| seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32)));
693
694 if !has_audio_tokens {
695 return Ok((None, None, None));
696 }
697
698 let mut audio_features_list = Vec::new();
699 let mut audio_embed_sizes_list = Vec::new();
700 let mut audio_frames_list = Vec::new();
701
702 for seq in input_seqs.iter_mut() {
704 let has_audio = seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32));
705
706 if has_audio {
707 if let Some(audios) = seq.take_audios() {
708 for audio in audios.into_iter() {
709 let samples = audio.to_mono();
711
712 let features = self.extract_audio_features(&samples, audio.sample_rate)?;
714 let audio_frames = features.len() * self.audio_feat_stride;
715
716 let embed_size = self.compute_audio_embed_size(
717 audio_frames,
718 self.audio_compression_rate,
719 self.audio_downsample_rate,
720 );
721
722 let features_len = features.len();
724 let features_flat: Vec<f32> = features.into_iter().flatten().collect();
725 let features_tensor = Tensor::from_slice(
726 &features_flat,
727 (features_len, AUDIO_FEATURE_SIZE),
728 device,
729 )?;
730
731 audio_features_list.push(features_tensor);
732 audio_embed_sizes_list.push(embed_size);
733 audio_frames_list.push(audio_frames);
734 }
735 } else {
736 candle_core::bail!("No audios in `process_audio_for_sequences`");
737 };
738 }
739 }
740
741 if audio_features_list.is_empty() {
742 return Ok((None, None, None));
743 }
744
745 let max_len = audio_features_list
747 .iter()
748 .map(|t| t.dim(0).unwrap_or(0))
749 .max()
750 .unwrap_or(0);
751
752 let mut padded_features = Vec::new();
753 for features in audio_features_list {
754 let seq_len = features.dim(0)?;
755 if seq_len < max_len {
756 let padding =
757 Tensor::zeros((max_len - seq_len, AUDIO_FEATURE_SIZE), DType::F32, device)?;
758 let padded = Tensor::cat(&[features, padding], 0)?;
759 padded_features.push(padded);
760 } else {
761 padded_features.push(features);
762 }
763 }
764
765 let input_audio_embeds = Tensor::stack(&padded_features, 0)?;
767
768 let audio_attention_mask = if audio_frames_list.len() > 1 {
770 Some(self.create_audio_attention_mask(&audio_frames_list, device)?)
771 } else {
772 None
773 };
774
775 Ok((
776 Some(input_audio_embeds),
777 Some(audio_embed_sizes_list),
778 audio_attention_mask,
779 ))
780 }
781}
782
783impl Phi4MMInputsProcessor {
784 fn pad_image(
785 image: &DynamicImage,
786 top: u32,
787 bottom: u32,
788 left: u32,
789 right: u32,
790 pad_color: Rgba<u8>,
791 ) -> DynamicImage {
792 let new_width = image.width() + left + right;
794 let new_height = image.height() + top + bottom;
795
796 let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
798 for x in 0..new_width {
799 for y in 0..new_height {
800 new_image.put_pixel(x, y, pad_color);
801 }
802 }
803
804 new_image
806 .copy_from(image, 0, 0)
807 .expect("Failed to copy image");
808
809 new_image
810 }
811
812 fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
813 let mut ratios: HashSet<(u32, u32)> = HashSet::new();
814 for n in min_num..=max_num {
815 for i in 1..=n {
816 for j in 1..=n {
817 if i * j >= min_num && i * j <= max_num {
818 ratios.insert((i, j));
819 }
820 }
821 }
822 }
823 let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
824 sorted_ratios.sort_by_key(|&(i, j)| i * j);
825 sorted_ratios
826 }
827
828 fn find_closest_aspect_ratio(
829 aspect_ratio: f64,
830 target_ratios: Vec<(u32, u32)>,
831 width: u32,
832 height: u32,
833 image_size: usize,
834 ) -> (u32, u32) {
835 let mut best_ratio_diff = f64::INFINITY;
836 let mut best_ratio = (1, 1);
837 let area = width * height;
838 for ratio in target_ratios {
839 let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
840 let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
841 if ratio_diff < best_ratio_diff {
842 best_ratio_diff = ratio_diff;
843 best_ratio = ratio;
844 } else if ratio_diff == best_ratio_diff
845 && area as f64
846 > 0.5 * image_size as f64 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
847 {
848 best_ratio = ratio;
849 }
850 }
851 best_ratio
852 }
853
854 fn dynamic_preprocess(
855 &self,
856 mut image: DynamicImage,
857 min_num: usize,
858 max_num: usize,
859 image_size: usize,
860 mask_size: usize,
861 device: &Device,
862 ) -> Result<(DynamicImage, Tensor)> {
863 let (orig_w, orig_h) = image.dimensions();
864
865 let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
866 let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
867 let (target_aspect_ratio, target_width, target_height) =
868 if w_crop_num * h_crop_num > max_num as f64 {
869 let aspect_ratio = orig_w as f64 / orig_h as f64;
870 let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
871
872 let target_aspect_ratio = Self::find_closest_aspect_ratio(
873 aspect_ratio,
874 target_ratios,
875 orig_w,
876 orig_h,
877 image_size,
878 );
879
880 let target_width = image_size * target_aspect_ratio.0 as usize;
881 let target_height = image_size * target_aspect_ratio.1 as usize;
882
883 (
884 (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
885 target_width,
886 target_height,
887 )
888 } else {
889 let target_width = (image_size as f64 * w_crop_num) as usize;
890 let target_height = (image_size as f64 * h_crop_num) as usize;
891 let target_aspect_ratio = (w_crop_num, h_crop_num);
892
893 (target_aspect_ratio, target_width, target_height)
894 };
895
896 let ratio_width = target_width as f64 / orig_w as f64;
897 let ratio_height = target_height as f64 / orig_h as f64;
898 let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
899 (
900 (target_width, (orig_h as f64 * ratio_width) as usize),
901 0_usize,
902 target_height - (orig_h as f64 * ratio_width) as usize,
903 )
904 } else {
905 (
906 ((orig_w as f64 * ratio_height) as usize, target_height),
907 target_width - (orig_w as f64 * ratio_height) as usize,
908 0_usize,
909 )
910 };
911
912 if new_size.1.min(target_height) < 10 || new_size.0.min(target_width) < 10 {
914 candle_core::bail!(
915 "Image aspect ratio too extreme; resulting size below minimum threshold",
916 );
917 }
918
919 let mut attention_mask = Tensor::ones(
920 (
921 (mask_size as f64 * target_aspect_ratio.1) as usize,
922 (mask_size as f64 * target_aspect_ratio.0) as usize,
923 ),
924 DType::U32,
925 device,
926 )?;
927 if padding_width >= 14 {
928 attention_mask = attention_mask.slice_assign(
929 &[
930 0..attention_mask.dim(0)?,
931 (attention_mask.dim(1)? - padding_width / 14)..attention_mask.dim(1)?,
932 ],
933 &Tensor::zeros(
934 (attention_mask.dim(0)?, padding_width / 14),
935 DType::U32,
936 device,
937 )?,
938 )?;
939 }
940 if padding_height >= 14 {
941 attention_mask = attention_mask.slice_assign(
942 &[
943 (attention_mask.dim(0)? - padding_height / 14)..attention_mask.dim(0)?,
944 0..attention_mask.dim(1)?,
945 ],
946 &Tensor::zeros(
947 (padding_height / 14, attention_mask.dim(1)?),
948 DType::U32,
949 device,
950 )?,
951 )?;
952 }
953
954 let mask_sum: u32 = attention_mask.sum_all()?.to_scalar::<u32>()?;
956 if mask_sum == 0 {
957 candle_core::bail!("dynamic_preprocess produced an attention mask with zero sum",);
958 }
959
960 image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
961 image = Self::pad_image(
962 &image,
963 0,
964 padding_height as u32,
965 0,
966 padding_width as u32,
967 Rgba([255u8, 255, 255, 255]),
968 );
969
970 Ok((image, attention_mask))
971 }
972}
973
974impl ImagePreProcessor for Phi4MMInputsProcessor {
975 #[allow(clippy::excessive_precision)]
976 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
977 #[allow(clippy::excessive_precision)]
978 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
979
980 fn preprocess(
981 &self,
982 mut images: Vec<DynamicImage>,
983 videos: Vec<Vec<DynamicImage>>,
984 config: &PreProcessorConfig,
985 device: &Device,
986 (_, _): (usize, usize),
987 ) -> Result<PreprocessedImages> {
988 assert!(!images.is_empty());
990 assert!(videos.is_empty());
991
992 let mut max_size = None;
994 for image in images.iter() {
995 if max_size.is_none() {
996 max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
997 } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
998 max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
999 } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
1000 max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
1001 }
1002 }
1003 let (max_w, max_h) = max_size.unwrap();
1004 for image in images.iter_mut() {
1005 *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
1006 }
1007
1008 let mut image_sizes = Vec::new();
1009 let mut padded_images = Vec::new();
1010 let mut padded_masks = Vec::new();
1011 let mut num_img_tokens = Vec::new();
1012 for mut image in images {
1013 if config.do_convert_rgb.unwrap_or(true) {
1015 image = DynamicImage::ImageRgb8(image.to_rgb8());
1016 }
1017
1018 let transforms = Transforms {
1019 input: &ToTensor,
1020 inner_transforms: &[&Normalize {
1021 mean: vec![0.5, 0.5, 0.5],
1022 std: vec![0.5, 0.5, 0.5],
1023 }],
1024 };
1025 let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
1027 let base_resolution = dyhd_base_resolution;
1028 let mask_resolution = base_resolution / 14;
1030 let min_num = 1;
1031
1032 let (elem, attention_mask) = self.dynamic_preprocess(
1033 image,
1034 min_num,
1035 config.dynamic_hd.unwrap(),
1036 base_resolution,
1037 mask_resolution,
1038 device,
1039 )?;
1040
1041 let hd_image = elem.apply(transforms, device)?;
1042 let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
1043 let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
1044
1045 let global_image = hd_image
1047 .unsqueeze(0)?
1048 .interpolate2d(base_resolution, base_resolution)?;
1049 let global_attention_mask =
1050 Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
1051
1052 let hd_image_reshape = hd_image
1053 .reshape((
1054 1,
1055 3,
1056 (img_h as f32 / base_resolution as f32) as usize,
1057 base_resolution,
1058 (img_w as f32 / base_resolution as f32) as usize,
1059 base_resolution,
1060 ))?
1061 .permute((0, 2, 4, 1, 3, 5))?
1062 .reshape(((), 3, base_resolution, base_resolution))?;
1063
1064 let attention_mask_reshape = attention_mask
1065 .reshape((
1066 1,
1067 (mask_h as f32 / mask_resolution as f32) as usize,
1068 mask_resolution,
1069 (mask_w as f32 / mask_resolution as f32) as usize,
1070 mask_resolution,
1071 ))?
1072 .permute((0, 1, 3, 2, 4))?
1073 .reshape(((), mask_resolution, mask_resolution))?;
1074
1075 let downsample_attention_mask = {
1076 let h_indices =
1077 Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
1078 let w_indices =
1079 Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
1080 let selected = attention_mask_reshape
1081 .index_select(&h_indices, 1)?
1082 .index_select(&w_indices, 2)?;
1083
1084 let mask = selected
1085 .reshape((
1086 1,
1087 mask_h / mask_resolution,
1088 mask_w / mask_resolution,
1089 mask_resolution / 2 + mask_resolution % 2,
1090 mask_resolution / 2 + mask_resolution % 2,
1091 ))?
1092 .permute((0, 1, 3, 2, 4))?;
1093 mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
1094 };
1095
1096 let img_tokens = 256
1097 + 1
1098 + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
1099 + downsample_attention_mask
1100 .i((.., 0))?
1101 .sum_all()?
1102 .to_scalar::<u32>()? as usize
1103 + 16;
1104
1105 let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
1106 let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
1107
1108 image_sizes.push((img_h as u32, img_w as u32));
1109 padded_images.push(hd_image_reshape);
1110 padded_masks.push(hd_mask_reshape);
1111 num_img_tokens.push(img_tokens);
1112 }
1113 Ok(PreprocessedImages {
1114 pixel_values: Tensor::stack(&padded_images, 0)?,
1115 pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
1116 image_sizes: None,
1117 num_img_tokens: Some(num_img_tokens),
1118 aspect_ratio_ids: None,
1119 aspect_ratio_mask: None,
1120 num_tiles: None,
1121 image_grid_thw: None,
1122 video_grid_thw: None,
1123 rows: None,
1124 cols: None,
1125 pixel_values_list: None,
1126 tgt_sizes: None,
1127 image_sizes_all: Some(image_sizes),
1128 num_crops: None,
1129 })
1130 }
1131}