mistralrs_core/vision_models/phi4/
inputs_processor.rs

1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, collections::HashSet, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{DType, Device, IndexOp, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
8use regex::Regex;
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13    device_map::DeviceMapper,
14    pipeline::{
15        text_models_inputs_processor::{
16            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17        },
18        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19        ProcessorCreator,
20    },
21    sequence::Sequence,
22};
23
24use crate::vision_models::{
25    image_processor::{ImagePreProcessor, PreprocessedImages},
26    phi4::Phi4MMVisionSpecificArgs,
27    preprocessor_config::PreProcessorConfig,
28    processor_config::ProcessorConfig,
29    ModelInputs,
30};
31
32use super::image_embedding::IMAGE_SPECIAL_TOKEN_ID;
33
34const COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN: &str = r"<\|image_\d+\|>";
35const IMAGE_SPECIAL_TOKEN: &str = "<|endoftext10|>";
36pub(crate) const DYHD_BASE_RESOLUTION: usize = 448;
37
38// Input processor
39pub struct Phi4MMInputsProcessor;
40// Processor
41pub struct Phi4MMProcessor {
42    inputs_processor: Arc<Phi4MMInputsProcessor>,
43}
44
45impl ProcessorCreator for Phi4MMProcessor {
46    fn new_processor(
47        _: Option<ProcessorConfig>,
48        _: PreProcessorConfig,
49    ) -> Arc<dyn Processor + Send + Sync> {
50        Arc::new(Self {
51            inputs_processor: Arc::new(Phi4MMInputsProcessor),
52        })
53    }
54}
55
56impl Processor for Phi4MMProcessor {
57    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58        self.inputs_processor.clone()
59    }
60    fn get_special_tokens(&self) -> &[&'static str] {
61        &[]
62    }
63    fn template_action(&self) -> MessagesAction {
64        MessagesAction::FlattenOnlyText
65    }
66}
67
68impl InputsProcessor for Phi4MMInputsProcessor {
69    fn get_type(&self) -> InputsProcessorType {
70        InputsProcessorType::Vision
71    }
72    fn process_inputs(
73        &self,
74        tokenizer: Option<Arc<Tokenizer>>,
75        input_seqs: &mut [&mut Sequence],
76        is_prompt: bool,
77        is_xlora: bool,
78        device: &Device,
79        no_kv_cache: bool,
80        last_n_context_len: Option<(usize, usize)>,
81        return_raw_logits: bool,
82        other_config: Option<Arc<dyn Any>>,
83        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84        prompt_chunksize: Option<NonZeroUsize>,
85        mapper: Option<&dyn DeviceMapper>,
86    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87        if is_xlora {
88            return Box::new(std::iter::once(Err(anyhow::Error::msg(
89                "Cannot make inputs for X-LoRA vision model.",
90            ))));
91        }
92        if no_kv_cache {
93            return Box::new(std::iter::once(Err(anyhow::Error::msg(
94                "Vision model must have kv cache.",
95            ))));
96        }
97        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
98        if prompt_chunksize.is_some() {
99            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100        }
101        let Some(tokenizer) = tokenizer else {
102            return Box::new(std::iter::once(Err(anyhow::Error::msg(
103                "Phi4MMInputProcessor requires a specified tokenizer.",
104            ))));
105        };
106
107        let config = other_config
108            .clone()
109            .expect("Need a PreProcessorConfig config.");
110        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112        let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114        let (pixel_values, pixel_attention_mask, image_sizes, num_img_tokens) = if has_images {
115            let mut pixel_values_accum = Vec::new();
116            let mut pixel_attention_masks_accum = Vec::new();
117            let mut image_sizes_accum = Vec::new();
118            let mut num_img_tokens_accum = Vec::new();
119            for seq in input_seqs.iter_mut() {
120                let imgs = seq
121                    .take_images()
122                    .expect("Need to have images by this point.");
123                let PreprocessedImages {
124                    pixel_values,
125                    pixel_attention_mask,
126                    image_sizes: _,
127                    num_img_tokens,
128                    aspect_ratio_ids: _,
129                    aspect_ratio_mask: _,
130                    num_tiles: _,
131                    image_grid_thw: _,
132                    video_grid_thw: _,
133                    rows: _,
134                    cols: _,
135                    pixel_values_list: _,
136                    tgt_sizes: _,
137                    image_sizes_all,
138                    num_crops: _,
139                } = self
140                    .preprocess(
141                        imgs,
142                        vec![],
143                        config,
144                        device,
145                        (usize::MAX, usize::MAX), // Don't use it here...
146                    )
147                    .expect("Preprocessor failed");
148                let image_sizes = image_sizes_all.unwrap();
149                let pixel_attention_mask = pixel_attention_mask.unwrap();
150                pixel_values_accum.push(pixel_values);
151                pixel_attention_masks_accum.push(pixel_attention_mask);
152                // Using extend on purpose
153                image_sizes_accum.extend(image_sizes);
154                num_img_tokens_accum.push(num_img_tokens.unwrap());
155            }
156            (
157                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
158                Some(Tensor::cat(&pixel_attention_masks_accum, 0).unwrap()),
159                Some(image_sizes_accum),
160                Some(num_img_tokens_accum),
161            )
162        } else {
163            return Box::new(
164                text_models_inputs_processor::TextInputsProcessor
165                    .process_inputs(
166                        Some(tokenizer),
167                        input_seqs,
168                        is_prompt,
169                        is_xlora,
170                        device,
171                        no_kv_cache,
172                        last_n_context_len,
173                        return_raw_logits,
174                        other_config,
175                        paged_attn_metadata,
176                        None, // TODO
177                        mapper,
178                    )
179                    .map(|metadata| {
180                        let InputProcessorOutput {
181                            inputs,
182                            seq_indices,
183                        } = metadata?;
184
185                        let text_models_inputs_processor::ModelInputs {
186                            input_ids,
187                            input_ids_full: _,
188                            seqlen_offsets,
189                            seqlen_offsets_full: _,
190                            context_lens,
191                            position_ids,
192                            paged_attn_meta,
193                            flash_meta,
194                            flash_meta_full: _,
195                        } = *inputs
196                            .downcast::<text_models_inputs_processor::ModelInputs>()
197                            .expect("Downcast failed.");
198
199                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
200                            input_ids,
201                            seqlen_offsets,
202                            context_lens,
203                            position_ids,
204                            pixel_values: None,
205                            model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
206                                image_sizes: None,
207                                image_attention_mask: None,
208                                input_image_embeds: None,
209                            }),
210                            paged_attn_meta,
211                            flash_meta,
212                        });
213                        Ok(InputProcessorOutput {
214                            inputs,
215                            seq_indices,
216                        })
217                    }),
218            );
219        };
220
221        let detokenized = tokenizer
222            .decode_batch(
223                &input_seqs
224                    .iter()
225                    .map(|seq| seq.get_toks())
226                    .collect::<Vec<_>>(),
227                false,
228            )
229            .expect("Decode failed");
230
231        let img_token_pattern = Regex::new(COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN).unwrap();
232
233        let mut toks = Vec::new();
234
235        for (mut detokenized, (seq, num_img_tokens)) in detokenized
236            .into_iter()
237            .zip(input_seqs.iter_mut().zip(num_img_tokens.unwrap()))
238        {
239            detokenized = img_token_pattern
240                .replace_all(&detokenized, IMAGE_SPECIAL_TOKEN)
241                .to_string();
242
243            let has_changed_prompt = seq.has_changed_prompt;
244            if !has_changed_prompt {
245                seq.set_toks_and_reallocate(
246                    tokenizer
247                        .encode_fast(detokenized.clone(), false)
248                        .expect("Encode failed")
249                        .get_ids()
250                        .to_vec(),
251                    paged_attn_metadata.as_mut(),
252                );
253
254                seq.set_initial_prompt(detokenized);
255            }
256
257            let mut i = 0;
258            let mut image_token_count_iter = num_img_tokens.iter();
259            while i < seq.get_toks().len() {
260                let token_id = seq.get_toks()[i];
261                let token_count = if token_id == IMAGE_SPECIAL_TOKEN_ID as u32 {
262                    image_token_count_iter.next().unwrap()
263                } else {
264                    i += 1;
265                    continue;
266                };
267
268                let mut new_ids = seq.get_toks()[..i].to_vec();
269                new_ids.extend(vec![token_id; *token_count]);
270                new_ids.extend(seq.get_toks()[i + 1..].to_vec());
271                if !has_changed_prompt {
272                    seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
273                }
274                i += token_count;
275            }
276            if !has_changed_prompt {
277                seq.has_changed_prompt = true;
278            }
279            toks.push(seq.get_toks().to_vec());
280        }
281
282        let iter = if is_prompt {
283            get_prompt_input(
284                toks,
285                input_seqs,
286                device,
287                last_n_context_len,
288                return_raw_logits,
289                paged_attn_metadata.as_mut(),
290                None, // TODO: evaluate if it is possible to batch this
291                mapper,
292            )
293        } else {
294            get_completion_input(
295                toks,
296                input_seqs,
297                device,
298                no_kv_cache,
299                last_n_context_len,
300                return_raw_logits,
301                paged_attn_metadata.as_mut(),
302                None, // TODO: evaluate if it is possible to batch this
303                mapper,
304            )
305        };
306
307        Box::new(iter.into_iter().map(move |metadata| {
308            let pixel_values = pixel_values.clone();
309            let pixel_attention_mask = pixel_attention_mask.clone();
310            let text_models_inputs_processor::InnerInputProcessorOutput {
311                inputs:
312                    text_models_inputs_processor::InputMetadata {
313                        input,
314                        positions,
315                        context_lens,
316                        position_ids,
317                        paged_attn_meta,
318                        flash_meta,
319                    },
320                seq_indices,
321            } = metadata?;
322            let inputs: Box<dyn Any> = Box::new(ModelInputs {
323                input_ids: input,
324                seqlen_offsets: positions,
325                context_lens,
326                position_ids,
327                pixel_values: pixel_values.clone(),
328                model_specific_args: Box::new(Phi4MMVisionSpecificArgs {
329                    image_sizes: image_sizes.clone(),
330                    image_attention_mask: pixel_attention_mask,
331                    input_image_embeds: pixel_values,
332                }),
333                paged_attn_meta,
334                flash_meta,
335            });
336            Ok(InputProcessorOutput {
337                inputs,
338                seq_indices,
339            })
340        }))
341    }
342}
343
344impl Phi4MMInputsProcessor {
345    fn pad_image(
346        image: &DynamicImage,
347        top: u32,
348        bottom: u32,
349        left: u32,
350        right: u32,
351        pad_color: Rgba<u8>,
352    ) -> DynamicImage {
353        // Calculate the new dimensions
354        let new_width = image.width() + left + right;
355        let new_height = image.height() + top + bottom;
356
357        // Create a new image with the new dimensions and fill it with the pad color
358        let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
359        for x in 0..new_width {
360            for y in 0..new_height {
361                new_image.put_pixel(x, y, pad_color);
362            }
363        }
364
365        // Paste the original image into the center of the new image
366        new_image
367            .copy_from(image, left, top)
368            .expect("Failed to copy image");
369
370        new_image
371    }
372
373    fn compute_target_ratios(min_num: u32, max_num: u32) -> Vec<(u32, u32)> {
374        let mut ratios: HashSet<(u32, u32)> = HashSet::new();
375        for n in min_num..=max_num {
376            for i in 1..=n {
377                for j in 1..=n {
378                    if i * j >= min_num && i * j <= max_num {
379                        ratios.insert((i, j));
380                    }
381                }
382            }
383        }
384        let mut sorted_ratios: Vec<(u32, u32)> = ratios.into_iter().collect();
385        sorted_ratios.sort_by_key(|&(i, j)| i * j);
386        sorted_ratios
387    }
388
389    fn find_closest_aspect_ratio(
390        aspect_ratio: f64,
391        target_ratios: Vec<(u32, u32)>,
392        width: u32,
393        height: u32,
394        image_size: usize,
395    ) -> (u32, u32) {
396        let mut best_ratio_diff = f64::INFINITY;
397        let mut best_ratio = (1, 1);
398        let area = width * height;
399        for ratio in target_ratios {
400            let target_aspect_ratio = ratio.0 as f64 / ratio.1 as f64;
401            let ratio_diff = (aspect_ratio - target_aspect_ratio).abs();
402            if ratio_diff < best_ratio_diff {
403                best_ratio_diff = ratio_diff;
404                best_ratio = ratio;
405            } else if ratio_diff == best_ratio_diff
406                && area as f64 > 0.5 * image_size as f64 * ratio.0 as f64 * ratio.1 as f64
407            {
408                best_ratio = ratio;
409            }
410        }
411        best_ratio
412    }
413
414    fn dynamic_preprocess(
415        &self,
416        mut image: DynamicImage,
417        min_num: usize,
418        max_num: usize,
419        image_size: usize,
420        mask_size: usize,
421        device: &Device,
422    ) -> Result<(DynamicImage, Tensor)> {
423        let (orig_w, orig_h) = image.dimensions();
424
425        let w_crop_num = (orig_w as f64 / image_size as f64).ceil();
426        let h_crop_num = (orig_h as f64 / image_size as f64).ceil();
427        let (target_aspect_ratio, target_width, target_height) =
428            if w_crop_num * h_crop_num > max_num as f64 {
429                let aspect_ratio = orig_w as f64 / orig_h as f64;
430                let target_ratios = Self::compute_target_ratios(min_num as u32, max_num as u32);
431
432                let target_aspect_ratio = Self::find_closest_aspect_ratio(
433                    aspect_ratio,
434                    target_ratios,
435                    orig_w,
436                    orig_h,
437                    image_size,
438                );
439
440                let target_width = image_size * target_aspect_ratio.0 as usize;
441                let target_height = image_size * target_aspect_ratio.1 as usize;
442
443                (
444                    (target_aspect_ratio.0 as f64, target_aspect_ratio.1 as f64),
445                    target_width,
446                    target_height,
447                )
448            } else {
449                let target_width = (image_size as f64 * w_crop_num) as usize;
450                let target_height = (image_size as f64 * h_crop_num) as usize;
451                let target_aspect_ratio = (w_crop_num, h_crop_num);
452
453                (target_aspect_ratio, target_width, target_height)
454            };
455
456        let ratio_width = target_width as f64 / orig_w as f64;
457        let ratio_height = target_height as f64 / orig_h as f64;
458        let (new_size, padding_width, padding_height) = if ratio_width < ratio_height {
459            (
460                (target_width, (orig_h as f64 * ratio_width) as usize),
461                0_usize,
462                target_height - (orig_h as f64 * ratio_width) as usize,
463            )
464        } else {
465            (
466                ((orig_w as f64 * ratio_height) as usize, target_height),
467                target_width - (orig_w as f64 * ratio_height) as usize,
468                0_usize,
469            )
470        };
471
472        let mut attention_mask = Tensor::ones(
473            (
474                (mask_size as f64 * target_aspect_ratio.1) as usize,
475                (mask_size as f64 * target_aspect_ratio.0) as usize,
476            ),
477            DType::U32,
478            device,
479        )?;
480        if padding_width >= 14 {
481            attention_mask = attention_mask.slice_assign(
482                &[&.., &(attention_mask.dim(1)? - padding_width / 14..)],
483                &Tensor::zeros(
484                    (attention_mask.dim(0)?, padding_width / 14),
485                    DType::U32,
486                    device,
487                )?,
488            )?;
489        }
490        if padding_height >= 14 {
491            attention_mask = attention_mask.slice_assign(
492                &[&(attention_mask.dim(0)? - padding_height / 14..), &..],
493                &Tensor::zeros(
494                    (padding_height / 14, attention_mask.dim(1)?),
495                    DType::U32,
496                    device,
497                )?,
498            )?;
499        }
500
501        image = image.resize_exact(new_size.0 as u32, new_size.1 as u32, FilterType::Nearest);
502        image = Self::pad_image(
503            &image,
504            0,
505            padding_height as u32,
506            padding_width as u32,
507            0,
508            Rgba([255u8, 255, 255, 255]),
509        );
510
511        Ok((image, attention_mask))
512    }
513}
514
515impl ImagePreProcessor for Phi4MMInputsProcessor {
516    #[allow(clippy::excessive_precision)]
517    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
518    #[allow(clippy::excessive_precision)]
519    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
520
521    fn preprocess(
522        &self,
523        mut images: Vec<DynamicImage>,
524        videos: Vec<Vec<DynamicImage>>,
525        config: &PreProcessorConfig,
526        device: &Device,
527        (_, _): (usize, usize),
528    ) -> Result<PreprocessedImages> {
529        // If no images, will not call this.
530        assert!(!images.is_empty());
531        assert!(videos.is_empty());
532
533        // If >1 images, resize them all to the largest, potentially destroying aspect ratio
534        let mut max_size = None;
535        for image in images.iter() {
536            if max_size.is_none() {
537                max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
538            } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
539                max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
540            } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
541                max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
542            }
543        }
544        let (max_h, max_w) = max_size.unwrap();
545        for image in images.iter_mut() {
546            *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
547        }
548
549        let mut image_sizes = Vec::new();
550        let mut padded_images = Vec::new();
551        let mut padded_masks = Vec::new();
552        let mut num_img_tokens = Vec::new();
553        for mut image in images {
554            // Convert to rgb, default to true
555            if config.do_convert_rgb.unwrap_or(true) {
556                image = DynamicImage::ImageRgb8(image.to_rgb8());
557            }
558
559            let transforms = Transforms {
560                input: &ToTensor,
561                inner_transforms: &[&Normalize {
562                    mean: vec![0.5, 0.5, 0.5],
563                    std: vec![0.5, 0.5, 0.5],
564                }],
565            };
566            // Dynamic HD
567            let dyhd_base_resolution = DYHD_BASE_RESOLUTION;
568            let base_resolution = dyhd_base_resolution;
569            // over 384 and 448 resolution
570            let mask_resolution = base_resolution / 14;
571            let min_num = 1;
572
573            let (elem, attention_mask) = self.dynamic_preprocess(
574                image,
575                min_num,
576                config.dynamic_hd.unwrap(),
577                base_resolution,
578                mask_resolution,
579                device,
580            )?;
581
582            let hd_image = elem.apply(transforms, device)?;
583            let (img_h, img_w) = (hd_image.dim(1)?, hd_image.dim(2)?);
584            let (mask_h, mask_w) = (attention_mask.dim(0)?, attention_mask.dim(1)?);
585
586            // Resize with bicubic interpolation
587            let global_image = hd_image
588                .unsqueeze(0)?
589                .interpolate2d(base_resolution, base_resolution)?;
590            let global_attention_mask =
591                Tensor::ones((1, mask_resolution, mask_resolution), DType::U32, device)?;
592
593            let hd_image_reshape = hd_image
594                .reshape((
595                    1,
596                    3,
597                    (img_h as f32 / base_resolution as f32) as usize,
598                    base_resolution,
599                    (img_w as f32 / base_resolution as f32) as usize,
600                    base_resolution,
601                ))?
602                .permute((0, 2, 4, 1, 3, 5))?
603                .reshape(((), 3, base_resolution, base_resolution))?;
604
605            let attention_mask_reshape = attention_mask
606                .reshape((
607                    1,
608                    (mask_h as f32 / mask_resolution as f32) as usize,
609                    mask_resolution,
610                    (mask_w as f32 / mask_resolution as f32) as usize,
611                    mask_resolution,
612                ))?
613                .permute((0, 1, 3, 2, 4))?
614                .reshape(((), mask_resolution, mask_resolution))?;
615
616            let downsample_attention_mask = {
617                let h_indices =
618                    Tensor::arange_step(0, attention_mask_reshape.dim(1)? as u32, 2, device)?;
619                let w_indices =
620                    Tensor::arange_step(0, attention_mask_reshape.dim(2)? as u32, 2, device)?;
621                let selected = attention_mask_reshape
622                    .index_select(&h_indices, 1)?
623                    .index_select(&w_indices, 2)?;
624
625                let mask = selected
626                    .reshape((
627                        1,
628                        mask_h / mask_resolution,
629                        mask_w / mask_resolution,
630                        mask_resolution / 2 + mask_resolution % 2,
631                        mask_resolution / 2 + mask_resolution % 2,
632                    ))?
633                    .permute((0, 1, 3, 2, 4))?;
634                mask.reshape((mask.dim(1)? * mask.dim(2)?, mask.dim(3)? * mask.dim(4)?))?
635            };
636
637            let img_tokens = 256
638                + 1
639                + downsample_attention_mask.sum_all()?.to_scalar::<u32>()? as usize
640                + downsample_attention_mask
641                    .i((.., 0))?
642                    .sum_all()?
643                    .to_scalar::<u32>()? as usize
644                + 16;
645
646            let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
647            let hd_mask_reshape = Tensor::cat(&[global_attention_mask, attention_mask_reshape], 0)?;
648
649            image_sizes.push((img_h as u32, img_w as u32));
650            padded_images.push(hd_image_reshape);
651            padded_masks.push(hd_mask_reshape);
652            num_img_tokens.push(img_tokens);
653        }
654        Ok(PreprocessedImages {
655            pixel_values: Tensor::stack(&padded_images, 0)?,
656            pixel_attention_mask: Some(Tensor::stack(&padded_masks, 0)?),
657            image_sizes: None,
658            num_img_tokens: Some(num_img_tokens),
659            aspect_ratio_ids: None,
660            aspect_ratio_mask: None,
661            num_tiles: None,
662            image_grid_thw: None,
663            video_grid_thw: None,
664            rows: None,
665            cols: None,
666            pixel_values_list: None,
667            tgt_sizes: None,
668            image_sizes_all: Some(image_sizes),
669            num_crops: None,
670        })
671    }
672}