1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use apodize::hanning_iter;
13use rubato::{
14 Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
15};
16use rustfft::{num_complex::Complex32, FftPlanner};
17
18use crate::{
19 device_map::DeviceMapper,
20 pipeline::{
21 text_models_inputs_processor::{
22 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
23 },
24 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
25 ProcessorCreator,
26 },
27 sequence::Sequence,
28};
29
30use crate::vision_models::{
31 image_processor::{ImagePreProcessor, PreprocessedImages},
32 phi4::Phi4MMVisionSpecificArgs,
33 preprocessor_config::PreProcessorConfig,
34 processor_config::ProcessorConfig,
35 ModelInputs,
36};
37
38use super::audio_embedding::AUDIO_SPECIAL_TOKEN_ID;
39use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
40
41const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
42const COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN: &str = r"<\|audio_\d+\|>";
43const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
44const AUDIO_SPECIAL_TOKEN: &str = "<|endoftext11|>";
45pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
46
47const AUDIO_FEATURE_SIZE: usize = 80; type AudioProcessingResult = Result<(Option<Tensor>, Option<Vec<usize>>, Option<Tensor>)>;
50
51pub struct Phi4MMInputsProcessor {
53 audio_compression_rate: usize,
54 audio_downsample_rate: usize,
55 audio_feat_stride: usize,
56 eightk_method: String, }
58
59pub struct Phi4MMProcessor {
61 inputs_processor: Arc<Phi4MMInputsProcessor>,
62}
63
64impl ProcessorCreator for Phi4MMProcessor {
65 fn new_processor(
66 _: Option<ProcessorConfig>,
67 pre_processor_config: PreProcessorConfig,
68 ) -> Arc<dyn Processor + Send + Sync> {
69 Arc::new(Self {
70 inputs_processor: Arc::new(Phi4MMInputsProcessor {
71 audio_compression_rate: pre_processor_config
72 .audio_compression_rate
73 .expect("audio_compression_rate"),
74 audio_downsample_rate: pre_processor_config
75 .audio_downsample_rate
76 .expect("audio_downsample_rate"),
77 audio_feat_stride: pre_processor_config
78 .audio_feat_stride
79 .expect("audio_feat_stride"),
80 eightk_method: "fillzero".to_string(), }),
82 })
83 }
84}
85
86impl Processor for Phi4MMProcessor {
87 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
88 self.inputs_processor.clone()
89 }
90 fn get_special_tokens(&self) -> &[&'static str] {
91 &[]
92 }
93 fn template_action(&self) -> MessagesAction {
94 MessagesAction::FlattenOnlyText
95 }
96}
97
98impl InputsProcessor for Phi4MMInputsProcessor {
99 fn get_type(&self) -> InputsProcessorType {
100 InputsProcessorType::Vision
101 }
102 fn process_inputs(
103 &self,
104 tokenizer: Option<Arc<Tokenizer>>,
105 input_seqs: &mut [&mut Sequence],
106 is_prompt: bool,
107 is_xlora: bool,
108 device: &Device,
109 no_kv_cache: bool,
110 last_n_context_len: Option<(usize, usize)>,
111 return_raw_logits: bool,
112 other_config: Option<Arc<dyn Any>>,
113 mut paged_attn_metadata: Option<PagedAttentionMeta>,
114 prompt_chunksize: Option<NonZeroUsize>,
115 mapper: Option<&dyn DeviceMapper>,
116 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
117 if is_xlora {
118 return Box::new(std::iter::once(Err(anyhow::Error::msg(
119 "Cannot make inputs for X-LoRA vision model.",
120 ))));
121 }
122 if no_kv_cache {
123 return Box::new(std::iter::once(Err(anyhow::Error::msg(
124 "Vision model must have kv cache.",
125 ))));
126 }
127 if prompt_chunksize.is_some() {
129 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
130 }
131 let Some(tokenizer) = tokenizer else {
132 return Box::new(std::iter::once(Err(anyhow::Error::msg(
133 "Phi4MMInputProcessor requires a specified tokenizer.",
134 ))));
135 };
136
137 let config = other_config
138 .clone()
139 .expect("Need a PreProcessorConfig config.");
140 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
141
142 let has_audios = input_seqs.iter().all(|seq| seq.has_audios());
143 let has_images = input_seqs.iter().all(|seq| seq.has_images());
144
145 let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
146 let mut pixel_values_accum = Vec::new();
147 let mut pixel_attention_masks_accum = Vec::new();
148 let mut image_sizes_accum = Vec::new();
149 let mut num_img_tokens_accum = Vec::new();
150 for seq in input_seqs.iter_mut() {
151 let imgs = seq
152 .take_images()
153 .expect("Need to have images by this point.");
154 let PreprocessedImages {
155 pixel_values,
156 pixel_attention_mask,
157 image_sizes: _,
158 num_img_tokens,
159 aspect_ratio_ids: _,
160 aspect_ratio_mask: _,
161 num_tiles: _,
162 image_grid_thw: _,
163 video_grid_thw: _,
164 rows: _,
165 cols: _,
166 pixel_values_list: _,
167 tgt_sizes: _,
168 image_sizes_all,
169 num_crops: _,
170 } = self
171 .preprocess(
172 imgs,
173 vec![],
174 config,
175 device,
176 (usize::MAX, usize::MAX), )
178 .expect("Preprocessor failed");
179 let image_sizes = image_sizes_all.unwrap();
180 let pixel_attention_mask = pixel_attention_mask.unwrap();
181 pixel_values_accum.push(pixel_values);
182 pixel_attention_masks_accum.push(pixel_attention_mask);
183 image_sizes_accum.extend(image_sizes);
185 num_img_tokens_accum.push(num_img_tokens.unwrap());
186 }
187 (
188 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
189 Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
190 Some(image_sizes_accum),
191 Some(num_img_tokens_accum),
192 )
193 } else if has_audios {
194 (None, None, None, Some(vec![vec![]; input_seqs.len()]))
195 } else {
196 return Box::new(
197 text_models_inputs_processor::TextInputsProcessor
198 .process_inputs(
199 Some(tokenizer),
200 input_seqs,
201 is_prompt,
202 is_xlora,
203 device,
204 no_kv_cache,
205 last_n_context_len,
206 return_raw_logits,
207 other_config,
208 paged_attn_metadata,
209 None, mapper,
211 )
212 .map(|metadata| {
213 let InputProcessorOutput {
214 inputs,
215 seq_indices,
216 } = metadata?;
217
218 let text_models_inputs_processor::ModelInputs {
219 input_ids,
220 input_ids_full: _,
221 seqlen_offsets,
222 seqlen_offsets_full: _,
223 context_lens,
224 position_ids,
225 paged_attn_meta,
226 flash_meta,
227 flash_meta_full: _,
228 } = *inputs
229 .downcast::<text_models_inputs_processor::ModelInputs>()
230 .expect("Downcast failed.");
231
232 let inputs: Box<dyn Any> = Box::new(ModelInputs {
233 input_ids,
234 seqlen_offsets,
235 context_lens,
236 position_ids,
237 pixel_values: None,
238 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
239 input_image_embeds: None,
240 image_attention_mask: None,
241 image_sizes: None,
242 input_audio_embeds: None,
243 audio_embed_sizes: None,
244 audio_attention_mask: None,
245 }),
246 paged_attn_meta,
247 flash_meta,
248 });
249 Ok(InputProcessorOutput {
250 inputs,
251 seq_indices,
252 })
253 }),
254 );
255 };
256
257 let detokenized = tokenizer
258 .decode_batch(
259 &input_seqs
260 .iter()
261 .map(|seq| seq.get_toks())
262 .collect::<Vec<_>>(),
263 false,
264 )
265 .expect("Decode failed");
266
267 let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
268 let audio_token_pattern = Regex::new(COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN).unwrap();
269
270 for (mut detokenized, seq) in detokenized.into_iter().zip(input_seqs.iter_mut()) {
271 detokenized = img_token_pattern
272 .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
273 .to_string();
274 detokenized = audio_token_pattern
275 .replace_all(&detokenized, AUDIO_SPECIAL_TOKEN)
276 .to_string();
277
278 let has_changed_prompt = seq.multimodal.has_changed_prompt;
279 if !has_changed_prompt {
280 seq.set_toks_and_reallocate(
281 tokenizer
282 .encode_fast(detokenized.clone(), false)
283 .expect("Encode failed")
284 .get_ids()
285 .to_vec(),
286 paged_attn_metadata.as_mut(),
287 );
288
289 seq.set_initial_prompt(detokenized);
290 }
291 }
292
293 let (input_audio_embeds, audio_embed_sizes, audio_attention_mask) =
294 match self.process_audio_for_sequences(input_seqs, device) {
295 Ok(result) => result,
296 Err(e) => return Box::new(std::iter::once(Err(anyhow::Error::new(e)))),
297 };
298
299 let mut toks = Vec::new();
300
301 for (seq, num_img_tokens) in input_seqs.iter_mut().zip(num_img_tokens.unwrap()) {
302 let has_changed_prompt = seq.multimodal.has_changed_prompt;
303
304 let mut i = 0;
305 let mut image_token_count_iter = num_img_tokens.iter();
306 let audio_sizes_tmp = audio_embed_sizes.clone().unwrap_or(vec![]);
307 let mut audio_embed_sizes = audio_sizes_tmp.iter();
308 while i < seq.get_toks().len() {
309 let token_id = seq.get_toks()[i];
310 let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
311 image_token_count_iter.next().unwrap()
312 } else if token_id == AUDIO_SPECIAL_TOKEN_ID as u32 {
313 audio_embed_sizes.next().unwrap()
314 } else {
315 i += 1;
316 continue;
317 };
318
319 let mut new_ids = seq.get_toks()[..i].to_vec();
320 new_ids.extend(vec![token_id; *token_count]);
321 new_ids.extend(seq.get_toks()[i + 1..].to_vec());
322 if !has_changed_prompt {
323 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
324 }
325 i += token_count;
326 }
327 if !has_changed_prompt {
328 seq.multimodal.has_changed_prompt = true;
329 }
330 toks.push(seq.get_toks().to_vec());
331 }
332
333 let iter = if is_prompt {
334 get_prompt_input(
335 toks.iter().map(Vec::as_slice).collect(),
336 input_seqs,
337 device,
338 last_n_context_len,
339 return_raw_logits,
340 paged_attn_metadata.as_mut(),
341 None, mapper,
343 )
344 } else {
345 get_completion_input(
346 toks.iter().map(Vec::as_slice).collect(),
347 input_seqs,
348 device,
349 no_kv_cache,
350 last_n_context_len,
351 return_raw_logits,
352 paged_attn_metadata.as_mut(),
353 None, mapper,
355 )
356 };
357
358 Box::new(iter.into_iter().map(move |metadata| {
359 let pixel_values = pixel_values.clone();
360 let pixel_attention_mask = pixel_attention_mask.clone();
361 let text_models_inputs_processor::InnerInputProcessorOutput {
362 inputs:
363 text_models_inputs_processor::InputMetadata {
364 input,
365 positions,
366 context_lens,
367 position_ids,
368 paged_attn_meta,
369 flash_meta,
370 },
371 seq_indices,
372 } = metadata?;
373 let inputs: Box<dyn Any> = Box::new(ModelInputs {
374 input_ids: input,
375 seqlen_offsets: positions,
376 context_lens,
377 position_ids,
378 pixel_values: pixel_values.clone(),
379 model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
380 input_image_embeds: pixel_values,
381 image_attention_mask: pixel_attention_mask,
382 image_sizes: image_sizes.clone(),
383 input_audio_embeds: input_audio_embeds.clone(),
384 audio_embed_sizes: audio_embed_sizes.clone(),
385 audio_attention_mask: audio_attention_mask.clone(),
386 }),
387 paged_attn_meta,
388 flash_meta,
389 });
390 Ok(InputProcessorOutput {
391 inputs,
392 seq_indices,
393 })
394 }))
395 }
396}
397
398impl Phi4MMInputsProcessor {
399 fn extract_audio_features(
400 &self,
401 audio_data: &[f32],
402 sample_rate: u32,
403 ) -> Result<Vec<Vec<f32>>> {
404 let (resampled_audio, final_sample_rate) =
406 self.resample_audio_with_rubato(audio_data, sample_rate)?;
407
408 let mel_features =
410 self.extract_mel_spectrogram_rustfft(&resampled_audio, final_sample_rate)?;
411
412 Ok(mel_features)
413 }
414
415 fn resample_audio_with_rubato(&self, wav: &[f32], fs: u32) -> Result<(Vec<f32>, u32)> {
416 let target_fs = if fs > 16000 {
417 16000
418 } else if fs > 8000 && fs < 16000 {
419 8000
420 } else if fs < 8000 {
421 return Err(candle_core::Error::Msg(format!(
422 "Unsupported sample rate: {fs}"
423 )));
424 } else {
425 return Ok((wav.to_vec(), fs)); };
427
428 if fs == target_fs {
429 return Ok((wav.to_vec(), fs));
430 }
431
432 if fs == 8000 && self.eightk_method == "resample" {
434 let params = SincInterpolationParameters {
436 sinc_len: 256,
437 f_cutoff: 0.95,
438 interpolation: SincInterpolationType::Linear,
439 oversampling_factor: 256,
440 window: WindowFunction::BlackmanHarris2,
441 };
442
443 let mut resampler = SincFixedIn::<f32>::new(
444 2.0, 2.0,
446 params,
447 wav.len(),
448 1, )
450 .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
451
452 let input = vec![wav.to_vec()];
453 let output = resampler
454 .process(&input, None)
455 .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
456
457 return Ok((output[0].clone(), 16000));
458 }
459
460 let resample_ratio = target_fs as f64 / fs as f64;
462
463 let params = SincInterpolationParameters {
464 sinc_len: 256,
465 f_cutoff: 0.95,
466 interpolation: SincInterpolationType::Linear,
467 oversampling_factor: 256,
468 window: WindowFunction::BlackmanHarris2,
469 };
470
471 let mut resampler = SincFixedIn::<f32>::new(
472 resample_ratio,
473 2.0,
474 params,
475 wav.len(),
476 1, )
478 .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
479
480 let input = vec![wav.to_vec()];
481 let output = resampler
482 .process(&input, None)
483 .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
484
485 Ok((output[0].clone(), target_fs))
486 }
487
488 fn extract_mel_spectrogram_rustfft(&self, wav: &[f32], fs: u32) -> Result<Vec<Vec<f32>>> {
489 let (n_fft, win_length, hop_length) = if fs == 8000 {
491 (256, 200, 80)
492 } else if fs == 16000 {
493 (512, 400, 160)
494 } else {
495 return Err(candle_core::Error::Msg(format!(
496 "Unsupported sample rate: {fs}"
497 )));
498 };
499
500 let preemphasized = self.apply_preemphasis(wav, 0.97);
502
503 let mut planner = FftPlanner::<f32>::new();
505 let fft = planner.plan_fft_forward(n_fft);
506
507 let window: Vec<f64> = hanning_iter(win_length).collect();
509
510 let mel_filters = self.create_mel_filterbank(AUDIO_FEATURE_SIZE, n_fft, fs as f32)?;
512
513 let n_batch = (preemphasized.len() - win_length) / hop_length + 1;
515 let mut mel_features = Vec::new();
516
517 for i in 0..n_batch {
518 let start = i * hop_length;
519 let end = start + win_length;
520 if end > preemphasized.len() {
521 break;
522 }
523
524 let mut windowed: Vec<Complex32> = preemphasized[start..end]
526 .iter()
527 .zip(window.iter())
528 .map(|(s, w)| Complex32::new(s * *w as f32, 0.0))
529 .collect();
530
531 windowed.resize(n_fft, Complex32::new(0.0, 0.0));
533
534 fft.process(&mut windowed);
536
537 let power_spectrum: Vec<f32> = windowed[0..n_fft / 2 + 1]
539 .iter()
540 .map(|c| c.norm_sqr())
541 .collect();
542
543 let mut mel_frame = vec![0.0; AUDIO_FEATURE_SIZE];
545 for (mel_idx, filter) in mel_filters.iter().enumerate() {
546 let mut sum = 0.0;
547 for (freq_idx, &coeff) in filter.iter().enumerate() {
548 if freq_idx < power_spectrum.len() {
549 sum += power_spectrum[freq_idx] * coeff;
550 }
551 }
552 mel_frame[mel_idx] = (sum.max(1.0)).ln(); }
554
555 mel_features.push(mel_frame);
556 }
557
558 if fs == 8000 && self.eightk_method == "fillzero" {
560 for frame in &mut mel_features {
561 let original_len = frame.len();
563 frame.extend(vec![0.0; original_len]);
564 }
565 }
566
567 Ok(mel_features)
568 }
569
570 fn apply_preemphasis(&self, wav: &[f32], preemphasis: f32) -> Vec<f32> {
571 if wav.is_empty() {
572 return vec![];
573 }
574
575 let mut preemphasized = Vec::with_capacity(wav.len());
576
577 preemphasized.push(wav[0] * 32768.0);
579
580 for i in 1..wav.len() {
582 let filtered = (wav[i] - preemphasis * wav[i - 1]) * 32768.0;
583 preemphasized.push(filtered);
584 }
585
586 preemphasized
587 }
588
589 fn create_mel_filterbank(
590 &self,
591 n_mels: usize,
592 n_fft: usize,
593 sample_rate: f32,
594 ) -> Result<Vec<Vec<f32>>> {
595 let bank_width = n_fft / 2 + 1;
596 let fmax = sample_rate / 2.0;
597 let fmin = 0.0;
598
599 let hz_to_mel = |f: f32| 1127.0 * (1.0 + f / 700.0).ln();
601 let mel_to_hz = |mel: f32| 700.0 * (mel / 1127.0).exp() - 700.0;
602
603 let mel_low = hz_to_mel(fmin);
604 let mel_high = hz_to_mel(fmax);
605
606 let mel_centers: Vec<f32> = (0..=n_mels + 1)
608 .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (n_mels + 1) as f32)
609 .collect();
610
611 let hz_centers: Vec<f32> = mel_centers.iter().map(|&mel| mel_to_hz(mel)).collect();
612
613 let bin_centers: Vec<usize> = hz_centers
615 .iter()
616 .map(|&f| ((f * n_fft as f32 / sample_rate) + 0.5) as usize)
617 .collect();
618
619 let mut filters = Vec::new();
621 for m in 0..n_mels {
622 let mut filter = vec![0.0; bank_width];
623
624 let left_bin = bin_centers[m];
625 let center_bin = bin_centers[m + 1];
626 let right_bin = bin_centers[m + 2];
627
628 for (bin, filter) in filter
630 .iter_mut()
631 .enumerate()
632 .take(center_bin)
633 .skip(left_bin)
634 {
635 if bin < bank_width {
636 *filter = (bin - left_bin) as f32 / (center_bin - left_bin) as f32;
637 }
638 }
639
640 for (bin, filter) in filter
642 .iter_mut()
643 .enumerate()
644 .take(right_bin)
645 .skip(center_bin)
646 {
647 if bin < bank_width {
648 *filter = (right_bin - bin) as f32 / (right_bin - center_bin) as f32;
649 }
650 }
651
652 filters.push(filter);
653 }
654
655 Ok(filters)
656 }
657
658 fn compute_audio_embed_size(
659 &self,
660 audio_frames: usize,
661 compression_rate: usize,
662 downsample_rate: usize,
663 ) -> usize {
664 let integer = audio_frames / compression_rate;
666 let remainder = audio_frames % compression_rate;
667 let result = if remainder == 0 { integer } else { integer + 1 };
668
669 let integer = result / downsample_rate;
671 let remainder = result % downsample_rate;
672 if remainder == 0 {
673 integer
674 } else {
675 integer + 1
676 }
677 }
678
679 fn create_audio_attention_mask(
680 &self,
681 audio_frames_list: &[usize],
682 device: &Device,
683 ) -> Result<Tensor> {
684 let max_frames = *audio_frames_list.iter().max().unwrap_or(&0);
685 let batch_size = audio_frames_list.len();
686
687 let mut mask_data = vec![0u8; batch_size * max_frames];
688 for (batch_idx, &frames) in audio_frames_list.iter().enumerate() {
689 for frame_idx in 0..frames.min(max_frames) {
690 mask_data[batch_idx * max_frames + frame_idx] = 1;
691 }
692 }
693
694 Tensor::from_slice(&mask_data, (batch_size, max_frames), device)?.to_dtype(DType::F32)
695 }
696
697 fn process_audio_for_sequences(
698 &self,
699 input_seqs: &mut [&mut Sequence],
700 device: &Device,
701 ) -> AudioProcessingResult {
702 let has_audio_tokens = input_seqs
704 .iter()
705 .any(|seq| seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32)));
706
707 if !has_audio_tokens {
708 return Ok((None, None, None));
709 }
710
711 let mut audio_features_list = Vec::new();
712 let mut audio_embed_sizes_list = Vec::new();
713 let mut audio_frames_list = Vec::new();
714
715 for seq in input_seqs.iter_mut() {
717 let has_audio = seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32));
718
719 if has_audio {
720 if let Some(audios) = seq.take_audios() {
721 for audio in audios.into_iter() {
722 let samples = audio.to_mono();
724
725 let features = self.extract_audio_features(&samples, audio.sample_rate)?;
727 let audio_frames = features.len() * self.audio_feat_stride;
728
729 let embed_size = self.compute_audio_embed_size(
730 audio_frames,
731 self.audio_compression_rate,
732 self.audio_downsample_rate,
733 );
734
735 let features_len = features.len();
737 let features_flat: Vec<f32> = features.into_iter().flatten().collect();
738 let features_tensor = Tensor::from_slice(
739 &features_flat,
740 (features_len, AUDIO_FEATURE_SIZE),
741 device,
742 )?;
743
744 audio_features_list.push(features_tensor);
745 audio_embed_sizes_list.push(embed_size);
746 audio_frames_list.push(audio_frames);
747 }
748 } else {
749 candle_core::bail!("No audios in `process_audio_for_sequences`");
750 };
751 }
752 }
753
754 if audio_features_list.is_empty() {
755 return Ok((None, None, None));
756 }
757
758 let max_len = audio_features_list
760 .iter()
761 .map(|t| t.dim(0).unwrap_or(0))
762 .max()
763 .unwrap_or(0);
764
765 let mut padded_features = Vec::new();
766 for features in audio_features_list {
767 let seq_len = features.dim(0)?;
768 if seq_len < max_len {
769 let padding =
770 Tensor::zeros((max_len - seq_len, AUDIO_FEATURE_SIZE), DType::F32, device)?;
771 let padded = Tensor::cat(&[features, padding], 0)?;
772 padded_features.push(padded);
773 } else {
774 padded_features.push(features);
775 }
776 }
777
778 let input_audio_embeds = Tensor::stack(&padded_features, 0)?;
780
781 let audio_attention_mask = if audio_frames_list.len() > 1 {
783 Some(self.create_audio_attention_mask(&audio_frames_list, device)?)
784 } else {
785 None
786 };
787
788 Ok((
789 Some(input_audio_embeds),
790 Some(audio_embed_sizes_list),
791 audio_attention_mask,
792 ))
793 }
794}
795
796impl Phi4MMInputsProcessor {
797 fn pad_image(
798 image: &DynamicImage,
799 top: u32,
800 bottom: u32,
801 left: u32,
802 right: u32,
803 pad_color: Rgba<u8>,
804 ) -> DynamicImage {
805 let new_width = image.width() + left + right;
807 let new_height = image.height() + top + bottom;
808
809 let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
811 for x in 0..new_width {
812 for y in 0..new_height {
813 new_image.put_pixel(x, y, pad_color);
814 }
815 }
816
817 new_image
819 .copy_from(image, 0, 0)
820 .expect("Failed to copy image");
821
822 new_image
823 }
824
825 fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
826 let mut ratios: HashSet<(u32, u32)> = HashSet::new();
827 for n in min_num..=max_num {
828 for i in 1..=n {
829 for j in 1..=n {
830 if i * j >= min_num && i * j <= max_num {
831 ratios.insert((i, j));
832 }
833 }
834 }
835 }
836 let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
837 sorted_ratios.sort_by_key(|&(i, j)| i * j);
838 sorted_ratios
839 }
840
841 fn find_closest_aspect_ratio(
842 aspect_ratio: f64,
843 target_ratios: Vec<(u32, u32)>,
844 width: u32,
845 height: u32,
846 image_size: usize,
847 ) -> (u32, u32) {
848 let mut best_ratio_diff = f64::INFINITY;
849 let mut best_ratio = (1, 1);
850 let area = width * height;
851 for ratio in target_ratios {
852 let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
853 let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
854 if ratio_diff < best_ratio_diff {
855 best_ratio_diff = ratio_diff;
856 best_ratio = ratio;
857 } else if ratio_diff == best_ratio_diff
858 && area as f64
859 > 0.5 * image_size as f64 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
860 {
861 best_ratio = ratio;
862 }
863 }
864 best_ratio
865 }
866
867 fn dynamic_preprocess(
868 &self,
869 mut image: DynamicImage,
870 min_num: usize,
871 max_num: usize,
872 image_size: usize,
873 mask_size: usize,
874 device: &Device,
875 ) -> Result<(DynamicImage, Tensor)> {
876 let (orig_w, orig_h) = image.dimensions();
877
878 let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
879 let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
880 let (target_aspect_ratio, target_width, target_height) =
881 if w_crop_num * h_crop_num > max_num as f64 {
882 let aspect_ratio = orig_w as f64 / orig_h as f64;
883 let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
884
885 let target_aspect_ratio = Self::find_closest_aspect_ratio(
886 aspect_ratio,
887 target_ratios,
888 orig_w,
889 orig_h,
890 image_size,
891 );
892
893 let target_width = image_size * target_aspect_ratio.0 as usize;
894 let target_height = image_size * target_aspect_ratio.1 as usize;
895
896 (
897 (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
898 target_width,
899 target_height,
900 )
901 } else {
902 let target_width = (image_size as f64 * w_crop_num) as usize;
903 let target_height = (image_size as f64 * h_crop_num) as usize;
904 let target_aspect_ratio = (w_crop_num, h_crop_num);
905
906 (target_aspect_ratio, target_width, target_height)
907 };
908
909 let ratio_width = target_width as f64 / orig_w as f64;
910 let ratio_height = target_height as f64 / orig_h as f64;
911 let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
912 (
913 (target_width, (orig_h as f64 * ratio_width) as usize),
914 0_usize,
915 target_height - (orig_h as f64 * ratio_width) as usize,
916 )
917 } else {
918 (
919 ((orig_w as f64 * ratio_height) as usize, target_height),
920 target_width - (orig_w as f64 * ratio_height) as usize,
921 0_usize,
922 )
923 };
924
925 if new_size.1.min(target_height) < 10 || new_size.0.min(target_width) < 10 {
927 candle_core::bail!(
928 "Image aspect ratio too extreme; resulting size below minimum threshold",
929 );
930 }
931
932 let mut attention_mask = Tensor::ones(
933 (
934 (mask_size as f64 * target_aspect_ratio.1) as usize,
935 (mask_size as f64 * target_aspect_ratio.0) as usize,
936 ),
937 DType::U32,
938 device,
939 )?;
940 if padding_width >= 14 {
941 attention_mask = attention_mask.slice_assign(
942 &[&.., &(attention_mask.dim(1)? - padding_width / 14..)],
943 &Tensor::zeros(
944 (attention_mask.dim(0)?, padding_width / 14),
945 DType::U32,
946 device,
947 )?,
948 )?;
949 }
950 if padding_height >= 14 {
951 attention_mask = attention_mask.slice_assign(
952 &[&(attention_mask.dim(0)? - padding_height / 14..), &..],
953 &Tensor::zeros(
954 (padding_height / 14, attention_mask.dim(1)?),
955 DType::U32,
956 device,
957 )?,
958 )?;
959 }
960
961 let mask_sum: u32 = attention_mask.sum_all()?.to_scalar::<u32>()?;
963 if mask_sum == 0 {
964 candle_core::bail!("dynamic_preprocess produced an attention mask with zero sum",);
965 }
966
967 image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
968 image = Self::pad_image(
969 &image,
970 0,
971 padding_height as u32,
972 0,
973 padding_width as u32,
974 Rgba([255u8, 255, 255, 255]),
975 );
976
977 Ok((image, attention_mask))
978 }
979}
980
981impl ImagePreProcessor for Phi4MMInputsProcessor {
982 #[allow(clippy::excessive_precision)]
983 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
984 #[allow(clippy::excessive_precision)]
985 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
986
987 fn preprocess(
988 &self,
989 mut images: Vec<DynamicImage>,
990 videos: Vec<Vec<DynamicImage>>,
991 config: &PreProcessorConfig,
992 device: &Device,
993 (_, _): (usize, usize),
994 ) -> Result<PreprocessedImages> {
995 assert!(!images.is_empty());
997 assert!(videos.is_empty());
998
999 let mut max_size = None;
1001 for image in images.iter() {
1002 if max_size.is_none() {
1003 max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
1004 } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
1005 max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
1006 } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
1007 max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
1008 }
1009 }
1010 let (max_w, max_h) = max_size.unwrap();
1011 for image in images.iter_mut() {
1012 *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
1013 }
1014
1015 let mut image_sizes = Vec::new();
1016 let mut padded_images = Vec::new();
1017 let mut padded_masks = Vec::new();
1018 let mut num_img_tokens = Vec::new();
1019 for mut image in images {
1020 if config.do_convert_rgb.unwrap_or(true) {
1022 image = DynamicImage::ImageRgb8(image.to_rgb8());
1023 }
1024
1025 let transforms = Transforms {
1026 input: &ToTensor,
1027 inner_transforms: &[&Normalize {
1028 mean: vec![0.5, 0.5, 0.5],
1029 std: vec![0.5, 0.5, 0.5],
1030 }],
1031 };
1032 let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
1034 let base_resolution = dyhd_base_resolution;
1035 let mask_resolution = base_resolution / 14;
1037 let min_num = 1;
1038
1039 let (elem, attention_mask) = self.dynamic_preprocess(
1040 image,
1041 min_num,
1042 config.dynamic_hd.unwrap(),
1043 base_resolution,
1044 mask_resolution,
1045 device,
1046 )?;
1047
1048 let hd_image = elem.apply(transforms, device)?;
1049 let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
1050 let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
1051
1052 let global_image = hd_image
1054 .unsqueeze(0)?
1055 .interpolate2d(base_resolution, base_resolution)?;
1056 let global_attention_mask =
1057 Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
1058
1059 let hd_image_reshape = hd_image
1060 .reshape((
1061 1,
1062 3,
1063 (img_h as f32 / base_resolution as f32) as usize,
1064 base_resolution,
1065 (img_w as f32 / base_resolution as f32) as usize,
1066 base_resolution,
1067 ))?
1068 .permute((0, 2, 4, 1, 3, 5))?
1069 .reshape(((), 3, base_resolution, base_resolution))?;
1070
1071 let attention_mask_reshape = attention_mask
1072 .reshape((
1073 1,
1074 (mask_h as f32 / mask_resolution as f32) as usize,
1075 mask_resolution,
1076 (mask_w as f32 / mask_resolution as f32) as usize,
1077 mask_resolution,
1078 ))?
1079 .permute((0, 1, 3, 2, 4))?
1080 .reshape(((), mask_resolution, mask_resolution))?;
1081
1082 let downsample_attention_mask = {
1083 let h_indices =
1084 Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
1085 let w_indices =
1086 Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
1087 let selected = attention_mask_reshape
1088 .index_select(&h_indices, 1)?
1089 .index_select(&w_indices, 2)?;
1090
1091 let mask = selected
1092 .reshape((
1093 1,
1094 mask_h / mask_resolution,
1095 mask_w / mask_resolution,
1096 mask_resolution / 2 + mask_resolution % 2,
1097 mask_resolution / 2 + mask_resolution % 2,
1098 ))?
1099 .permute((0, 1, 3, 2, 4))?;
1100 mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
1101 };
1102
1103 let img_tokens = 256
1104 + 1
1105 + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
1106 + downsample_attention_mask
1107 .i((.., 0))?
1108 .sum_all()?
1109 .to_scalar::<u32>()? as usize
1110 + 16;
1111
1112 let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
1113 let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
1114
1115 image_sizes.push((img_h as u32, img_w as u32));
1116 padded_images.push(hd_image_reshape);
1117 padded_masks.push(hd_mask_reshape);
1118 num_img_tokens.push(img_tokens);
1119 }
1120 Ok(PreprocessedImages {
1121 pixel_values: Tensor::stack(&padded_images, 0)?,
1122 pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
1123 image_sizes: None,
1124 num_img_tokens: Some(num_img_tokens),
1125 aspect_ratio_ids: None,
1126 aspect_ratio_mask: None,
1127 num_tiles: None,
1128 image_grid_thw: None,
1129 video_grid_thw: None,
1130 rows: None,
1131 cols: None,
1132 pixel_values_list: None,
1133 tgt_sizes: None,
1134 image_sizes_all: Some(image_sizes),
1135 num_crops: None,
1136 })
1137 }
1138}