mistralrs_core/vision_models/phi4/
inputs_processor.rs

1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10
11use apodize::hanning_iter;
12use rubato::{
13    Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
14};
15use rustfft::{num_complex::Complex32, FftPlanner};
16
17use crate::{
18    device_map::DeviceMapper,
19    pipeline::{
20        text_models_inputs_processor::{
21            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
22        },
23        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
24        ProcessorCreator,
25    },
26    sequence::Sequence,
27};
28
29use crate::vision_models::{
30    image_processor::{ImagePreProcessor, PreprocessedImages},
31    phi4::Phi4MMVisionSpecificArgs,
32    preprocessor_config::PreProcessorConfig,
33    processor_config::ProcessorConfig,
34    ModelInputs,
35};
36
37use super::audio_embedding::AUDIO_SPECIAL_TOKEN_ID;
38use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
39
40const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
41const COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN: &str = r"<\|audio_\d+\|>";
42const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
43const AUDIO_SPECIAL_TOKEN: &str = "<|endoftext11|>";
44pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
45
46const AUDIO_FEATURE_SIZE: usize = 80; // mel bins
47
48type AudioProcessingResult = Result<(Option<Tensor>, Option<Vec<usize>>, Option<Tensor>)>;
49
50// Input processor
51pub struct Phi4MMInputsProcessor {
52    audio_compression_rate: usize,
53    audio_downsample_rate: usize,
54    audio_feat_stride: usize,
55    eightk_method: String, // "fillzero" or "resample"
56}
57
58// Processor
59pub struct Phi4MMProcessor {
60    inputs_processor: Arc<Phi4MMInputsProcessor>,
61}
62
63impl ProcessorCreator for Phi4MMProcessor {
64    fn new_processor(
65        _: Option<ProcessorConfig>,
66        pre_processor_config: PreProcessorConfig,
67    ) -> Arc<dyn Processor + Send + Sync> {
68        Arc::new(Self {
69            inputs_processor: Arc::new(Phi4MMInputsProcessor {
70                audio_compression_rate: pre_processor_config
71                    .audio_compression_rate
72                    .expect("audio_compression_rate"),
73                audio_downsample_rate: pre_processor_config
74                    .audio_downsample_rate
75                    .expect("audio_downsample_rate"),
76                audio_feat_stride: pre_processor_config
77                    .audio_feat_stride
78                    .expect("audio_feat_stride"),
79                eightk_method: "fillzero".to_string(), // Default to fillzero
80            }),
81        })
82    }
83}
84
85impl Processor for Phi4MMProcessor {
86    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
87        self.inputs_processor.clone()
88    }
89    fn get_special_tokens(&self) -> &[&'static str] {
90        &[]
91    }
92    fn template_action(&self) -> MessagesAction {
93        MessagesAction::FlattenOnlyText
94    }
95}
96
97impl InputsProcessor for Phi4MMInputsProcessor {
98    fn get_type(&self) -> InputsProcessorType {
99        InputsProcessorType::Vision
100    }
101    fn process_inputs(
102        &self,
103        tokenizer: Option<Arc<Tokenizer>>,
104        input_seqs: &mut [&mut Sequence],
105        is_prompt: bool,
106        is_xlora: bool,
107        device: &Device,
108        no_kv_cache: bool,
109        last_n_context_len: Option<(usize, usize)>,
110        return_raw_logits: bool,
111        other_config: Option<Arc<dyn Any>>,
112        mut paged_attn_metadata: Option<PagedAttentionMeta>,
113        mapper: Option<&dyn DeviceMapper>,
114    ) -> anyhow::Result<InputProcessorOutput> {
115        if is_xlora {
116            return Err(anyhow::Error::msg(
117                "Cannot make inputs for X-LoRA vision model.",
118            ));
119        }
120        if no_kv_cache {
121            return Err(anyhow::Error::msg("Vision model must have kv cache."));
122        }
123        let Some(tokenizer) = tokenizer else {
124            return Err(anyhow::Error::msg(
125                "Phi4MMInputProcessor requires a specified tokenizer.",
126            ));
127        };
128
129        let config = other_config
130            .clone()
131            .expect("Need a PreProcessorConfig config.");
132        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
133
134        let has_audios = input_seqs.iter().all(|seq| seq.has_audios());
135        let has_images = input_seqs.iter().all(|seq| seq.has_images());
136
137        let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
138            let mut pixel_values_accum = Vec::new();
139            let mut pixel_attention_masks_accum = Vec::new();
140            let mut image_sizes_accum = Vec::new();
141            let mut num_img_tokens_accum = Vec::new();
142            for seq in input_seqs.iter_mut() {
143                let imgs = seq
144                    .take_images()
145                    .expect("Need to have images by this point.");
146                let PreprocessedImages {
147                    pixel_values,
148                    pixel_attention_mask,
149                    image_sizes: _,
150                    num_img_tokens,
151                    aspect_ratio_ids: _,
152                    aspect_ratio_mask: _,
153                    num_tiles: _,
154                    image_grid_thw: _,
155                    video_grid_thw: _,
156                    rows: _,
157                    cols: _,
158                    pixel_values_list: _,
159                    tgt_sizes: _,
160                    image_sizes_all,
161                    num_crops: _,
162                } = self
163                    .preprocess(
164                        imgs,
165                        vec![],
166                        config,
167                        device,
168                        (usize::MAX, usize::MAX), // Don't use it here...
169                    )
170                    .expect("Preprocessor failed");
171                let image_sizes = image_sizes_all.unwrap();
172                let pixel_attention_mask = pixel_attention_mask.unwrap();
173                pixel_values_accum.push(pixel_values);
174                pixel_attention_masks_accum.push(pixel_attention_mask);
175                // Using extend on purpose
176                image_sizes_accum.extend(image_sizes);
177                num_img_tokens_accum.push(num_img_tokens.unwrap());
178            }
179            (
180                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
181                Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
182                Some(image_sizes_accum),
183                Some(num_img_tokens_accum),
184            )
185        } else if has_audios {
186            (None, None, None, Some(vec![vec![]; input_seqs.len()]))
187        } else {
188            return text_models_inputs_processor::TextInputsProcessor
189                .process_inputs(
190                    Some(tokenizer),
191                    input_seqs,
192                    is_prompt,
193                    is_xlora,
194                    device,
195                    no_kv_cache,
196                    last_n_context_len,
197                    return_raw_logits,
198                    other_config,
199                    paged_attn_metadata,
200                    mapper,
201                )
202                .map(|metadata| {
203                    let InputProcessorOutput {
204                        inputs,
205                        seq_indices,
206                    } = metadata;
207
208                    let text_models_inputs_processor::ModelInputs {
209                        input_ids,
210                        input_ids_full: _,
211                        seqlen_offsets,
212                        seqlen_offsets_full: _,
213                        context_lens,
214                        position_ids,
215                        paged_attn_meta,
216                        flash_meta,
217                        flash_meta_full: _,
218                    } = *inputs
219                        .downcast::<text_models_inputs_processor::ModelInputs>()
220                        .expect("Downcast failed.");
221
222                    let inputs: Box<dyn Any> = Box::new(ModelInputs {
223                        input_ids,
224                        seqlen_offsets,
225                        context_lens,
226                        position_ids,
227                        pixel_values: None,
228                        model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
229                            input_image_embeds: None,
230                            image_attention_mask: None,
231                            image_sizes: None,
232                            input_audio_embeds: None,
233                            audio_embed_sizes: None,
234                            audio_attention_mask: None,
235                        }),
236                        paged_attn_meta,
237                        flash_meta,
238                    });
239                    InputProcessorOutput {
240                        inputs,
241                        seq_indices,
242                    }
243                });
244        };
245
246        let detokenized = tokenizer
247            .decode_batch(
248                &input_seqs
249                    .iter()
250                    .map(|seq| seq.get_toks())
251                    .collect::<Vec<_>>(),
252                false,
253            )
254            .expect("Decode failed");
255
256        let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
257        let audio_token_pattern = Regex::new(COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN).unwrap();
258
259        for (mut detokenized, seq) in detokenized.into_iter().zip(input_seqs.iter_mut()) {
260            detokenized = img_token_pattern
261                .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
262                .to_string();
263            detokenized = audio_token_pattern
264                .replace_all(&detokenized, AUDIO_SPECIAL_TOKEN)
265                .to_string();
266
267            let has_changed_prompt = seq.multimodal.has_changed_prompt;
268            if !has_changed_prompt {
269                seq.set_toks_and_reallocate(
270                    tokenizer
271                        .encode_fast(detokenized.clone(), false)
272                        .expect("Encode failed")
273                        .get_ids()
274                        .to_vec(),
275                    paged_attn_metadata.as_mut(),
276                );
277
278                seq.set_initial_prompt(detokenized);
279            }
280        }
281
282        let (input_audio_embeds, audio_embed_sizes, audio_attention_mask) =
283            match self.process_audio_for_sequences(input_seqs, device) {
284                Ok(result) => result,
285                Err(e) => return Err(anyhow::Error::new(e)),
286            };
287
288        let mut toks = Vec::new();
289
290        for (seq, num_img_tokens) in input_seqs.iter_mut().zip(num_img_tokens.unwrap()) {
291            let has_changed_prompt = seq.multimodal.has_changed_prompt;
292
293            let mut i = 0;
294            let mut image_token_count_iter = num_img_tokens.iter();
295            let audio_sizes_tmp = audio_embed_sizes.clone().unwrap_or(vec![]);
296            let mut audio_embed_sizes = audio_sizes_tmp.iter();
297            while i < seq.get_toks().len() {
298                let token_id = seq.get_toks()[i];
299                let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
300                    image_token_count_iter.next().unwrap()
301                } else if token_id == AUDIO_SPECIAL_TOKEN_ID as u32 {
302                    audio_embed_sizes.next().unwrap()
303                } else {
304                    i += 1;
305                    continue;
306                };
307
308                let mut new_ids = seq.get_toks()[..i].to_vec();
309                new_ids.extend(vec![token_id; *token_count]);
310                new_ids.extend(seq.get_toks()[i + 1..].to_vec());
311                if !has_changed_prompt {
312                    seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
313                }
314                i += token_count;
315            }
316            if !has_changed_prompt {
317                seq.multimodal.has_changed_prompt = true;
318            }
319            toks.push(seq.get_toks().to_vec());
320        }
321
322        let result = if is_prompt {
323            get_prompt_input(
324                toks.iter().map(Vec::as_slice).collect(),
325                input_seqs,
326                device,
327                last_n_context_len,
328                return_raw_logits,
329                paged_attn_metadata.as_mut(),
330                mapper,
331            )
332        } else {
333            get_completion_input(
334                toks.iter().map(Vec::as_slice).collect(),
335                input_seqs,
336                device,
337                no_kv_cache,
338                last_n_context_len,
339                return_raw_logits,
340                paged_attn_metadata.as_mut(),
341                mapper,
342            )
343        };
344
345        result.map(move |metadata| {
346            let pixel_values = pixel_values.clone();
347            let pixel_attention_mask = pixel_attention_mask.clone();
348            let text_models_inputs_processor::InnerInputProcessorOutput {
349                inputs:
350                    text_models_inputs_processor::InputMetadata {
351                        input,
352                        positions,
353                        context_lens,
354                        position_ids,
355                        paged_attn_meta,
356                        flash_meta,
357                    },
358                seq_indices,
359            } = metadata;
360            let inputs: Box<dyn Any> = Box::new(ModelInputs {
361                input_ids: input,
362                seqlen_offsets: positions,
363                context_lens,
364                position_ids,
365                pixel_values: pixel_values.clone(),
366                model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
367                    input_image_embeds: pixel_values,
368                    image_attention_mask: pixel_attention_mask,
369                    image_sizes: image_sizes.clone(),
370                    input_audio_embeds: input_audio_embeds.clone(),
371                    audio_embed_sizes: audio_embed_sizes.clone(),
372                    audio_attention_mask: audio_attention_mask.clone(),
373                }),
374                paged_attn_meta,
375                flash_meta,
376            });
377            InputProcessorOutput {
378                inputs,
379                seq_indices,
380            }
381        })
382    }
383}
384
385impl Phi4MMInputsProcessor {
386    fn extract_audio_features(
387        &self,
388        audio_data: &[f32],
389        sample_rate: u32,
390    ) -> Result<Vec<Vec<f32>>> {
391        // Resample audio to supported rates using rubato
392        let (resampled_audio, final_sample_rate) =
393            self.resample_audio_with_rubato(audio_data, sample_rate)?;
394
395        // Extract mel spectrogram using rustfft and custom mel filterbank
396        let mel_features =
397            self.extract_mel_spectrogram_rustfft(&resampled_audio, final_sample_rate)?;
398
399        Ok(mel_features)
400    }
401
402    fn resample_audio_with_rubato(&self, wav: &[f32], fs: u32) -> Result<(Vec<f32>, u32)> {
403        let target_fs = if fs > 16000 {
404            16000
405        } else if fs > 8000 && fs < 16000 {
406            8000
407        } else if fs < 8000 {
408            return Err(candle_core::Error::Msg(format!(
409                "Unsupported sample rate: {fs}"
410            )));
411        } else {
412            return Ok((wav.to_vec(), fs)); // No resampling needed
413        };
414
415        if fs == target_fs {
416            return Ok((wav.to_vec(), fs));
417        }
418
419        // Handle 8kHz upsampling case
420        if fs == 8000 && self.eightk_method == "resample" {
421            // Upsample to 16kHz using rubato
422            let params = SincInterpolationParameters {
423                sinc_len: 256,
424                f_cutoff: 0.95,
425                interpolation: SincInterpolationType::Linear,
426                oversampling_factor: 256,
427                window: WindowFunction::BlackmanHarris2,
428            };
429
430            let mut resampler = SincFixedIn::<f32>::new(
431                2.0, // resample ratio (16000/8000)
432                2.0,
433                params,
434                wav.len(),
435                1, // mono
436            )
437            .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
438
439            let input = vec![wav.to_vec()];
440            let output = resampler
441                .process(&input, None)
442                .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
443
444            return Ok((output[0].clone(), 16000));
445        }
446
447        // Regular downsampling
448        let resample_ratio = target_fs as f64 / fs as f64;
449
450        let params = SincInterpolationParameters {
451            sinc_len: 256,
452            f_cutoff: 0.95,
453            interpolation: SincInterpolationType::Linear,
454            oversampling_factor: 256,
455            window: WindowFunction::BlackmanHarris2,
456        };
457
458        let mut resampler = SincFixedIn::<f32>::new(
459            resample_ratio,
460            2.0,
461            params,
462            wav.len(),
463            1, // mono
464        )
465        .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
466
467        let input = vec![wav.to_vec()];
468        let output = resampler
469            .process(&input, None)
470            .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
471
472        Ok((output[0].clone(), target_fs))
473    }
474
475    fn extract_mel_spectrogram_rustfft(&self, wav: &[f32], fs: u32) -> Result<Vec<Vec<f32>>> {
476        // Set parameters based on sample rate
477        let (n_fft, win_length, hop_length) = if fs == 8000 {
478            (256, 200, 80)
479        } else if fs == 16000 {
480            (512, 400, 160)
481        } else {
482            return Err(candle_core::Error::Msg(format!(
483                "Unsupported sample rate: {fs}"
484            )));
485        };
486
487        // Apply preemphasis first
488        let preemphasized = self.apply_preemphasis(wav, 0.97);
489
490        // Create FFT planner
491        let mut planner = FftPlanner::<f32>::new();
492        let fft = planner.plan_fft_forward(n_fft);
493
494        // Create Hanning window
495        let window: Vec<f64> = hanning_iter(win_length).collect();
496
497        // Create mel filterbank
498        let mel_filters = self.create_mel_filterbank(AUDIO_FEATURE_SIZE, n_fft, fs as f32)?;
499
500        // Extract frames and apply STFT
501        let n_batch = (preemphasized.len() - win_length) / hop_length + 1;
502        let mut mel_features = Vec::new();
503
504        for i in 0..n_batch {
505            let start = i * hop_length;
506            let end = start + win_length;
507            if end > preemphasized.len() {
508                break;
509            }
510
511            // Apply window and convert to complex
512            let mut windowed: Vec<Complex32> = preemphasized[start..end]
513                .iter()
514                .zip(window.iter())
515                .map(|(s, w)| Complex32::new(s * *w as f32, 0.0))
516                .collect();
517
518            // Pad to n_fft length
519            windowed.resize(n_fft, Complex32::new(0.0, 0.0));
520
521            // Apply FFT
522            fft.process(&mut windowed);
523
524            // Take power spectrum of positive frequencies
525            let power_spectrum: Vec<f32> = windowed[0..n_fft / 2 + 1]
526                .iter()
527                .map(|c| c.norm_sqr())
528                .collect();
529
530            // Apply mel filterbank
531            let mut mel_frame = vec![0.0; AUDIO_FEATURE_SIZE];
532            for (mel_idx, filter) in mel_filters.iter().enumerate() {
533                let mut sum = 0.0;
534                for (freq_idx, &coeff) in filter.iter().enumerate() {
535                    if freq_idx < power_spectrum.len() {
536                        sum += power_spectrum[freq_idx] * coeff;
537                    }
538                }
539                mel_frame[mel_idx] = (sum.max(1.0)).ln(); // Apply log with clipping
540            }
541
542            mel_features.push(mel_frame);
543        }
544
545        // Handle 8kHz case with fillzero method
546        if fs == 8000 && self.eightk_method == "fillzero" {
547            for frame in &mut mel_features {
548                // Extend each frame with zeros to match 16kHz structure
549                let original_len = frame.len();
550                frame.extend(vec![0.0; original_len]);
551            }
552        }
553
554        Ok(mel_features)
555    }
556
557    fn apply_preemphasis(&self, wav: &[f32], preemphasis: f32) -> Vec<f32> {
558        if wav.is_empty() {
559            return vec![];
560        }
561
562        let mut preemphasized = Vec::with_capacity(wav.len());
563
564        // First sample: y[0] = x[0] * 32768
565        preemphasized.push(wav[0] * 32768.0);
566
567        // Remaining samples: y[n] = (x[n] - preemphasis * x[n-1]) * 32768
568        for i in 1..wav.len() {
569            let filtered = (wav[i] - preemphasis * wav[i - 1]) * 32768.0;
570            preemphasized.push(filtered);
571        }
572
573        preemphasized
574    }
575
576    fn create_mel_filterbank(
577        &self,
578        n_mels: usize,
579        n_fft: usize,
580        sample_rate: f32,
581    ) -> Result<Vec<Vec<f32>>> {
582        let bank_width = n_fft / 2 + 1;
583        let fmax = sample_rate / 2.0;
584        let fmin = 0.0;
585
586        // Mel scale conversion functions
587        let hz_to_mel = |f: f32| 1127.0 * (1.0 + f / 700.0).ln();
588        let mel_to_hz = |mel: f32| 700.0 * (mel / 1127.0).exp() - 700.0;
589
590        let mel_low = hz_to_mel(fmin);
591        let mel_high = hz_to_mel(fmax);
592
593        // Create mel centers
594        let mel_centers: Vec<f32> = (0..=n_mels + 1)
595            .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (n_mels + 1) as f32)
596            .collect();
597
598        let hz_centers: Vec<f32> = mel_centers.iter().map(|&mel| mel_to_hz(mel)).collect();
599
600        // Convert to bin numbers
601        let bin_centers: Vec<usize> = hz_centers
602            .iter()
603            .map(|&f| ((f * n_fft as f32 / sample_rate) + 0.5) as usize)
604            .collect();
605
606        // Create triangular filters
607        let mut filters = Vec::new();
608        for m in 0..n_mels {
609            let mut filter = vec![0.0; bank_width];
610
611            let left_bin = bin_centers[m];
612            let center_bin = bin_centers[m + 1];
613            let right_bin = bin_centers[m + 2];
614
615            // Left slope
616            for (bin, filter) in filter
617                .iter_mut()
618                .enumerate()
619                .take(center_bin)
620                .skip(left_bin)
621            {
622                if bin < bank_width {
623                    *filter = (bin - left_bin) as f32 / (center_bin - left_bin) as f32;
624                }
625            }
626
627            // Right slope
628            for (bin, filter) in filter
629                .iter_mut()
630                .enumerate()
631                .take(right_bin)
632                .skip(center_bin)
633            {
634                if bin < bank_width {
635                    *filter = (right_bin - bin) as f32 / (right_bin - center_bin) as f32;
636                }
637            }
638
639            filters.push(filter);
640        }
641
642        Ok(filters)
643    }
644
645    fn compute_audio_embed_size(
646        &self,
647        audio_frames: usize,
648        compression_rate: usize,
649        downsample_rate: usize,
650    ) -> usize {
651        // First compression
652        let integer = audio_frames / compression_rate;
653        let remainder = audio_frames % compression_rate;
654        let result = if remainder == 0 { integer } else { integer + 1 };
655
656        // Second compression (qformer)
657        let integer = result / downsample_rate;
658        let remainder = result % downsample_rate;
659        if remainder == 0 {
660            integer
661        } else {
662            integer + 1
663        }
664    }
665
666    fn create_audio_attention_mask(
667        &self,
668        audio_frames_list: &[usize],
669        device: &Device,
670    ) -> Result<Tensor> {
671        let max_frames = *audio_frames_list.iter().max().unwrap_or(&0);
672        let batch_size = audio_frames_list.len();
673
674        let mut mask_data = vec![0u8; batch_size * max_frames];
675        for (batch_idx, &frames) in audio_frames_list.iter().enumerate() {
676            for frame_idx in 0..frames.min(max_frames) {
677                mask_data[batch_idx * max_frames + frame_idx] = 1;
678            }
679        }
680
681        Tensor::from_slice(&mask_data, (batch_size, max_frames), device)?.to_dtype(DType::F32)
682    }
683
684    fn process_audio_for_sequences(
685        &self,
686        input_seqs: &mut [&mut Sequence],
687        device: &Device,
688    ) -> AudioProcessingResult {
689        // Check if any sequence has audio tokens
690        let has_audio_tokens = input_seqs
691            .iter()
692            .any(|seq| seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32)));
693
694        if !has_audio_tokens {
695            return Ok((None, None, None));
696        }
697
698        let mut audio_features_list = Vec::new();
699        let mut audio_embed_sizes_list = Vec::new();
700        let mut audio_frames_list = Vec::new();
701
702        // Process audio for each sequence that needs it
703        for seq in input_seqs.iter_mut() {
704            let has_audio = seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32));
705
706            if has_audio {
707                if let Some(audios) = seq.take_audios() {
708                    for audio in audios.into_iter() {
709                        // Convert multi-channel audio to mono by averaging channels
710                        let samples = audio.to_mono();
711
712                        // Extract features
713                        let features = self.extract_audio_features(&samples, audio.sample_rate)?;
714                        let audio_frames = features.len() * self.audio_feat_stride;
715
716                        let embed_size = self.compute_audio_embed_size(
717                            audio_frames,
718                            self.audio_compression_rate,
719                            self.audio_downsample_rate,
720                        );
721
722                        // Convert to tensor
723                        let features_len = features.len();
724                        let features_flat: Vec<f32> = features.into_iter().flatten().collect();
725                        let features_tensor = Tensor::from_slice(
726                            &features_flat,
727                            (features_len, AUDIO_FEATURE_SIZE),
728                            device,
729                        )?;
730
731                        audio_features_list.push(features_tensor);
732                        audio_embed_sizes_list.push(embed_size);
733                        audio_frames_list.push(audio_frames);
734                    }
735                } else {
736                    candle_core::bail!("No audios in `process_audio_for_sequences`");
737                };
738            }
739        }
740
741        if audio_features_list.is_empty() {
742            return Ok((None, None, None));
743        }
744
745        // Pad sequences to same length
746        let max_len = audio_features_list
747            .iter()
748            .map(|t| t.dim(0).unwrap_or(0))
749            .max()
750            .unwrap_or(0);
751
752        let mut padded_features = Vec::new();
753        for features in audio_features_list {
754            let seq_len = features.dim(0)?;
755            if seq_len < max_len {
756                let padding =
757                    Tensor::zeros((max_len - seq_len, AUDIO_FEATURE_SIZE), DType::F32, device)?;
758                let padded = Tensor::cat(&[features, padding], 0)?;
759                padded_features.push(padded);
760            } else {
761                padded_features.push(features);
762            }
763        }
764
765        // Stack into batch tensor
766        let input_audio_embeds = Tensor::stack(&padded_features, 0)?;
767
768        // Create attention mask if multiple sequences
769        let audio_attention_mask = if audio_frames_list.len() > 1 {
770            Some(self.create_audio_attention_mask(&audio_frames_list, device)?)
771        } else {
772            None
773        };
774
775        Ok((
776            Some(input_audio_embeds),
777            Some(audio_embed_sizes_list),
778            audio_attention_mask,
779        ))
780    }
781}
782
783impl Phi4MMInputsProcessor {
784    fn pad_image(
785        image: &DynamicImage,
786        top: u32,
787        bottom: u32,
788        left: u32,
789        right: u32,
790        pad_color: Rgba<u8>,
791    ) -> DynamicImage {
792        // Calculate the new dimensions
793        let new_width = image.width() + left + right;
794        let new_height = image.height() + top + bottom;
795
796        // Create a new image with the new dimensions and fill it with the pad color
797        let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
798        for x in 0..new_width {
799            for y in 0..new_height {
800                new_image.put_pixel(x, y, pad_color);
801            }
802        }
803
804        // Paste the original image into the center of the new image
805        new_image
806            .copy_from(image, 0, 0)
807            .expect("Failed to copy image");
808
809        new_image
810    }
811
812    fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
813        let mut ratios: HashSet<(u32, u32)> = HashSet::new();
814        for n in min_num..=max_num {
815            for i in 1..=n {
816                for j in 1..=n {
817                    if i * j >= min_num && i * j <= max_num {
818                        ratios.insert((i, j));
819                    }
820                }
821            }
822        }
823        let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
824        sorted_ratios.sort_by_key(|&(i, j)| i * j);
825        sorted_ratios
826    }
827
828    fn find_closest_aspect_ratio(
829        aspect_ratio: f64,
830        target_ratios: Vec<(u32, u32)>,
831        width: u32,
832        height: u32,
833        image_size: usize,
834    ) -> (u32, u32) {
835        let mut best_ratio_diff = f64::INFINITY;
836        let mut best_ratio = (1, 1);
837        let area = width * height;
838        for ratio in target_ratios {
839            let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
840            let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
841            if ratio_diff < best_ratio_diff {
842                best_ratio_diff = ratio_diff;
843                best_ratio = ratio;
844            } else if ratio_diff == best_ratio_diff
845                && area as f64
846                    > 0.5 * image_size as f64 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
847            {
848                best_ratio = ratio;
849            }
850        }
851        best_ratio
852    }
853
854    fn dynamic_preprocess(
855        &self,
856        mut image: DynamicImage,
857        min_num: usize,
858        max_num: usize,
859        image_size: usize,
860        mask_size: usize,
861        device: &Device,
862    ) -> Result<(DynamicImage, Tensor)> {
863        let (orig_w, orig_h) = image.dimensions();
864
865        let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
866        let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
867        let (target_aspect_ratio, target_width, target_height) =
868            if w_crop_num * h_crop_num > max_num as f64 {
869                let aspect_ratio = orig_w as f64 / orig_h as f64;
870                let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
871
872                let target_aspect_ratio = Self::find_closest_aspect_ratio(
873                    aspect_ratio,
874                    target_ratios,
875                    orig_w,
876                    orig_h,
877                    image_size,
878                );
879
880                let target_width = image_size * target_aspect_ratio.0 as usize;
881                let target_height = image_size * target_aspect_ratio.1 as usize;
882
883                (
884                    (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
885                    target_width,
886                    target_height,
887                )
888            } else {
889                let target_width = (image_size as f64 * w_crop_num) as usize;
890                let target_height = (image_size as f64 * h_crop_num) as usize;
891                let target_aspect_ratio = (w_crop_num, h_crop_num);
892
893                (target_aspect_ratio, target_width, target_height)
894            };
895
896        let ratio_width = target_width as f64 / orig_w as f64;
897        let ratio_height = target_height as f64 / orig_h as f64;
898        let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
899            (
900                (target_width, (orig_h as f64 * ratio_width) as usize),
901                0_usize,
902                target_height - (orig_h as f64 * ratio_width) as usize,
903            )
904        } else {
905            (
906                ((orig_w as f64 * ratio_height) as usize, target_height),
907                target_width - (orig_w as f64 * ratio_height) as usize,
908                0_usize,
909            )
910        };
911
912        // Guard against extreme aspect ratios resulting in too-small dimensions
913        if new_size.1.min(target_height) < 10 || new_size.0.min(target_width) < 10 {
914            candle_core::bail!(
915                "Image aspect ratio too extreme; resulting size below minimum threshold",
916            );
917        }
918
919        let mut attention_mask = Tensor::ones(
920            (
921                (mask_size as f64 * target_aspect_ratio.1) as usize,
922                (mask_size as f64 * target_aspect_ratio.0) as usize,
923            ),
924            DType::U32,
925            device,
926        )?;
927        if padding_width >= 14 {
928            attention_mask = attention_mask.slice_assign(
929                &[
930                    0..attention_mask.dim(0)?,
931                    (attention_mask.dim(1)? - padding_width / 14)..attention_mask.dim(1)?,
932                ],
933                &Tensor::zeros(
934                    (attention_mask.dim(0)?, padding_width / 14),
935                    DType::U32,
936                    device,
937                )?,
938            )?;
939        }
940        if padding_height >= 14 {
941            attention_mask = attention_mask.slice_assign(
942                &[
943                    (attention_mask.dim(0)? - padding_height / 14)..attention_mask.dim(0)?,
944                    0..attention_mask.dim(1)?,
945                ],
946                &Tensor::zeros(
947                    (padding_height / 14, attention_mask.dim(1)?),
948                    DType::U32,
949                    device,
950                )?,
951            )?;
952        }
953
954        // Ensure the attention mask is non-empty
955        let mask_sum: u32 = attention_mask.sum_all()?.to_scalar::<u32>()?;
956        if mask_sum == 0 {
957            candle_core::bail!("dynamic_preprocess produced an attention mask with zero sum",);
958        }
959
960        image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
961        image = Self::pad_image(
962            &image,
963            0,
964            padding_height as u32,
965            0,
966            padding_width as u32,
967            Rgba([255u8, 255, 255, 255]),
968        );
969
970        Ok((image, attention_mask))
971    }
972}
973
974impl ImagePreProcessor for Phi4MMInputsProcessor {
975    #[allow(clippy::excessive_precision)]
976    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
977    #[allow(clippy::excessive_precision)]
978    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
979
980    fn preprocess(
981        &self,
982        mut images: Vec<DynamicImage>,
983        videos: Vec<Vec<DynamicImage>>,
984        config: &PreProcessorConfig,
985        device: &Device,
986        (_, _): (usize, usize),
987    ) -> Result<PreprocessedImages> {
988        // If no images, will not call this.
989        assert!(!images.is_empty());
990        assert!(videos.is_empty());
991
992        // If >1 images, resize them all to the largest, potentially destroying aspect ratio
993        let mut max_size = None;
994        for image in images.iter() {
995            if max_size.is_none() {
996                max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
997            } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
998                max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
999            } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
1000                max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
1001            }
1002        }
1003        let (max_w, max_h) = max_size.unwrap();
1004        for image in images.iter_mut() {
1005            *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
1006        }
1007
1008        let mut image_sizes = Vec::new();
1009        let mut padded_images = Vec::new();
1010        let mut padded_masks = Vec::new();
1011        let mut num_img_tokens = Vec::new();
1012        for mut image in images {
1013            // Convert to rgb, default to true
1014            if config.do_convert_rgb.unwrap_or(true) {
1015                image = DynamicImage::ImageRgb8(image.to_rgb8());
1016            }
1017
1018            let transforms = Transforms {
1019                input: &ToTensor,
1020                inner_transforms: &[&Normalize {
1021                    mean: vec![0.5, 0.5, 0.5],
1022                    std: vec![0.5, 0.5, 0.5],
1023                }],
1024            };
1025            // Dynamic HD
1026            let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
1027            let base_resolution = dyhd_base_resolution;
1028            // over 384 and 448 resolution
1029            let mask_resolution = base_resolution / 14;
1030            let min_num = 1;
1031
1032            let (elem, attention_mask) = self.dynamic_preprocess(
1033                image,
1034                min_num,
1035                config.dynamic_hd.unwrap(),
1036                base_resolution,
1037                mask_resolution,
1038                device,
1039            )?;
1040
1041            let hd_image = elem.apply(transforms, device)?;
1042            let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
1043            let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
1044
1045            // Resize with bicubic interpolation
1046            let global_image = hd_image
1047                .unsqueeze(0)?
1048                .interpolate2d(base_resolution, base_resolution)?;
1049            let global_attention_mask =
1050                Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
1051
1052            let hd_image_reshape = hd_image
1053                .reshape((
1054                    1,
1055                    3,
1056                    (img_h as f32 / base_resolution as f32) as usize,
1057                    base_resolution,
1058                    (img_w as f32 / base_resolution as f32) as usize,
1059                    base_resolution,
1060                ))?
1061                .permute((0, 2, 4, 1, 3, 5))?
1062                .reshape(((), 3, base_resolution, base_resolution))?;
1063
1064            let attention_mask_reshape = attention_mask
1065                .reshape((
1066                    1,
1067                    (mask_h as f32 / mask_resolution as f32) as usize,
1068                    mask_resolution,
1069                    (mask_w as f32 / mask_resolution as f32) as usize,
1070                    mask_resolution,
1071                ))?
1072                .permute((0, 1, 3, 2, 4))?
1073                .reshape(((), mask_resolution, mask_resolution))?;
1074
1075            let downsample_attention_mask = {
1076                let h_indices =
1077                    Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
1078                let w_indices =
1079                    Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
1080                let selected = attention_mask_reshape
1081                    .index_select(&h_indices, 1)?
1082                    .index_select(&w_indices, 2)?;
1083
1084                let mask = selected
1085                    .reshape((
1086                        1,
1087                        mask_h / mask_resolution,
1088                        mask_w / mask_resolution,
1089                        mask_resolution / 2 + mask_resolution % 2,
1090                        mask_resolution / 2 + mask_resolution % 2,
1091                    ))?
1092                    .permute((0, 1, 3, 2, 4))?;
1093                mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
1094            };
1095
1096            let img_tokens = 256
1097                + 1
1098                + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
1099                + downsample_attention_mask
1100                    .i((.., 0))?
1101                    .sum_all()?
1102                    .to_scalar::<u32>()? as usize
1103                + 16;
1104
1105            let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
1106            let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
1107
1108            image_sizes.push((img_h as u32, img_w as u32));
1109            padded_images.push(hd_image_reshape);
1110            padded_masks.push(hd_mask_reshape);
1111            num_img_tokens.push(img_tokens);
1112        }
1113        Ok(PreprocessedImages {
1114            pixel_values: Tensor::stack(&padded_images, 0)?,
1115            pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
1116            image_sizes: None,
1117            num_img_tokens: Some(num_img_tokens),
1118            aspect_ratio_ids: None,
1119            aspect_ratio_mask: None,
1120            num_tiles: None,
1121            image_grid_thw: None,
1122            video_grid_thw: None,
1123            rows: None,
1124            cols: None,
1125            pixel_values_list: None,
1126            tgt_sizes: None,
1127            image_sizes_all: Some(image_sizes),
1128            num_crops: None,
1129        })
1130    }
1131}