mistralrs_core/vision_models/phi3/
phi3_inputs_processor.rs

1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
9use regex_automata::meta::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14    device_map::DeviceMapper,
15    pipeline::{
16        text_models_inputs_processor::{
17            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18        },
19        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20        ProcessorCreator,
21    },
22    sequence::Sequence,
23};
24
25use crate::vision_models::{
26    image_processor::{ImagePreProcessor, PreprocessedImages},
27    phi3::Phi3VisionSpecificArgs,
28    preprocessor_config::PreProcessorConfig,
29    processor_config::ProcessorConfig,
30    ModelInputs,
31};
32
33// Input processor
34pub struct Phi3InputsProcessor {
35    image_tag_splitter: Regex,
36}
37// Processor
38pub struct Phi3Processor {
39    inputs_processor: Arc<Phi3InputsProcessor>,
40}
41
42impl ProcessorCreator for Phi3Processor {
43    fn new_processor(
44        _: Option<ProcessorConfig>,
45        _: PreProcessorConfig,
46    ) -> Arc<dyn Processor + Send + Sync> {
47        Arc::new(Self {
48            inputs_processor: Arc::new(Phi3InputsProcessor {
49                image_tag_splitter: Regex::new(r"<\|image_\d+\|>")
50                    .expect("Failed to compile split regex."),
51            }),
52        })
53    }
54}
55
56impl Processor for Phi3Processor {
57    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58        self.inputs_processor.clone()
59    }
60    fn get_special_tokens(&self) -> &[&'static str] {
61        &[]
62    }
63    fn template_action(&self) -> MessagesAction {
64        MessagesAction::FlattenOnlyText
65    }
66}
67
68impl InputsProcessor for Phi3InputsProcessor {
69    fn get_type(&self) -> InputsProcessorType {
70        InputsProcessorType::Vision
71    }
72    fn process_inputs(
73        &self,
74        tokenizer: Option<Arc<Tokenizer>>,
75        input_seqs: &mut [&mut Sequence],
76        is_prompt: bool,
77        is_xlora: bool,
78        device: &Device,
79        no_kv_cache: bool,
80        last_n_context_len: Option<(usize, usize)>,
81        return_raw_logits: bool,
82        other_config: Option<Arc<dyn Any>>,
83        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84        prompt_chunksize: Option<NonZeroUsize>,
85        mapper: Option<&dyn DeviceMapper>,
86    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87        if is_xlora {
88            return Box::new(std::iter::once(Err(anyhow::Error::msg(
89                "Cannot make inputs for X-LoRA vision model.",
90            ))));
91        }
92        if no_kv_cache {
93            return Box::new(std::iter::once(Err(anyhow::Error::msg(
94                "Vision model must have kv cache.",
95            ))));
96        }
97        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
98        if prompt_chunksize.is_some() {
99            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100        }
101        let Some(tokenizer) = tokenizer else {
102            return Box::new(std::iter::once(Err(anyhow::Error::msg(
103                "Phi3InputProcessor requires a specified tokenizer.",
104            ))));
105        };
106
107        let config = other_config
108            .clone()
109            .expect("Need a PreProcessorConfig config.");
110        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112        let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114        let (pixel_values, image_sizes, num_img_tokens, n_images) = if has_images {
115            let mut pixel_values_accum = Vec::new();
116            let mut image_sizes_accum = Vec::new();
117            let mut num_img_tokens_accum = Vec::new();
118            let mut n_images = Vec::new();
119            for seq in input_seqs.iter_mut() {
120                let imgs = seq
121                    .take_images()
122                    .expect("Need to have images by this point.");
123                let imgs_len = imgs.len();
124                n_images.push(imgs_len);
125                let PreprocessedImages {
126                    pixel_values,
127                    pixel_attention_mask: _,
128                    image_sizes,
129                    num_img_tokens,
130                    aspect_ratio_ids: _,
131                    aspect_ratio_mask: _,
132                    num_tiles: _,
133                    image_grid_thw: _,
134                    video_grid_thw: _,
135                    rows: _,
136                    cols: _,
137                    pixel_values_list: _,
138                    tgt_sizes: _,
139                    image_sizes_all: _,
140                    num_crops: _,
141                } = self
142                    .preprocess(
143                        imgs,
144                        vec![],
145                        config,
146                        device,
147                        (usize::MAX, usize::MAX), // Don't use it here...
148                    )
149                    .expect("Preprocessor failed");
150                let image_sizes = image_sizes.unwrap();
151                pixel_values_accum.push(pixel_values);
152                image_sizes_accum.push(image_sizes);
153                num_img_tokens_accum.push(num_img_tokens.unwrap());
154            }
155            (
156                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157                Some(image_sizes_accum),
158                Some(num_img_tokens_accum),
159                n_images,
160            )
161        } else {
162            return Box::new(
163                text_models_inputs_processor::TextInputsProcessor
164                    .process_inputs(
165                        Some(tokenizer),
166                        input_seqs,
167                        is_prompt,
168                        is_xlora,
169                        device,
170                        no_kv_cache,
171                        last_n_context_len,
172                        return_raw_logits,
173                        other_config,
174                        paged_attn_metadata,
175                        None, // TODO
176                        mapper,
177                    )
178                    .map(|metadata| {
179                        let InputProcessorOutput {
180                            inputs,
181                            seq_indices,
182                        } = metadata?;
183
184                        let text_models_inputs_processor::ModelInputs {
185                            input_ids,
186                            input_ids_full: _,
187                            seqlen_offsets,
188                            seqlen_offsets_full: _,
189                            context_lens,
190                            position_ids,
191                            paged_attn_meta,
192                            flash_meta,
193                            flash_meta_full: _,
194                        } = *inputs
195                            .downcast::<text_models_inputs_processor::ModelInputs>()
196                            .expect("Downcast failed.");
197
198                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
199                            input_ids,
200                            seqlen_offsets,
201                            context_lens,
202                            position_ids,
203                            pixel_values: None,
204                            model_specific_args: Box::new(Phi3VisionSpecificArgs {
205                                image_sizes: None,
206                            }),
207                            paged_attn_meta,
208                            flash_meta,
209                        });
210                        Ok(InputProcessorOutput {
211                            inputs,
212                            seq_indices,
213                        })
214                    }),
215            );
216        };
217
218        let mut toks = Vec::new();
219        let detokenized = tokenizer
220            .decode_batch(
221                &input_seqs
222                    .iter()
223                    .map(|seq| seq.get_toks())
224                    .collect::<Vec<_>>(),
225                false,
226            )
227            .expect("Decode failed");
228
229        for (detokenized, (seq, (num_img_tokens, n_images))) in detokenized.into_iter().zip(
230            input_seqs
231                .iter_mut()
232                .zip(num_img_tokens.unwrap().into_iter().zip(n_images)),
233        ) {
234            let splits = self
235                .image_tag_splitter
236                .split(&detokenized)
237                .map(|span| &detokenized[span.range()])
238                .collect::<Vec<_>>();
239            let prompt_chunks = tokenizer
240                .encode_batch(splits, true)
241                .expect("Encode failed")
242                .into_iter()
243                .map(|enc| enc.get_ids().to_vec())
244                .collect::<Vec<_>>();
245
246            let image_tags = self.image_tag_splitter.find_iter(&detokenized);
247            let image_ids = image_tags
248                .into_iter()
249                .map(|s| {
250                    let s = &detokenized[s.range()];
251                    s.split('|')
252                        .nth(1)
253                        .unwrap()
254                        .split('_')
255                        .nth(1)
256                        .unwrap()
257                        .parse::<u32>()
258                        .expect("Failed to parse image id to u32")
259                })
260                .collect::<Vec<_>>();
261            let unique_image_ids = image_ids
262                .iter()
263                .copied()
264                .unique()
265                .sorted()
266                .collect::<Vec<_>>();
267            // `image_ids` must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
268            if unique_image_ids != (1u32..unique_image_ids.len() as u32 + 1).collect::<Vec<_>>() {
269                return Box::new(std::iter::once(Err(anyhow::Error::msg(
270                    "`image_ids` must start from 1, and must be continuous, e.g. [1, 2, 3], cannot be [1, 4, 5].",
271                ))));
272            }
273            // Total images must be the same as the number of image tags
274            if unique_image_ids.len() != n_images {
275                return Box::new(std::iter::once(Err(anyhow::Error::msg(
276                    "Total images must be the same as the number of image tags.",
277                ))));
278            }
279
280            // Use the TryInto + unwrap_or to handle case when id==0
281            let image_ids_pad = image_ids
282                .iter()
283                .map(|id| {
284                    [-(*id as i64)].repeat(
285                        num_img_tokens[TryInto::<usize>::try_into(*id as isize - 1)
286                            .unwrap_or(num_img_tokens.len() - 1)],
287                    )
288                })
289                .collect::<Vec<_>>();
290
291            let mut input_ids: Vec<i64> = Vec::new();
292            for item in prompt_chunks
293                .iter()
294                .map(|x| x.iter().map(|x| *x as i64).collect::<Vec<_>>())
295                .interleave(image_ids_pad)
296            {
297                input_ids.extend(item);
298            }
299
300            let new_ids = input_ids
301                .iter()
302                .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
303                .collect::<Vec<_>>();
304            if !seq.has_changed_prompt {
305                let new_prompt = tokenizer.decode(&new_ids, false).unwrap();
306                seq.set_initial_prompt(new_prompt);
307                // NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
308                seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
309                seq.has_changed_prompt = true;
310            }
311
312            toks.push(input_ids);
313        }
314
315        let iter = if is_prompt {
316            get_prompt_input(
317                toks,
318                input_seqs,
319                device,
320                last_n_context_len,
321                return_raw_logits,
322                paged_attn_metadata.as_mut(),
323                None, // TODO: evaluate if it is possible to batch this
324                mapper,
325            )
326        } else {
327            get_completion_input(
328                toks,
329                input_seqs,
330                device,
331                no_kv_cache,
332                last_n_context_len,
333                return_raw_logits,
334                paged_attn_metadata.as_mut(),
335                None, // TODO: evaluate if it is possible to batch this
336                mapper,
337            )
338        };
339
340        Box::new(iter.into_iter().map(move |metadata| {
341            let text_models_inputs_processor::InnerInputProcessorOutput {
342                inputs:
343                    text_models_inputs_processor::InputMetadata {
344                        input,
345                        positions,
346                        context_lens,
347                        position_ids,
348                        paged_attn_meta,
349                        flash_meta,
350                    },
351                seq_indices,
352            } = metadata?;
353            let inputs: Box<dyn Any> = Box::new(ModelInputs {
354                input_ids: input,
355                seqlen_offsets: positions,
356                context_lens,
357                position_ids,
358                pixel_values: pixel_values.clone(),
359                model_specific_args: Box::new(Phi3VisionSpecificArgs {
360                    image_sizes: image_sizes.clone(),
361                }),
362                paged_attn_meta,
363                flash_meta,
364            });
365            Ok(InputProcessorOutput {
366                inputs,
367                seq_indices,
368            })
369        }))
370    }
371}
372
373impl Phi3InputsProcessor {
374    fn pad_image(
375        image: &DynamicImage,
376        top: u32,
377        bottom: u32,
378        left: u32,
379        right: u32,
380        pad_color: Rgba<u8>,
381    ) -> DynamicImage {
382        // Calculate the new dimensions
383        let new_width = image.width() + left + right;
384        let new_height = image.height() + top + bottom;
385
386        // Create a new image with the new dimensions and fill it with the pad color
387        let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
388        for x in 0..new_width {
389            for y in 0..new_height {
390                new_image.put_pixel(x, y, pad_color);
391            }
392        }
393
394        // Paste the original image into the center of the new image
395        new_image
396            .copy_from(image, left, top)
397            .expect("Failed to copy image");
398
399        new_image
400    }
401
402    fn padding_336(img: &DynamicImage) -> DynamicImage {
403        let (_width, height) = img.dimensions();
404        let tar = ((height as f64 / 336.0).ceil() * 336.0) as u32;
405        let top_padding = ((tar as f64 - height as f64 + 1.) / 2.) as u32;
406        let bottom_padding = tar - height - top_padding;
407        let left_padding = 0u32;
408        let right_padding = 0u32;
409        Self::pad_image(
410            img,
411            top_padding,
412            bottom_padding,
413            left_padding,
414            right_padding,
415            Rgba([255u8, 255, 255, 255]),
416        )
417    }
418
419    fn hd_transform(img: &DynamicImage, hd_num: usize) -> DynamicImage {
420        let (mut width, mut height) = img.dimensions();
421        let mut transposed = false;
422
423        let img = if width < height {
424            let img = img.rotate90();
425            transposed = true;
426            width = img.width();
427            height = img.height();
428            img
429        } else {
430            // NOTE: Don't love the clone.
431            img.clone()
432        };
433
434        let ratio = width as f64 / height as f64;
435        let mut scale = 1.0;
436        while (scale * (scale / ratio).ceil()) <= hd_num as f64 {
437            scale += 1.0;
438        }
439        scale -= 1.0;
440
441        let new_width = (scale * 336.0) as u32;
442        let new_height = (new_width as f64 / ratio) as u32;
443
444        let resized_img = img.resize_exact(new_width, new_height, FilterType::Nearest);
445        let padded_img = Self::padding_336(&resized_img);
446
447        if transposed {
448            return padded_img.rotate270();
449        }
450
451        padded_img
452    }
453}
454
455fn pad_to_max_num_crops_tensor(image: &Tensor, max_crops: usize) -> Result<Tensor> {
456    let (b, _, h, w) = image.dims4()?;
457    if b < max_crops {
458        let pad = Tensor::zeros((max_crops - b, 3, h, w), image.dtype(), image.device())?;
459        Tensor::cat(&[image, &pad], 0)
460    } else {
461        Ok(image.clone())
462    }
463}
464
465impl ImagePreProcessor for Phi3InputsProcessor {
466    #[allow(clippy::excessive_precision)]
467    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
468    #[allow(clippy::excessive_precision)]
469    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
470
471    fn preprocess(
472        &self,
473        mut images: Vec<DynamicImage>,
474        videos: Vec<Vec<DynamicImage>>,
475        config: &PreProcessorConfig,
476        device: &Device,
477        (_, _): (usize, usize),
478    ) -> Result<PreprocessedImages> {
479        // If no images, will not call this.
480        assert!(!images.is_empty());
481        assert!(videos.is_empty());
482
483        let mut image_sizes = Vec::new();
484        let mut padded_images = Vec::new();
485        let mut num_img_tokens = Vec::new();
486        // If >1 images, resize them all to the largest, potentially destroying aspect ratio
487        let mut max_size = None;
488        for image in images.iter() {
489            if max_size.is_none() {
490                max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
491            } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
492                max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
493            } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
494                max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
495            }
496        }
497        let (max_h, max_w) = max_size.unwrap();
498        for image in images.iter_mut() {
499            *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
500        }
501
502        for image in images.iter_mut() {
503            // Convert to rgb, default to true
504            if config.do_convert_rgb.unwrap_or(true) {
505                *image = DynamicImage::ImageRgb8(image.to_rgb8());
506            }
507
508            let hd_image = Self::hd_transform(image, config.num_crops.expect("Need `num_crops`"));
509
510            // Both hd and global have a normalization
511            // Transforms for the HD image
512            let transforms_hd = Transforms {
513                input: &ToTensor,
514                inner_transforms: &[&Normalize {
515                    mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
516                    std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
517                }],
518            };
519
520            // (3,h,w)
521            let hd_image = hd_image.apply(transforms_hd, device)?;
522
523            // Resize with bicubic interpolation
524            // (3,336,336)
525            let global_image = hd_image.unsqueeze(0)?.interpolate2d(336, 336)?;
526
527            let (_, h, w) = hd_image.dims3()?;
528            let num_image_tokens = ((h as f32 / 336. * w as f32 / 336. + 1.) * 144.
529                + ((h as f32 / 336.) + 1.) * 12.
530                + 1.) as usize;
531
532            let hd_image_reshape = hd_image
533                .reshape((
534                    1,
535                    3,
536                    (h as f32 / 336.) as usize,
537                    336,
538                    (w as f32 / 336.) as usize,
539                    336,
540                ))?
541                .permute((0, 2, 4, 1, 3, 5))?
542                .reshape(((), 3, 336, 336))?;
543            let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
544            let image_transformed = pad_to_max_num_crops_tensor(
545                &hd_image_reshape,
546                config.num_crops.expect("Need `num_crops`") + 1,
547            )?;
548            image_sizes.push((h, w));
549            padded_images.push(image_transformed);
550            num_img_tokens.push(num_image_tokens);
551        }
552        if padded_images.len() > 1 {
553            candle_core::bail!("Can only process one image per batch");
554        }
555        let image_sizes = image_sizes[0];
556
557        Ok(PreprocessedImages {
558            pixel_values: Tensor::stack(&padded_images, 0)?,
559            image_sizes: Some((image_sizes.0, image_sizes.1)),
560            pixel_attention_mask: None,
561            num_img_tokens: Some(num_img_tokens),
562            aspect_ratio_ids: None,
563            aspect_ratio_mask: None,
564            num_tiles: None,
565            image_grid_thw: None,
566            video_grid_thw: None,
567            rows: None,
568            cols: None,
569            pixel_values_list: None,
570            tgt_sizes: None,
571            image_sizes_all: None,
572            num_crops: None,
573        })
574    }
575}