mistralrs_core/vision_models/phi4/
inputs_processor.rs

1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13    device_map::DeviceMapper,
14    pipeline::{
15        text_models_inputs_processor::{
16            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17        },
18        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19        ProcessorCreator,
20    },
21    sequence::Sequence,
22};
23
24use crate::vision_models::{
25    image_processor::{ImagePreProcessor, PreprocessedImages},
26    phi4::Phi4MMVisionSpecificArgs,
27    preprocessor_config::PreProcessorConfig,
28    processor_config::ProcessorConfig,
29    ModelInputs,
30};
31
32use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
33
34const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
35const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
36pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
37
38// Input processor
39pub struct Phi4MMInputsProcessor;
40// Processor
41pub struct Phi4MMProcessor {
42    inputs_processor: Arc<Phi4MMInputsProcessor>,
43}
44
45impl ProcessorCreator for Phi4MMProcessor {
46    fn new_processor(
47        _: Option<ProcessorConfig>,
48        _: PreProcessorConfig,
49    ) -> Arc<dyn Processor + Send + Sync> {
50        Arc::new(Self {
51            inputs_processor: Arc::new(Phi4MMInputsProcessor),
52        })
53    }
54}
55
56impl Processor for Phi4MMProcessor {
57    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58        self.inputs_processor.clone()
59    }
60    fn get_special_tokens(&self) -> &[&'static str] {
61        &[]
62    }
63    fn template_action(&self) -> MessagesAction {
64        MessagesAction::FlattenOnlyText
65    }
66}
67
68impl InputsProcessor for Phi4MMInputsProcessor {
69    fn get_type(&self) -> InputsProcessorType {
70        InputsProcessorType::Vision
71    }
72    fn process_inputs(
73        &self,
74        tokenizer: Option<Arc<Tokenizer>>,
75        input_seqs: &mut [&mut Sequence],
76        is_prompt: bool,
77        is_xlora: bool,
78        device: &Device,
79        no_kv_cache: bool,
80        last_n_context_len: Option<(usize, usize)>,
81        return_raw_logits: bool,
82        other_config: Option<Arc<dyn Any>>,
83        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84        prompt_chunksize: Option<NonZeroUsize>,
85        mapper: Option<&dyn DeviceMapper>,
86    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87        if is_xlora {
88            return Box::new(std::iter::once(Err(anyhow::Error::msg(
89                "Cannot make inputs for X-LoRA vision model.",
90            ))));
91        }
92        if no_kv_cache {
93            return Box::new(std::iter::once(Err(anyhow::Error::msg(
94                "Vision model must have kv cache.",
95            ))));
96        }
97        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
98        if prompt_chunksize.is_some() {
99            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100        }
101        let Some(tokenizer) = tokenizer else {
102            return Box::new(std::iter::once(Err(anyhow::Error::msg(
103                "Phi4MMInputProcessor requires a specified tokenizer.",
104            ))));
105        };
106
107        let config = other_config
108            .clone()
109            .expect("Need a PreProcessorConfig config.");
110        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112        let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114        let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
115            let mut pixel_values_accum = Vec::new();
116            let mut pixel_attention_masks_accum = Vec::new();
117            let mut image_sizes_accum = Vec::new();
118            let mut num_img_tokens_accum = Vec::new();
119            for seq in input_seqs.iter_mut() {
120                let imgs = seq
121                    .take_images()
122                    .expect("Need to have images by this point.");
123                let PreprocessedImages {
124                    pixel_values,
125                    pixel_attention_mask,
126                    image_sizes: _,
127                    num_img_tokens,
128                    aspect_ratio_ids: _,
129                    aspect_ratio_mask: _,
130                    num_tiles: _,
131                    image_grid_thw: _,
132                    video_grid_thw: _,
133                    rows: _,
134                    cols: _,
135                    pixel_values_list: _,
136                    tgt_sizes: _,
137                    image_sizes_all,
138                    num_crops: _,
139                } = self
140                    .preprocess(
141                        imgs,
142                        vec![],
143                        config,
144                        device,
145                        (usize::MAX, usize::MAX), // Don't use it here...
146                    )
147                    .expect("Preprocessor failed");
148                let image_sizes = image_sizes_all.unwrap();
149                let pixel_attention_mask = pixel_attention_mask.unwrap();
150                pixel_values_accum.push(pixel_values);
151                pixel_attention_masks_accum.push(pixel_attention_mask);
152                // Using extend on purpose
153                image_sizes_accum.extend(image_sizes);
154                num_img_tokens_accum.push(num_img_tokens.unwrap());
155            }
156            (
157                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
158                Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
159                Some(image_sizes_accum),
160                Some(num_img_tokens_accum),
161            )
162        } else {
163            return Box::new(
164                text_models_inputs_processor::TextInputsProcessor
165                    .process_inputs(
166                        Some(tokenizer),
167                        input_seqs,
168                        is_prompt,
169                        is_xlora,
170                        device,
171                        no_kv_cache,
172                        last_n_context_len,
173                        return_raw_logits,
174                        other_config,
175                        paged_attn_metadata,
176                        None, // TODO
177                        mapper,
178                    )
179                    .map(|metadata| {
180                        let InputProcessorOutput {
181                            inputs,
182                            seq_indices,
183                        } = metadata?;
184
185                        let text_models_inputs_processor::ModelInputs {
186                            input_ids,
187                            input_ids_full: _,
188                            seqlen_offsets,
189                            seqlen_offsets_full: _,
190                            context_lens,
191                            position_ids,
192                            paged_attn_meta,
193                            flash_meta,
194                            flash_meta_full: _,
195                        } = *inputs
196                            .downcast::<text_models_inputs_processor::ModelInputs>()
197                            .expect("Downcast failed.");
198
199                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
200                            input_ids,
201                            seqlen_offsets,
202                            context_lens,
203                            position_ids,
204                            pixel_values: None,
205                            model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
206                                image_sizes: None,
207                                image_attention_mask: None,
208                                input_image_embeds: None,
209                            }),
210                            paged_attn_meta,
211                            flash_meta,
212                        });
213                        Ok(InputProcessorOutput {
214                            inputs,
215                            seq_indices,
216                        })
217                    }),
218            );
219        };
220
221        let detokenized = tokenizer
222            .decode_batch(
223                &input_seqs
224                    .iter()
225                    .map(|seq| seq.get_toks())
226                    .collect::<Vec<_>>(),
227                false,
228            )
229            .expect("Decode failed");
230
231        let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
232
233        let mut toks = Vec::new();
234
235        for (mut detokenized, (seq, num_img_tokens)) in detokenized
236            .into_iter()
237            .zip(input_seqs.iter_mut().zip(num_img_tokens.unwrap()))
238        {
239            detokenized = img_token_pattern
240                .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
241                .to_string();
242
243            seq.set_toks_and_reallocate(
244                tokenizer
245                    .encode_fast(detokenized.clone(), false)
246                    .expect("Encode failed")
247                    .get_ids()
248                    .to_vec(),
249                paged_attn_metadata.as_mut(),
250            );
251
252            seq.set_initial_prompt(detokenized);
253
254            let mut i = 0;
255            let mut image_token_count_iter = num_img_tokens.iter();
256            while i < seq.get_toks().len() {
257                let token_id = seq.get_toks()[i];
258                let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
259                    image_token_count_iter.next().unwrap()
260                } else {
261                    i += 1;
262                    continue;
263                };
264
265                let mut new_ids = seq.get_toks()[..i].to_vec();
266                new_ids.extend(vec![token_id; *token_count]);
267                new_ids.extend(seq.get_toks()[i + 1..].to_vec());
268                seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
269                i += token_count;
270            }
271            toks.push(seq.get_toks().to_vec());
272        }
273
274        let iter = if is_prompt {
275            get_prompt_input(
276                toks,
277                input_seqs,
278                device,
279                last_n_context_len,
280                return_raw_logits,
281                paged_attn_metadata.as_mut(),
282                None, // TODO: evaluate if it is possible to batch this
283                mapper,
284            )
285        } else {
286            get_completion_input(
287                toks,
288                input_seqs,
289                device,
290                no_kv_cache,
291                last_n_context_len,
292                return_raw_logits,
293                paged_attn_metadata.as_mut(),
294                None, // TODO: evaluate if it is possible to batch this
295                mapper,
296            )
297        };
298
299        Box::new(iter.into_iter().map(move |metadata| {
300            let pixel_values = pixel_values.clone();
301            let pixel_attention_mask = pixel_attention_mask.clone();
302            let text_models_inputs_processor::InnerInputProcessorOutput {
303                inputs:
304                    text_models_inputs_processor::InputMetadata {
305                        input,
306                        positions,
307                        context_lens,
308                        position_ids,
309                        paged_attn_meta,
310                        flash_meta,
311                    },
312                seq_indices,
313            } = metadata?;
314            let inputs: Box<dyn Any> = Box::new(ModelInputs {
315                input_ids: input,
316                seqlen_offsets: positions,
317                context_lens,
318                position_ids,
319                pixel_values: pixel_values.clone(),
320                model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
321                    image_sizes: image_sizes.clone(),
322                    image_attention_mask: pixel_attention_mask,
323                    input_image_embeds: pixel_values,
324                }),
325                paged_attn_meta,
326                flash_meta,
327            });
328            Ok(InputProcessorOutput {
329                inputs,
330                seq_indices,
331            })
332        }))
333    }
334}
335
336impl Phi4MMInputsProcessor {
337    fn pad_image(
338        image: &DynamicImage,
339        top: u32,
340        bottom: u32,
341        left: u32,
342        right: u32,
343        pad_color: Rgba<u8>,
344    ) -> DynamicImage {
345        // Calculate the new dimensions
346        let new_width = image.width() + left + right;
347        let new_height = image.height() + top + bottom;
348
349        // Create a new image with the new dimensions and fill it with the pad color
350        let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
351        for x in 0..new_width {
352            for y in 0..new_height {
353                new_image.put_pixel(x, y, pad_color);
354            }
355        }
356
357        // Paste the original image into the center of the new image
358        new_image
359            .copy_from(image, left, top)
360            .expect("Failed to copy image");
361
362        new_image
363    }
364
365    fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
366        let mut ratios: HashSet<(u32, u32)> = HashSet::new();
367        for n in min_num..=max_num {
368            for i in 1..=n {
369                for j in 1..=n {
370                    if i * j >= min_num && i * j <= max_num {
371                        ratios.insert((i, j));
372                    }
373                }
374            }
375        }
376        let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
377        sorted_ratios.sort_by_key(|&(i, j)| i * j);
378        sorted_ratios
379    }
380
381    fn find_closest_aspect_ratio(
382        aspect_ratio: f64,
383        target_ratios: Vec<(u32, u32)>,
384        width: u32,
385        height: u32,
386        image_size: usize,
387    ) -> (u32, u32) {
388        let mut best_ratio_diff = f64::INFINITY;
389        let mut best_ratio = (1, 1);
390        let area = width * height;
391        for ratio in target_ratios {
392            let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
393            let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
394            if ratio_diff < best_ratio_diff {
395                best_ratio_diff = ratio_diff;
396                best_ratio = ratio;
397            } else if ratio_diff == best_ratio_diff
398                && area as f64 > 0.5 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
399            {
400                best_ratio = ratio;
401            }
402        }
403        best_ratio
404    }
405
406    fn dynamic_preprocess(
407        &self,
408        mut image: DynamicImage,
409        min_num: usize,
410        max_num: usize,
411        image_size: usize,
412        mask_size: usize,
413        device: &Device,
414    ) -> Result<(DynamicImage, Tensor)> {
415        let (orig_w, orig_h) = image.dimensions();
416
417        let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
418        let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
419        let (target_aspect_ratio, target_width, target_height) =
420            if w_crop_num * h_crop_num > max_num as f64 {
421                let aspect_ratio = orig_w as f64 / orig_h as f64;
422                let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
423
424                let target_aspect_ratio = Self::find_closest_aspect_ratio(
425                    aspect_ratio,
426                    target_ratios,
427                    orig_w,
428                    orig_h,
429                    image_size,
430                );
431
432                let target_width = image_size * target_aspect_ratio.0 as usize;
433                let target_height = image_size * target_aspect_ratio.1 as usize;
434
435                (
436                    (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
437                    target_width,
438                    target_height,
439                )
440            } else {
441                let target_width = (image_size as f64 * w_crop_num) as usize;
442                let target_height = (image_size as f64 * h_crop_num) as usize;
443                let target_aspect_ratio = (w_crop_num, h_crop_num);
444
445                (target_aspect_ratio, target_width, target_height)
446            };
447
448        let ratio_width = target_width as f64 / orig_w as f64;
449        let ratio_height = target_height as f64 / orig_h as f64;
450        let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
451            (
452                (target_width, (orig_h as f64 * ratio_width) as usize),
453                0_usize,
454                target_height - (orig_h as f64 * ratio_width) as usize,
455            )
456        } else {
457            (
458                ((orig_w as f64 * ratio_height) as usize, target_height),
459                target_width - (orig_w as f64 * ratio_height) as usize,
460                0_usize,
461            )
462        };
463
464        let mut attention_mask = Tensor::ones(
465            (
466                (mask_size as f64 * target_aspect_ratio.1) as usize,
467                (mask_size as f64 * target_aspect_ratio.0) as usize,
468            ),
469            DType::U32,
470            device,
471        )?;
472        if padding_width >= 14 {
473            attention_mask = attention_mask.slice_assign(
474                &[&.., &(attention_mask.dim(1)? - padding_width / 14..)],
475                &Tensor::zeros(
476                    (attention_mask.dim(0)?, padding_width / 14),
477                    DType::U32,
478                    device,
479                )?,
480            )?;
481        }
482        if padding_height >= 14 {
483            attention_mask = attention_mask.slice_assign(
484                &[&(attention_mask.dim(0)? - padding_height / 14..), &..],
485                &Tensor::zeros(
486                    (padding_height / 14, attention_mask.dim(1)?),
487                    DType::U32,
488                    device,
489                )?,
490            )?;
491        }
492
493        image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
494        image = Self::pad_image(
495            &image,
496            0,
497            padding_height as u32,
498            padding_width as u32,
499            0,
500            Rgba([255u8, 255, 255, 255]),
501        );
502
503        Ok((image, attention_mask))
504    }
505}
506
507impl ImagePreProcessor for Phi4MMInputsProcessor {
508    #[allow(clippy::excessive_precision)]
509    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
510    #[allow(clippy::excessive_precision)]
511    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
512
513    fn preprocess(
514        &self,
515        mut images: Vec<DynamicImage>,
516        videos: Vec<Vec<DynamicImage>>,
517        config: &PreProcessorConfig,
518        device: &Device,
519        (_, _): (usize, usize),
520    ) -> Result<PreprocessedImages> {
521        // If no images, will not call this.
522        assert!(!images.is_empty());
523        assert!(videos.is_empty());
524
525        // If >1 images, resize them all to the largest, potentially destroying aspect ratio
526        let mut max_size = None;
527        for image in images.iter() {
528            if max_size.is_none() {
529                max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
530            } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
531                max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
532            } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
533                max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
534            }
535        }
536        let (max_h, max_w) = max_size.unwrap();
537        for image in images.iter_mut() {
538            *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
539        }
540
541        let mut image_sizes = Vec::new();
542        let mut padded_images = Vec::new();
543        let mut padded_masks = Vec::new();
544        let mut num_img_tokens = Vec::new();
545        for mut image in images {
546            // Convert to rgb, default to true
547            if config.do_convert_rgb.unwrap_or(true) {
548                image = DynamicImage::ImageRgb8(image.to_rgb8());
549            }
550
551            let transforms = Transforms {
552                input: &ToTensor,
553                inner_transforms: &[&Normalize {
554                    mean: vec![0.5, 0.5, 0.5],
555                    std: vec![0.5, 0.5, 0.5],
556                }],
557            };
558            // Dynamic HD
559            let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
560            let base_resolution = dyhd_base_resolution;
561            // over 384 and 448 resolution
562            let mask_resolution = base_resolution / 14;
563            let min_num = 1;
564
565            let (elem, attention_mask) = self.dynamic_preprocess(
566                image,
567                min_num,
568                config.dynamic_hd.unwrap(),
569                base_resolution,
570                mask_resolution,
571                device,
572            )?;
573
574            let hd_image = elem.apply(transforms, device)?;
575            let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
576            let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
577
578            // Resize with bicubic interpolation
579            let global_image = hd_image
580                .unsqueeze(0)?
581                .interpolate2d(base_resolution, base_resolution)?;
582            let global_attention_mask =
583                Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
584
585            let hd_image_reshape = hd_image
586                .reshape((
587                    1,
588                    3,
589                    (img_h as f32 / base_resolution as f32) as usize,
590                    base_resolution,
591                    (img_w as f32 / base_resolution as f32) as usize,
592                    base_resolution,
593                ))?
594                .permute((0, 2, 4, 1, 3, 5))?
595                .reshape(((), 3, base_resolution, base_resolution))?;
596
597            let attention_mask_reshape = attention_mask
598                .reshape((
599                    1,
600                    (mask_h as f32 / mask_resolution as f32) as usize,
601                    mask_resolution,
602                    (mask_w as f32 / mask_resolution as f32) as usize,
603                    mask_resolution,
604                ))?
605                .permute((0, 1, 3, 2, 4))?
606                .reshape(((), mask_resolution, mask_resolution))?;
607
608            let downsample_attention_mask = {
609                let h_indices =
610                    Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
611                let w_indices =
612                    Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
613                let selected = attention_mask_reshape
614                    .index_select(&h_indices, 1)?
615                    .index_select(&w_indices, 2)?;
616
617                let mask = selected
618                    .reshape((
619                        1,
620                        mask_h / mask_resolution,
621                        mask_w / mask_resolution,
622                        mask_resolution / 2 + mask_resolution % 2,
623                        mask_resolution / 2 + mask_resolution % 2,
624                    ))?
625                    .permute((0, 1, 3, 2, 4))?;
626                mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
627            };
628
629            let img_tokens = 256
630                + 1
631                + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
632                + downsample_attention_mask
633                    .i((.., 0))?
634                    .sum_all()?
635                    .to_scalar::<u32>()? as usize
636                + 16;
637
638            let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
639            let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
640
641            image_sizes.push((img_h as u32, img_w as u32));
642            padded_images.push(hd_image_reshape);
643            padded_masks.push(hd_mask_reshape);
644            num_img_tokens.push(img_tokens);
645        }
646        Ok(PreprocessedImages {
647            pixel_values: Tensor::stack(&padded_images, 0)?,
648            pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
649            image_sizes: None,
650            num_img_tokens: Some(num_img_tokens),
651            aspect_ratio_ids: None,
652            aspect_ratio_mask: None,
653            num_tiles: None,
654            image_grid_thw: None,
655            video_grid_thw: None,
656            rows: None,
657            cols: None,
658            pixel_values_list: None,
659            tgt_sizes: None,
660            image_sizes_all: Some(image_sizes),
661            num_crops: None,
662        })
663    }
664}