mistralrs_core/vision_models/phi4/
inputs_processor.rs

1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use apodize::hanning_iter;
13use rubato::{
14    Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction,
15};
16use rustfft::{num_complex::Complex32, FftPlanner};
17
18use crate::{
19    device_map::DeviceMapper,
20    pipeline::{
21        text_models_inputs_processor::{
22            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
23        },
24        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
25        ProcessorCreator,
26    },
27    sequence::Sequence,
28};
29
30use crate::vision_models::{
31    image_processor::{ImagePreProcessor, PreprocessedImages},
32    phi4::Phi4MMVisionSpecificArgs,
33    preprocessor_config::PreProcessorConfig,
34    processor_config::ProcessorConfig,
35    ModelInputs,
36};
37
38use super::audio_embedding::AUDIO_SPECIAL_TOKEN_ID;
39use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
40
41const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
42const COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN: &str = r"<\|audio_\d+\|>";
43const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
44const AUDIO_SPECIAL_TOKEN: &str = "<|endoftext11|>";
45pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
46
47const AUDIO_FEATURE_SIZE: usize = 80; // mel bins
48
49type AudioProcessingResult = Result<(Option<Tensor>, Option<Vec<usize>>, Option<Tensor>)>;
50
51// Input processor
52pub struct Phi4MMInputsProcessor {
53    audio_compression_rate: usize,
54    audio_downsample_rate: usize,
55    audio_feat_stride: usize,
56    eightk_method: String, // "fillzero" or "resample"
57}
58
59// Processor
60pub struct Phi4MMProcessor {
61    inputs_processor: Arc<Phi4MMInputsProcessor>,
62}
63
64impl ProcessorCreator for Phi4MMProcessor {
65    fn new_processor(
66        _: Option<ProcessorConfig>,
67        pre_processor_config: PreProcessorConfig,
68    ) -> Arc<dyn Processor + Send + Sync> {
69        Arc::new(Self {
70            inputs_processor: Arc::new(Phi4MMInputsProcessor {
71                audio_compression_rate: pre_processor_config
72                    .audio_compression_rate
73                    .expect("audio_compression_rate"),
74                audio_downsample_rate: pre_processor_config
75                    .audio_downsample_rate
76                    .expect("audio_downsample_rate"),
77                audio_feat_stride: pre_processor_config
78                    .audio_feat_stride
79                    .expect("audio_feat_stride"),
80                eightk_method: "fillzero".to_string(), // Default to fillzero
81            }),
82        })
83    }
84}
85
86impl Processor for Phi4MMProcessor {
87    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
88        self.inputs_processor.clone()
89    }
90    fn get_special_tokens(&self) -> &[&'static str] {
91        &[]
92    }
93    fn template_action(&self) -> MessagesAction {
94        MessagesAction::FlattenOnlyText
95    }
96}
97
98impl InputsProcessor for Phi4MMInputsProcessor {
99    fn get_type(&self) -> InputsProcessorType {
100        InputsProcessorType::Vision
101    }
102    fn process_inputs(
103        &self,
104        tokenizer: Option<Arc<Tokenizer>>,
105        input_seqs: &mut [&mut Sequence],
106        is_prompt: bool,
107        is_xlora: bool,
108        device: &Device,
109        no_kv_cache: bool,
110        last_n_context_len: Option<(usize, usize)>,
111        return_raw_logits: bool,
112        other_config: Option<Arc<dyn Any>>,
113        mut paged_attn_metadata: Option<PagedAttentionMeta>,
114        prompt_chunksize: Option<NonZeroUsize>,
115        mapper: Option<&dyn DeviceMapper>,
116    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
117        if is_xlora {
118            return Box::new(std::iter::once(Err(anyhow::Error::msg(
119                "Cannot make inputs for X-LoRA vision model.",
120            ))));
121        }
122        if no_kv_cache {
123            return Box::new(std::iter::once(Err(anyhow::Error::msg(
124                "Vision model must have kv cache.",
125            ))));
126        }
127        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
128        if prompt_chunksize.is_some() {
129            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
130        }
131        let Some(tokenizer) = tokenizer else {
132            return Box::new(std::iter::once(Err(anyhow::Error::msg(
133                "Phi4MMInputProcessor requires a specified tokenizer.",
134            ))));
135        };
136
137        let config = other_config
138            .clone()
139            .expect("Need a PreProcessorConfig config.");
140        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
141
142        let has_audios = input_seqs.iter().all(|seq| seq.has_audios());
143        let has_images = input_seqs.iter().all(|seq| seq.has_images());
144
145        let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
146            let mut pixel_values_accum = Vec::new();
147            let mut pixel_attention_masks_accum = Vec::new();
148            let mut image_sizes_accum = Vec::new();
149            let mut num_img_tokens_accum = Vec::new();
150            for seq in input_seqs.iter_mut() {
151                let imgs = seq
152                    .take_images()
153                    .expect("Need to have images by this point.");
154                let PreprocessedImages {
155                    pixel_values,
156                    pixel_attention_mask,
157                    image_sizes: _,
158                    num_img_tokens,
159                    aspect_ratio_ids: _,
160                    aspect_ratio_mask: _,
161                    num_tiles: _,
162                    image_grid_thw: _,
163                    video_grid_thw: _,
164                    rows: _,
165                    cols: _,
166                    pixel_values_list: _,
167                    tgt_sizes: _,
168                    image_sizes_all,
169                    num_crops: _,
170                } = self
171                    .preprocess(
172                        imgs,
173                        vec![],
174                        config,
175                        device,
176                        (usize::MAX, usize::MAX), // Don't use it here...
177                    )
178                    .expect("Preprocessor failed");
179                let image_sizes = image_sizes_all.unwrap();
180                let pixel_attention_mask = pixel_attention_mask.unwrap();
181                pixel_values_accum.push(pixel_values);
182                pixel_attention_masks_accum.push(pixel_attention_mask);
183                // Using extend on purpose
184                image_sizes_accum.extend(image_sizes);
185                num_img_tokens_accum.push(num_img_tokens.unwrap());
186            }
187            (
188                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
189                Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
190                Some(image_sizes_accum),
191                Some(num_img_tokens_accum),
192            )
193        } else if has_audios {
194            (None, None, None, Some(vec![vec![]; input_seqs.len()]))
195        } else {
196            return Box::new(
197                text_models_inputs_processor::TextInputsProcessor
198                    .process_inputs(
199                        Some(tokenizer),
200                        input_seqs,
201                        is_prompt,
202                        is_xlora,
203                        device,
204                        no_kv_cache,
205                        last_n_context_len,
206                        return_raw_logits,
207                        other_config,
208                        paged_attn_metadata,
209                        None, // TODO
210                        mapper,
211                    )
212                    .map(|metadata| {
213                        let InputProcessorOutput {
214                            inputs,
215                            seq_indices,
216                        } = metadata?;
217
218                        let text_models_inputs_processor::ModelInputs {
219                            input_ids,
220                            input_ids_full: _,
221                            seqlen_offsets,
222                            seqlen_offsets_full: _,
223                            context_lens,
224                            position_ids,
225                            paged_attn_meta,
226                            flash_meta,
227                            flash_meta_full: _,
228                        } = *inputs
229                            .downcast::<text_models_inputs_processor::ModelInputs>()
230                            .expect("Downcast failed.");
231
232                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
233                            input_ids,
234                            seqlen_offsets,
235                            context_lens,
236                            position_ids,
237                            pixel_values: None,
238                            model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
239                                input_image_embeds: None,
240                                image_attention_mask: None,
241                                image_sizes: None,
242                                input_audio_embeds: None,
243                                audio_embed_sizes: None,
244                                audio_attention_mask: None,
245                            }),
246                            paged_attn_meta,
247                            flash_meta,
248                        });
249                        Ok(InputProcessorOutput {
250                            inputs,
251                            seq_indices,
252                        })
253                    }),
254            );
255        };
256
257        let detokenized = tokenizer
258            .decode_batch(
259                &input_seqs
260                    .iter()
261                    .map(|seq| seq.get_toks())
262                    .collect::<Vec<_>>(),
263                false,
264            )
265            .expect("Decode failed");
266
267        let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
268        let audio_token_pattern = Regex::new(COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN).unwrap();
269
270        for (mut detokenized, seq) in detokenized.into_iter().zip(input_seqs.iter_mut()) {
271            detokenized = img_token_pattern
272                .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
273                .to_string();
274            detokenized = audio_token_pattern
275                .replace_all(&detokenized, AUDIO_SPECIAL_TOKEN)
276                .to_string();
277
278            let has_changed_prompt = seq.multimodal.has_changed_prompt;
279            if !has_changed_prompt {
280                seq.set_toks_and_reallocate(
281                    tokenizer
282                        .encode_fast(detokenized.clone(), false)
283                        .expect("Encode failed")
284                        .get_ids()
285                        .to_vec(),
286                    paged_attn_metadata.as_mut(),
287                );
288
289                seq.set_initial_prompt(detokenized);
290            }
291        }
292
293        let (input_audio_embeds, audio_embed_sizes, audio_attention_mask) =
294            match self.process_audio_for_sequences(input_seqs, device) {
295                Ok(result) => result,
296                Err(e) => return Box::new(std::iter::once(Err(anyhow::Error::new(e)))),
297            };
298
299        let mut toks = Vec::new();
300
301        for (seq, num_img_tokens) in input_seqs.iter_mut().zip(num_img_tokens.unwrap()) {
302            let has_changed_prompt = seq.multimodal.has_changed_prompt;
303
304            let mut i = 0;
305            let mut image_token_count_iter = num_img_tokens.iter();
306            let audio_sizes_tmp = audio_embed_sizes.clone().unwrap_or(vec![]);
307            let mut audio_embed_sizes = audio_sizes_tmp.iter();
308            while i < seq.get_toks().len() {
309                let token_id = seq.get_toks()[i];
310                let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
311                    image_token_count_iter.next().unwrap()
312                } else if token_id == AUDIO_SPECIAL_TOKEN_ID as u32 {
313                    audio_embed_sizes.next().unwrap()
314                } else {
315                    i += 1;
316                    continue;
317                };
318
319                let mut new_ids = seq.get_toks()[..i].to_vec();
320                new_ids.extend(vec![token_id; *token_count]);
321                new_ids.extend(seq.get_toks()[i + 1..].to_vec());
322                if !has_changed_prompt {
323                    seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
324                }
325                i += token_count;
326            }
327            if !has_changed_prompt {
328                seq.multimodal.has_changed_prompt = true;
329            }
330            toks.push(seq.get_toks().to_vec());
331        }
332
333        let iter = if is_prompt {
334            get_prompt_input(
335                toks.iter().map(Vec::as_slice).collect(),
336                input_seqs,
337                device,
338                last_n_context_len,
339                return_raw_logits,
340                paged_attn_metadata.as_mut(),
341                None, // TODO: evaluate if it is possible to batch this
342                mapper,
343            )
344        } else {
345            get_completion_input(
346                toks.iter().map(Vec::as_slice).collect(),
347                input_seqs,
348                device,
349                no_kv_cache,
350                last_n_context_len,
351                return_raw_logits,
352                paged_attn_metadata.as_mut(),
353                None, // TODO: evaluate if it is possible to batch this
354                mapper,
355            )
356        };
357
358        Box::new(iter.into_iter().map(move |metadata| {
359            let pixel_values = pixel_values.clone();
360            let pixel_attention_mask = pixel_attention_mask.clone();
361            let text_models_inputs_processor::InnerInputProcessorOutput {
362                inputs:
363                    text_models_inputs_processor::InputMetadata {
364                        input,
365                        positions,
366                        context_lens,
367                        position_ids,
368                        paged_attn_meta,
369                        flash_meta,
370                    },
371                seq_indices,
372            } = metadata?;
373            let inputs: Box<dyn Any> = Box::new(ModelInputs {
374                input_ids: input,
375                seqlen_offsets: positions,
376                context_lens,
377                position_ids,
378                pixel_values: pixel_values.clone(),
379                model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
380                    input_image_embeds: pixel_values,
381                    image_attention_mask: pixel_attention_mask,
382                    image_sizes: image_sizes.clone(),
383                    input_audio_embeds: input_audio_embeds.clone(),
384                    audio_embed_sizes: audio_embed_sizes.clone(),
385                    audio_attention_mask: audio_attention_mask.clone(),
386                }),
387                paged_attn_meta,
388                flash_meta,
389            });
390            Ok(InputProcessorOutput {
391                inputs,
392                seq_indices,
393            })
394        }))
395    }
396}
397
398impl Phi4MMInputsProcessor {
399    fn extract_audio_features(
400        &self,
401        audio_data: &[f32],
402        sample_rate: u32,
403    ) -> Result<Vec<Vec<f32>>> {
404        // Resample audio to supported rates using rubato
405        let (resampled_audio, final_sample_rate) =
406            self.resample_audio_with_rubato(audio_data, sample_rate)?;
407
408        // Extract mel spectrogram using rustfft and custom mel filterbank
409        let mel_features =
410            self.extract_mel_spectrogram_rustfft(&resampled_audio, final_sample_rate)?;
411
412        Ok(mel_features)
413    }
414
415    fn resample_audio_with_rubato(&self, wav: &[f32], fs: u32) -> Result<(Vec<f32>, u32)> {
416        let target_fs = if fs > 16000 {
417            16000
418        } else if fs > 8000 && fs < 16000 {
419            8000
420        } else if fs < 8000 {
421            return Err(candle_core::Error::Msg(format!(
422                "Unsupported sample rate: {fs}"
423            )));
424        } else {
425            return Ok((wav.to_vec(), fs)); // No resampling needed
426        };
427
428        if fs == target_fs {
429            return Ok((wav.to_vec(), fs));
430        }
431
432        // Handle 8kHz upsampling case
433        if fs == 8000 && self.eightk_method == "resample" {
434            // Upsample to 16kHz using rubato
435            let params = SincInterpolationParameters {
436                sinc_len: 256,
437                f_cutoff: 0.95,
438                interpolation: SincInterpolationType::Linear,
439                oversampling_factor: 256,
440                window: WindowFunction::BlackmanHarris2,
441            };
442
443            let mut resampler = SincFixedIn::<f32>::new(
444                2.0, // resample ratio (16000/8000)
445                2.0,
446                params,
447                wav.len(),
448                1, // mono
449            )
450            .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
451
452            let input = vec![wav.to_vec()];
453            let output = resampler
454                .process(&input, None)
455                .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
456
457            return Ok((output[0].clone(), 16000));
458        }
459
460        // Regular downsampling
461        let resample_ratio = target_fs as f64 / fs as f64;
462
463        let params = SincInterpolationParameters {
464            sinc_len: 256,
465            f_cutoff: 0.95,
466            interpolation: SincInterpolationType::Linear,
467            oversampling_factor: 256,
468            window: WindowFunction::BlackmanHarris2,
469        };
470
471        let mut resampler = SincFixedIn::<f32>::new(
472            resample_ratio,
473            2.0,
474            params,
475            wav.len(),
476            1, // mono
477        )
478        .map_err(|e| candle_core::Error::Msg(format!("Resampler creation failed: {e}")))?;
479
480        let input = vec![wav.to_vec()];
481        let output = resampler
482            .process(&input, None)
483            .map_err(|e| candle_core::Error::Msg(format!("Resampling failed: {e}")))?;
484
485        Ok((output[0].clone(), target_fs))
486    }
487
488    fn extract_mel_spectrogram_rustfft(&self, wav: &[f32], fs: u32) -> Result<Vec<Vec<f32>>> {
489        // Set parameters based on sample rate
490        let (n_fft, win_length, hop_length) = if fs == 8000 {
491            (256, 200, 80)
492        } else if fs == 16000 {
493            (512, 400, 160)
494        } else {
495            return Err(candle_core::Error::Msg(format!(
496                "Unsupported sample rate: {fs}"
497            )));
498        };
499
500        // Apply preemphasis first
501        let preemphasized = self.apply_preemphasis(wav, 0.97);
502
503        // Create FFT planner
504        let mut planner = FftPlanner::<f32>::new();
505        let fft = planner.plan_fft_forward(n_fft);
506
507        // Create Hanning window
508        let window: Vec<f64> = hanning_iter(win_length).collect();
509
510        // Create mel filterbank
511        let mel_filters = self.create_mel_filterbank(AUDIO_FEATURE_SIZE, n_fft, fs as f32)?;
512
513        // Extract frames and apply STFT
514        let n_batch = (preemphasized.len() - win_length) / hop_length + 1;
515        let mut mel_features = Vec::new();
516
517        for i in 0..n_batch {
518            let start = i * hop_length;
519            let end = start + win_length;
520            if end > preemphasized.len() {
521                break;
522            }
523
524            // Apply window and convert to complex
525            let mut windowed: Vec<Complex32> = preemphasized[start..end]
526                .iter()
527                .zip(window.iter())
528                .map(|(s, w)| Complex32::new(s * *w as f32, 0.0))
529                .collect();
530
531            // Pad to n_fft length
532            windowed.resize(n_fft, Complex32::new(0.0, 0.0));
533
534            // Apply FFT
535            fft.process(&mut windowed);
536
537            // Take power spectrum of positive frequencies
538            let power_spectrum: Vec<f32> = windowed[0..n_fft / 2 + 1]
539                .iter()
540                .map(|c| c.norm_sqr())
541                .collect();
542
543            // Apply mel filterbank
544            let mut mel_frame = vec![0.0; AUDIO_FEATURE_SIZE];
545            for (mel_idx, filter) in mel_filters.iter().enumerate() {
546                let mut sum = 0.0;
547                for (freq_idx, &coeff) in filter.iter().enumerate() {
548                    if freq_idx < power_spectrum.len() {
549                        sum += power_spectrum[freq_idx] * coeff;
550                    }
551                }
552                mel_frame[mel_idx] = (sum.max(1.0)).ln(); // Apply log with clipping
553            }
554
555            mel_features.push(mel_frame);
556        }
557
558        // Handle 8kHz case with fillzero method
559        if fs == 8000 && self.eightk_method == "fillzero" {
560            for frame in &mut mel_features {
561                // Extend each frame with zeros to match 16kHz structure
562                let original_len = frame.len();
563                frame.extend(vec![0.0; original_len]);
564            }
565        }
566
567        Ok(mel_features)
568    }
569
570    fn apply_preemphasis(&self, wav: &[f32], preemphasis: f32) -> Vec<f32> {
571        if wav.is_empty() {
572            return vec![];
573        }
574
575        let mut preemphasized = Vec::with_capacity(wav.len());
576
577        // First sample: y[0] = x[0] * 32768
578        preemphasized.push(wav[0] * 32768.0);
579
580        // Remaining samples: y[n] = (x[n] - preemphasis * x[n-1]) * 32768
581        for i in 1..wav.len() {
582            let filtered = (wav[i] - preemphasis * wav[i - 1]) * 32768.0;
583            preemphasized.push(filtered);
584        }
585
586        preemphasized
587    }
588
589    fn create_mel_filterbank(
590        &self,
591        n_mels: usize,
592        n_fft: usize,
593        sample_rate: f32,
594    ) -> Result<Vec<Vec<f32>>> {
595        let bank_width = n_fft / 2 + 1;
596        let fmax = sample_rate / 2.0;
597        let fmin = 0.0;
598
599        // Mel scale conversion functions
600        let hz_to_mel = |f: f32| 1127.0 * (1.0 + f / 700.0).ln();
601        let mel_to_hz = |mel: f32| 700.0 * (mel / 1127.0).exp() - 700.0;
602
603        let mel_low = hz_to_mel(fmin);
604        let mel_high = hz_to_mel(fmax);
605
606        // Create mel centers
607        let mel_centers: Vec<f32> = (0..=n_mels + 1)
608            .map(|i| mel_low + (mel_high - mel_low) * i as f32 / (n_mels + 1) as f32)
609            .collect();
610
611        let hz_centers: Vec<f32> = mel_centers.iter().map(|&mel| mel_to_hz(mel)).collect();
612
613        // Convert to bin numbers
614        let bin_centers: Vec<usize> = hz_centers
615            .iter()
616            .map(|&f| ((f * n_fft as f32 / sample_rate) + 0.5) as usize)
617            .collect();
618
619        // Create triangular filters
620        let mut filters = Vec::new();
621        for m in 0..n_mels {
622            let mut filter = vec![0.0; bank_width];
623
624            let left_bin = bin_centers[m];
625            let center_bin = bin_centers[m + 1];
626            let right_bin = bin_centers[m + 2];
627
628            // Left slope
629            for (bin, filter) in filter
630                .iter_mut()
631                .enumerate()
632                .take(center_bin)
633                .skip(left_bin)
634            {
635                if bin < bank_width {
636                    *filter = (bin - left_bin) as f32 / (center_bin - left_bin) as f32;
637                }
638            }
639
640            // Right slope
641            for (bin, filter) in filter
642                .iter_mut()
643                .enumerate()
644                .take(right_bin)
645                .skip(center_bin)
646            {
647                if bin < bank_width {
648                    *filter = (right_bin - bin) as f32 / (right_bin - center_bin) as f32;
649                }
650            }
651
652            filters.push(filter);
653        }
654
655        Ok(filters)
656    }
657
658    fn compute_audio_embed_size(
659        &self,
660        audio_frames: usize,
661        compression_rate: usize,
662        downsample_rate: usize,
663    ) -> usize {
664        // First compression
665        let integer = audio_frames / compression_rate;
666        let remainder = audio_frames % compression_rate;
667        let result = if remainder == 0 { integer } else { integer + 1 };
668
669        // Second compression (qformer)
670        let integer = result / downsample_rate;
671        let remainder = result % downsample_rate;
672        if remainder == 0 {
673            integer
674        } else {
675            integer + 1
676        }
677    }
678
679    fn create_audio_attention_mask(
680        &self,
681        audio_frames_list: &[usize],
682        device: &Device,
683    ) -> Result<Tensor> {
684        let max_frames = *audio_frames_list.iter().max().unwrap_or(&0);
685        let batch_size = audio_frames_list.len();
686
687        let mut mask_data = vec![0u8; batch_size * max_frames];
688        for (batch_idx, &frames) in audio_frames_list.iter().enumerate() {
689            for frame_idx in 0..frames.min(max_frames) {
690                mask_data[batch_idx * max_frames + frame_idx] = 1;
691            }
692        }
693
694        Tensor::from_slice(&mask_data, (batch_size, max_frames), device)?.to_dtype(DType::F32)
695    }
696
697    fn process_audio_for_sequences(
698        &self,
699        input_seqs: &mut [&mut Sequence],
700        device: &Device,
701    ) -> AudioProcessingResult {
702        // Check if any sequence has audio tokens
703        let has_audio_tokens = input_seqs
704            .iter()
705            .any(|seq| seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32)));
706
707        if !has_audio_tokens {
708            return Ok((None, None, None));
709        }
710
711        let mut audio_features_list = Vec::new();
712        let mut audio_embed_sizes_list = Vec::new();
713        let mut audio_frames_list = Vec::new();
714
715        // Process audio for each sequence that needs it
716        for seq in input_seqs.iter_mut() {
717            let has_audio = seq.get_toks().contains(&(AUDIO_SPECIAL_TOKEN_ID as u32));
718
719            if has_audio {
720                if let Some(audios) = seq.take_audios() {
721                    for audio in audios.into_iter() {
722                        // Convert multi-channel audio to mono by averaging channels
723                        let samples = audio.to_mono();
724
725                        // Extract features
726                        let features = self.extract_audio_features(&samples, audio.sample_rate)?;
727                        let audio_frames = features.len() * self.audio_feat_stride;
728
729                        let embed_size = self.compute_audio_embed_size(
730                            audio_frames,
731                            self.audio_compression_rate,
732                            self.audio_downsample_rate,
733                        );
734
735                        // Convert to tensor
736                        let features_len = features.len();
737                        let features_flat: Vec<f32> = features.into_iter().flatten().collect();
738                        let features_tensor = Tensor::from_slice(
739                            &features_flat,
740                            (features_len, AUDIO_FEATURE_SIZE),
741                            device,
742                        )?;
743
744                        audio_features_list.push(features_tensor);
745                        audio_embed_sizes_list.push(embed_size);
746                        audio_frames_list.push(audio_frames);
747                    }
748                } else {
749                    candle_core::bail!("No audios in `process_audio_for_sequences`");
750                };
751            }
752        }
753
754        if audio_features_list.is_empty() {
755            return Ok((None, None, None));
756        }
757
758        // Pad sequences to same length
759        let max_len = audio_features_list
760            .iter()
761            .map(|t| t.dim(0).unwrap_or(0))
762            .max()
763            .unwrap_or(0);
764
765        let mut padded_features = Vec::new();
766        for features in audio_features_list {
767            let seq_len = features.dim(0)?;
768            if seq_len < max_len {
769                let padding =
770                    Tensor::zeros((max_len - seq_len, AUDIO_FEATURE_SIZE), DType::F32, device)?;
771                let padded = Tensor::cat(&[features, padding], 0)?;
772                padded_features.push(padded);
773            } else {
774                padded_features.push(features);
775            }
776        }
777
778        // Stack into batch tensor
779        let input_audio_embeds = Tensor::stack(&padded_features, 0)?;
780
781        // Create attention mask if multiple sequences
782        let audio_attention_mask = if audio_frames_list.len() > 1 {
783            Some(self.create_audio_attention_mask(&audio_frames_list, device)?)
784        } else {
785            None
786        };
787
788        Ok((
789            Some(input_audio_embeds),
790            Some(audio_embed_sizes_list),
791            audio_attention_mask,
792        ))
793    }
794}
795
796impl Phi4MMInputsProcessor {
797    fn pad_image(
798        image: &DynamicImage,
799        top: u32,
800        bottom: u32,
801        left: u32,
802        right: u32,
803        pad_color: Rgba<u8>,
804    ) -> DynamicImage {
805        // Calculate the new dimensions
806        let new_width = image.width() + left + right;
807        let new_height = image.height() + top + bottom;
808
809        // Create a new image with the new dimensions and fill it with the pad color
810        let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
811        for x in 0..new_width {
812            for y in 0..new_height {
813                new_image.put_pixel(x, y, pad_color);
814            }
815        }
816
817        // Paste the original image into the center of the new image
818        new_image
819            .copy_from(image, 0, 0)
820            .expect("Failed to copy image");
821
822        new_image
823    }
824
825    fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
826        let mut ratios: HashSet<(u32, u32)> = HashSet::new();
827        for n in min_num..=max_num {
828            for i in 1..=n {
829                for j in 1..=n {
830                    if i * j >= min_num && i * j <= max_num {
831                        ratios.insert((i, j));
832                    }
833                }
834            }
835        }
836        let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
837        sorted_ratios.sort_by_key(|&(i, j)| i * j);
838        sorted_ratios
839    }
840
841    fn find_closest_aspect_ratio(
842        aspect_ratio: f64,
843        target_ratios: Vec<(u32, u32)>,
844        width: u32,
845        height: u32,
846        image_size: usize,
847    ) -> (u32, u32) {
848        let mut best_ratio_diff = f64::INFINITY;
849        let mut best_ratio = (1, 1);
850        let area = width * height;
851        for ratio in target_ratios {
852            let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
853            let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
854            if ratio_diff < best_ratio_diff {
855                best_ratio_diff = ratio_diff;
856                best_ratio = ratio;
857            } else if ratio_diff == best_ratio_diff
858                && area as f64
859                    > 0.5 * image_size as f64 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
860            {
861                best_ratio = ratio;
862            }
863        }
864        best_ratio
865    }
866
867    fn dynamic_preprocess(
868        &self,
869        mut image: DynamicImage,
870        min_num: usize,
871        max_num: usize,
872        image_size: usize,
873        mask_size: usize,
874        device: &Device,
875    ) -> Result<(DynamicImage, Tensor)> {
876        let (orig_w, orig_h) = image.dimensions();
877
878        let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
879        let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
880        let (target_aspect_ratio, target_width, target_height) =
881            if w_crop_num * h_crop_num > max_num as f64 {
882                let aspect_ratio = orig_w as f64 / orig_h as f64;
883                let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
884
885                let target_aspect_ratio = Self::find_closest_aspect_ratio(
886                    aspect_ratio,
887                    target_ratios,
888                    orig_w,
889                    orig_h,
890                    image_size,
891                );
892
893                let target_width = image_size * target_aspect_ratio.0 as usize;
894                let target_height = image_size * target_aspect_ratio.1 as usize;
895
896                (
897                    (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
898                    target_width,
899                    target_height,
900                )
901            } else {
902                let target_width = (image_size as f64 * w_crop_num) as usize;
903                let target_height = (image_size as f64 * h_crop_num) as usize;
904                let target_aspect_ratio = (w_crop_num, h_crop_num);
905
906                (target_aspect_ratio, target_width, target_height)
907            };
908
909        let ratio_width = target_width as f64 / orig_w as f64;
910        let ratio_height = target_height as f64 / orig_h as f64;
911        let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
912            (
913                (target_width, (orig_h as f64 * ratio_width) as usize),
914                0_usize,
915                target_height - (orig_h as f64 * ratio_width) as usize,
916            )
917        } else {
918            (
919                ((orig_w as f64 * ratio_height) as usize, target_height),
920                target_width - (orig_w as f64 * ratio_height) as usize,
921                0_usize,
922            )
923        };
924
925        // Guard against extreme aspect ratios resulting in too-small dimensions
926        if new_size.1.min(target_height) < 10 || new_size.0.min(target_width) < 10 {
927            candle_core::bail!(
928                "Image aspect ratio too extreme; resulting size below minimum threshold",
929            );
930        }
931
932        let mut attention_mask = Tensor::ones(
933            (
934                (mask_size as f64 * target_aspect_ratio.1) as usize,
935                (mask_size as f64 * target_aspect_ratio.0) as usize,
936            ),
937            DType::U32,
938            device,
939        )?;
940        if padding_width >= 14 {
941            attention_mask = attention_mask.slice_assign(
942                &[&.., &(attention_mask.dim(1)? - padding_width / 14..)],
943                &Tensor::zeros(
944                    (attention_mask.dim(0)?, padding_width / 14),
945                    DType::U32,
946                    device,
947                )?,
948            )?;
949        }
950        if padding_height >= 14 {
951            attention_mask = attention_mask.slice_assign(
952                &[&(attention_mask.dim(0)? - padding_height / 14..), &..],
953                &Tensor::zeros(
954                    (padding_height / 14, attention_mask.dim(1)?),
955                    DType::U32,
956                    device,
957                )?,
958            )?;
959        }
960
961        // Ensure the attention mask is non-empty
962        let mask_sum: u32 = attention_mask.sum_all()?.to_scalar::<u32>()?;
963        if mask_sum == 0 {
964            candle_core::bail!("dynamic_preprocess produced an attention mask with zero sum",);
965        }
966
967        image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
968        image = Self::pad_image(
969            &image,
970            0,
971            padding_height as u32,
972            0,
973            padding_width as u32,
974            Rgba([255u8, 255, 255, 255]),
975        );
976
977        Ok((image, attention_mask))
978    }
979}
980
981impl ImagePreProcessor for Phi4MMInputsProcessor {
982    #[allow(clippy::excessive_precision)]
983    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
984    #[allow(clippy::excessive_precision)]
985    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
986
987    fn preprocess(
988        &self,
989        mut images: Vec<DynamicImage>,
990        videos: Vec<Vec<DynamicImage>>,
991        config: &PreProcessorConfig,
992        device: &Device,
993        (_, _): (usize, usize),
994    ) -> Result<PreprocessedImages> {
995        // If no images, will not call this.
996        assert!(!images.is_empty());
997        assert!(videos.is_empty());
998
999        // If >1 images, resize them all to the largest, potentially destroying aspect ratio
1000        let mut max_size = None;
1001        for image in images.iter() {
1002            if max_size.is_none() {
1003                max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
1004            } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
1005                max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
1006            } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
1007                max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
1008            }
1009        }
1010        let (max_w, max_h) = max_size.unwrap();
1011        for image in images.iter_mut() {
1012            *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
1013        }
1014
1015        let mut image_sizes = Vec::new();
1016        let mut padded_images = Vec::new();
1017        let mut padded_masks = Vec::new();
1018        let mut num_img_tokens = Vec::new();
1019        for mut image in images {
1020            // Convert to rgb, default to true
1021            if config.do_convert_rgb.unwrap_or(true) {
1022                image = DynamicImage::ImageRgb8(image.to_rgb8());
1023            }
1024
1025            let transforms = Transforms {
1026                input: &ToTensor,
1027                inner_transforms: &[&Normalize {
1028                    mean: vec![0.5, 0.5, 0.5],
1029                    std: vec![0.5, 0.5, 0.5],
1030                }],
1031            };
1032            // Dynamic HD
1033            let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
1034            let base_resolution = dyhd_base_resolution;
1035            // over 384 and 448 resolution
1036            let mask_resolution = base_resolution / 14;
1037            let min_num = 1;
1038
1039            let (elem, attention_mask) = self.dynamic_preprocess(
1040                image,
1041                min_num,
1042                config.dynamic_hd.unwrap(),
1043                base_resolution,
1044                mask_resolution,
1045                device,
1046            )?;
1047
1048            let hd_image = elem.apply(transforms, device)?;
1049            let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
1050            let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
1051
1052            // Resize with bicubic interpolation
1053            let global_image = hd_image
1054                .unsqueeze(0)?
1055                .interpolate2d(base_resolution, base_resolution)?;
1056            let global_attention_mask =
1057                Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
1058
1059            let hd_image_reshape = hd_image
1060                .reshape((
1061                    1,
1062                    3,
1063                    (img_h as f32 / base_resolution as f32) as usize,
1064                    base_resolution,
1065                    (img_w as f32 / base_resolution as f32) as usize,
1066                    base_resolution,
1067                ))?
1068                .permute((0, 2, 4, 1, 3, 5))?
1069                .reshape(((), 3, base_resolution, base_resolution))?;
1070
1071            let attention_mask_reshape = attention_mask
1072                .reshape((
1073                    1,
1074                    (mask_h as f32 / mask_resolution as f32) as usize,
1075                    mask_resolution,
1076                    (mask_w as f32 / mask_resolution as f32) as usize,
1077                    mask_resolution,
1078                ))?
1079                .permute((0, 1, 3, 2, 4))?
1080                .reshape(((), mask_resolution, mask_resolution))?;
1081
1082            let downsample_attention_mask = {
1083                let h_indices =
1084                    Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
1085                let w_indices =
1086                    Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
1087                let selected = attention_mask_reshape
1088                    .index_select(&h_indices, 1)?
1089                    .index_select(&w_indices, 2)?;
1090
1091                let mask = selected
1092                    .reshape((
1093                        1,
1094                        mask_h / mask_resolution,
1095                        mask_w / mask_resolution,
1096                        mask_resolution / 2 + mask_resolution % 2,
1097                        mask_resolution / 2 + mask_resolution % 2,
1098                    ))?
1099                    .permute((0, 1, 3, 2, 4))?;
1100                mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
1101            };
1102
1103            let img_tokens = 256
1104                + 1
1105                + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
1106                + downsample_attention_mask
1107                    .i((.., 0))?
1108                    .sum_all()?
1109                    .to_scalar::<u32>()? as usize
1110                + 16;
1111
1112            let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
1113            let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
1114
1115            image_sizes.push((img_h as u32, img_w as u32));
1116            padded_images.push(hd_image_reshape);
1117            padded_masks.push(hd_mask_reshape);
1118            num_img_tokens.push(img_tokens);
1119        }
1120        Ok(PreprocessedImages {
1121            pixel_values: Tensor::stack(&padded_images, 0)?,
1122            pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
1123            image_sizes: None,
1124            num_img_tokens: Some(num_img_tokens),
1125            aspect_ratio_ids: None,
1126            aspect_ratio_mask: None,
1127            num_tiles: None,
1128            image_grid_thw: None,
1129            video_grid_thw: None,
1130            rows: None,
1131            cols: None,
1132            pixel_values_list: None,
1133            tgt_sizes: None,
1134            image_sizes_all: Some(image_sizes),
1135            num_crops: None,
1136        })
1137    }
1138}