mistralrs_core/vision_models/phi3/
phi3_inputs_processor.rs

1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
9use regex_automata::meta::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14    device_map::DeviceMapper,
15    pipeline::{
16        text_models_inputs_processor::{
17            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18        },
19        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20        ProcessorCreator,
21    },
22    sequence::Sequence,
23};
24
25use crate::vision_models::{
26    image_processor::{ImagePreProcessor, PreprocessedImages},
27    phi3::Phi3VisionSpecificArgs,
28    preprocessor_config::PreProcessorConfig,
29    processor_config::ProcessorConfig,
30    ModelInputs,
31};
32
33// Input processor
34pub struct Phi3InputsProcessor {
35    image_tag_splitter: Regex,
36}
37// Processor
38pub struct Phi3Processor {
39    inputs_processor: Arc<Phi3InputsProcessor>,
40}
41
42impl ProcessorCreator for Phi3Processor {
43    fn new_processor(
44        _: Option<ProcessorConfig>,
45        _: PreProcessorConfig,
46    ) -> Arc<dyn Processor + Send + Sync> {
47        Arc::new(Self {
48            inputs_processor: Arc::new(Phi3InputsProcessor {
49                image_tag_splitter: Regex::new(r"<\|image_\d+\|>")
50                    .expect("Failed to compile split regex."),
51            }),
52        })
53    }
54}
55
56impl Processor for Phi3Processor {
57    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58        self.inputs_processor.clone()
59    }
60    fn get_special_tokens(&self) -> &[&'static str] {
61        &[]
62    }
63    fn template_action(&self) -> MessagesAction {
64        MessagesAction::FlattenOnlyText
65    }
66}
67
68impl InputsProcessor for Phi3InputsProcessor {
69    fn get_type(&self) -> InputsProcessorType {
70        InputsProcessorType::Vision
71    }
72    fn process_inputs(
73        &self,
74        tokenizer: Option<Arc<Tokenizer>>,
75        input_seqs: &mut [&mut Sequence],
76        is_prompt: bool,
77        is_xlora: bool,
78        device: &Device,
79        no_kv_cache: bool,
80        last_n_context_len: Option<(usize, usize)>,
81        return_raw_logits: bool,
82        other_config: Option<Arc<dyn Any>>,
83        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84        prompt_chunksize: Option<NonZeroUsize>,
85        mapper: Option<&dyn DeviceMapper>,
86    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87        if is_xlora {
88            return Box::new(std::iter::once(Err(anyhow::Error::msg(
89                "Cannot make inputs for X-LoRA vision model.",
90            ))));
91        }
92        if no_kv_cache {
93            return Box::new(std::iter::once(Err(anyhow::Error::msg(
94                "Vision model must have kv cache.",
95            ))));
96        }
97        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
98        if prompt_chunksize.is_some() {
99            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100        }
101        let Some(tokenizer) = tokenizer else {
102            return Box::new(std::iter::once(Err(anyhow::Error::msg(
103                "Phi3InputProcessor requires a specified tokenizer.",
104            ))));
105        };
106
107        let config = other_config
108            .clone()
109            .expect("Need a PreProcessorConfig config.");
110        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112        let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114        let (pixel_values, image_sizes, num_img_tokens, n_images) = if has_images {
115            let mut pixel_values_accum = Vec::new();
116            let mut image_sizes_accum = Vec::new();
117            let mut num_img_tokens_accum = Vec::new();
118            let mut n_images = Vec::new();
119            for seq in input_seqs.iter_mut() {
120                let imgs = seq
121                    .take_images()
122                    .expect("Need to have images by this point.");
123                let imgs_len = imgs.len();
124                n_images.push(imgs_len);
125                let PreprocessedImages {
126                    pixel_values,
127                    pixel_attention_mask: _,
128                    image_sizes,
129                    num_img_tokens,
130                    aspect_ratio_ids: _,
131                    aspect_ratio_mask: _,
132                    num_tiles: _,
133                    image_grid_thw: _,
134                    video_grid_thw: _,
135                    rows: _,
136                    cols: _,
137                    pixel_values_list: _,
138                    tgt_sizes: _,
139                    image_sizes_all: _,
140                    num_crops: _,
141                } = self
142                    .preprocess(
143                        imgs,
144                        vec![],
145                        config,
146                        device,
147                        (usize::MAX, usize::MAX), // Don't use it here...
148                    )
149                    .expect("Preprocessor failed");
150                let image_sizes = image_sizes.unwrap();
151                pixel_values_accum.push(pixel_values);
152                image_sizes_accum.push(image_sizes);
153                num_img_tokens_accum.push(num_img_tokens.unwrap());
154            }
155            (
156                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157                Some(image_sizes_accum),
158                Some(num_img_tokens_accum),
159                n_images,
160            )
161        } else {
162            return Box::new(
163                text_models_inputs_processor::TextInputsProcessor
164                    .process_inputs(
165                        Some(tokenizer),
166                        input_seqs,
167                        is_prompt,
168                        is_xlora,
169                        device,
170                        no_kv_cache,
171                        last_n_context_len,
172                        return_raw_logits,
173                        other_config,
174                        paged_attn_metadata,
175                        None, // TODO
176                        mapper,
177                    )
178                    .map(|metadata| {
179                        let InputProcessorOutput {
180                            inputs,
181                            seq_indices,
182                        } = metadata?;
183
184                        let text_models_inputs_processor::ModelInputs {
185                            input_ids,
186                            input_ids_full: _,
187                            seqlen_offsets,
188                            seqlen_offsets_full: _,
189                            context_lens,
190                            position_ids,
191                            paged_attn_meta,
192                            flash_meta,
193                            flash_meta_full: _,
194                        } = *inputs
195                            .downcast::<text_models_inputs_processor::ModelInputs>()
196                            .expect("Downcast failed.");
197
198                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
199                            input_ids,
200                            seqlen_offsets,
201                            context_lens,
202                            position_ids,
203                            pixel_values: None,
204                            model_specific_args: Box::new(Phi3VisionSpecificArgs {
205                                image_sizes: None,
206                            }),
207                            paged_attn_meta,
208                            flash_meta,
209                        });
210                        Ok(InputProcessorOutput {
211                            inputs,
212                            seq_indices,
213                        })
214                    }),
215            );
216        };
217
218        let mut toks = Vec::new();
219        let detokenized = tokenizer
220            .decode_batch(
221                &input_seqs
222                    .iter()
223                    .map(|seq| seq.get_toks())
224                    .collect::<Vec<_>>(),
225                false,
226            )
227            .expect("Decode failed");
228
229        for (detokenized, (seq, (num_img_tokens, n_images))) in detokenized.into_iter().zip(
230            input_seqs
231                .iter_mut()
232                .zip(num_img_tokens.unwrap().into_iter().zip(n_images)),
233        ) {
234            let splits = self
235                .image_tag_splitter
236                .split(&detokenized)
237                .map(|span| &detokenized[span.range()])
238                .collect::<Vec<_>>();
239            let prompt_chunks = tokenizer
240                .encode_batch(splits, true)
241                .expect("Encode failed")
242                .into_iter()
243                .map(|enc| enc.get_ids().to_vec())
244                .collect::<Vec<_>>();
245
246            let image_tags = self.image_tag_splitter.find_iter(&detokenized);
247            let image_ids = image_tags
248                .into_iter()
249                .map(|s| {
250                    let s = &detokenized[s.range()];
251                    s.split('|')
252                        .nth(1)
253                        .unwrap()
254                        .split('_')
255                        .nth(1)
256                        .unwrap()
257                        .parse::<u32>()
258                        .expect("Failed to parse image id to u32")
259                })
260                .collect::<Vec<_>>();
261            let unique_image_ids = image_ids
262                .iter()
263                .copied()
264                .unique()
265                .sorted()
266                .collect::<Vec<_>>();
267            // `image_ids` must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
268            if unique_image_ids != (1u32..unique_image_ids.len() as u32 + 1).collect::<Vec<_>>() {
269                return Box::new(std::iter::once(Err(anyhow::Error::msg(
270                    "`image_ids` must start from 1, and must be continuous, e.g. [1, 2, 3], cannot be [1, 4, 5].",
271                ))));
272            }
273            // Total images must be the same as the number of image tags
274            if unique_image_ids.len() != n_images {
275                return Box::new(std::iter::once(Err(anyhow::Error::msg(
276                    "Total images must be the same as the number of image tags.",
277                ))));
278            }
279
280            // Use the TryInto + unwrap_or to handle case when id==0
281            let image_ids_pad = image_ids
282                .iter()
283                .map(|id| {
284                    [-(*id as i64)].repeat(
285                        num_img_tokens[TryInto::<usize>::try_into(*id as isize - 1)
286                            .unwrap_or(num_img_tokens.len() - 1)],
287                    )
288                })
289                .collect::<Vec<_>>();
290
291            let mut input_ids: Vec<i64> = Vec::new();
292            for item in prompt_chunks
293                .iter()
294                .map(|x| x.iter().map(|x| *x as i64).collect::<Vec<_>>())
295                .interleave(image_ids_pad)
296            {
297                input_ids.extend(item);
298            }
299
300            // NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
301            seq.set_toks_and_reallocate(
302                input_ids
303                    .iter()
304                    .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
305                    .collect::<Vec<_>>(),
306                paged_attn_metadata.as_mut(),
307            );
308
309            toks.push(input_ids);
310        }
311
312        let iter = if is_prompt {
313            get_prompt_input(
314                toks,
315                input_seqs,
316                device,
317                last_n_context_len,
318                return_raw_logits,
319                paged_attn_metadata.as_mut(),
320                None, // TODO: evaluate if it is possible to batch this
321                mapper,
322            )
323        } else {
324            get_completion_input(
325                toks,
326                input_seqs,
327                device,
328                no_kv_cache,
329                last_n_context_len,
330                return_raw_logits,
331                paged_attn_metadata.as_mut(),
332                None, // TODO: evaluate if it is possible to batch this
333                mapper,
334            )
335        };
336
337        Box::new(iter.into_iter().map(move |metadata| {
338            let text_models_inputs_processor::InnerInputProcessorOutput {
339                inputs:
340                    text_models_inputs_processor::InputMetadata {
341                        input,
342                        positions,
343                        context_lens,
344                        position_ids,
345                        paged_attn_meta,
346                        flash_meta,
347                    },
348                seq_indices,
349            } = metadata?;
350            let inputs: Box<dyn Any> = Box::new(ModelInputs {
351                input_ids: input,
352                seqlen_offsets: positions,
353                context_lens,
354                position_ids,
355                pixel_values: pixel_values.clone(),
356                model_specific_args: Box::new(Phi3VisionSpecificArgs {
357                    image_sizes: image_sizes.clone(),
358                }),
359                paged_attn_meta,
360                flash_meta,
361            });
362            Ok(InputProcessorOutput {
363                inputs,
364                seq_indices,
365            })
366        }))
367    }
368}
369
370impl Phi3InputsProcessor {
371    fn pad_image(
372        image: &DynamicImage,
373        top: u32,
374        bottom: u32,
375        left: u32,
376        right: u32,
377        pad_color: Rgba<u8>,
378    ) -> DynamicImage {
379        // Calculate the new dimensions
380        let new_width = image.width() + left + right;
381        let new_height = image.height() + top + bottom;
382
383        // Create a new image with the new dimensions and fill it with the pad color
384        let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
385        for x in 0..new_width {
386            for y in 0..new_height {
387                new_image.put_pixel(x, y, pad_color);
388            }
389        }
390
391        // Paste the original image into the center of the new image
392        new_image
393            .copy_from(image, left, top)
394            .expect("Failed to copy image");
395
396        new_image
397    }
398
399    fn padding_336(img: &DynamicImage) -> DynamicImage {
400        let (_width, height) = img.dimensions();
401        let tar = ((height as f64 / 336.0).ceil() * 336.0) as u32;
402        let top_padding = ((tar as f64 - height as f64 + 1.) / 2.) as u32;
403        let bottom_padding = tar - height - top_padding;
404        let left_padding = 0u32;
405        let right_padding = 0u32;
406        Self::pad_image(
407            img,
408            top_padding,
409            bottom_padding,
410            left_padding,
411            right_padding,
412            Rgba([255u8, 255, 255, 255]),
413        )
414    }
415
416    fn hd_transform(img: &DynamicImage, hd_num: usize) -> DynamicImage {
417        let (mut width, mut height) = img.dimensions();
418        let mut transposed = false;
419
420        let img = if width < height {
421            let img = img.rotate90();
422            transposed = true;
423            width = img.width();
424            height = img.height();
425            img
426        } else {
427            // NOTE: Don't love the clone.
428            img.clone()
429        };
430
431        let ratio = width as f64 / height as f64;
432        let mut scale = 1.0;
433        while (scale * (scale / ratio).ceil()) <= hd_num as f64 {
434            scale += 1.0;
435        }
436        scale -= 1.0;
437
438        let new_width = (scale * 336.0) as u32;
439        let new_height = (new_width as f64 / ratio) as u32;
440
441        let resized_img = img.resize_exact(new_width, new_height, FilterType::Nearest);
442        let padded_img = Self::padding_336(&resized_img);
443
444        if transposed {
445            return padded_img.rotate270();
446        }
447
448        padded_img
449    }
450}
451
452fn pad_to_max_num_crops_tensor(image: &Tensor, max_crops: usize) -> Result<Tensor> {
453    let (b, _, h, w) = image.dims4()?;
454    if b < max_crops {
455        let pad = Tensor::zeros((max_crops - b, 3, h, w), image.dtype(), image.device())?;
456        Tensor::cat(&[image, &pad], 0)
457    } else {
458        Ok(image.clone())
459    }
460}
461
462impl ImagePreProcessor for Phi3InputsProcessor {
463    #[allow(clippy::excessive_precision)]
464    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
465    #[allow(clippy::excessive_precision)]
466    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
467
468    fn preprocess(
469        &self,
470        mut images: Vec<DynamicImage>,
471        videos: Vec<Vec<DynamicImage>>,
472        config: &PreProcessorConfig,
473        device: &Device,
474        (_, _): (usize, usize),
475    ) -> Result<PreprocessedImages> {
476        // If no images, will not call this.
477        assert!(!images.is_empty());
478        assert!(videos.is_empty());
479
480        let mut image_sizes = Vec::new();
481        let mut padded_images = Vec::new();
482        let mut num_img_tokens = Vec::new();
483        // If >1 images, resize them all to the largest, potentially destroying aspect ratio
484        let mut max_size = None;
485        for image in images.iter() {
486            if max_size.is_none() {
487                max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
488            } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
489                max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
490            } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
491                max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
492            }
493        }
494        let (max_h, max_w) = max_size.unwrap();
495        for image in images.iter_mut() {
496            *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
497        }
498
499        for image in images.iter_mut() {
500            // Convert to rgb, default to true
501            if config.do_convert_rgb.unwrap_or(true) {
502                *image = DynamicImage::ImageRgb8(image.to_rgb8());
503            }
504
505            let hd_image = Self::hd_transform(image, config.num_crops.expect("Need `num_crops`"));
506
507            // Both hd and global have a normalization
508            // Transforms for the HD image
509            let transforms_hd = Transforms {
510                input: &ToTensor,
511                inner_transforms: &[&Normalize {
512                    mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
513                    std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
514                }],
515            };
516
517            // (3,h,w)
518            let hd_image = hd_image.apply(transforms_hd, device)?;
519
520            // Resize with bicubic interpolation
521            // (3,336,336)
522            let global_image = hd_image.unsqueeze(0)?.interpolate2d(336, 336)?;
523
524            let (_, h, w) = hd_image.dims3()?;
525            let num_image_tokens = ((h as f32 / 336. * w as f32 / 336. + 1.) * 144.
526                + ((h as f32 / 336.) + 1.) * 12.
527                + 1.) as usize;
528
529            let hd_image_reshape = hd_image
530                .reshape((
531                    1,
532                    3,
533                    (h as f32 / 336.) as usize,
534                    336,
535                    (w as f32 / 336.) as usize,
536                    336,
537                ))?
538                .permute((0, 2, 4, 1, 3, 5))?
539                .reshape(((), 3, 336, 336))?;
540            let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
541            let image_transformed = pad_to_max_num_crops_tensor(
542                &hd_image_reshape,
543                config.num_crops.expect("Need `num_crops`") + 1,
544            )?;
545            image_sizes.push((h, w));
546            padded_images.push(image_transformed);
547            num_img_tokens.push(num_image_tokens);
548        }
549        if padded_images.len() > 1 {
550            candle_core::bail!("Can only process one image per batch");
551        }
552        let image_sizes = image_sizes[0];
553
554        Ok(PreprocessedImages {
555            pixel_values: Tensor::stack(&padded_images, 0)?,
556            image_sizes: Some((image_sizes.0, image_sizes.1)),
557            pixel_attention_mask: None,
558            num_img_tokens: Some(num_img_tokens),
559            aspect_ratio_ids: None,
560            aspect_ratio_mask: None,
561            num_tiles: None,
562            image_grid_thw: None,
563            video_grid_thw: None,
564            rows: None,
565            cols: None,
566            pixel_values_list: None,
567            tgt_sizes: None,
568            image_sizes_all: None,
569            num_crops: None,
570        })
571    }
572}