mistralrs_core/vision_models/mistral3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9
10use crate::{
11    device_map::DeviceMapper,
12    pipeline::{
13        text_models_inputs_processor::{
14            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
15        },
16        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
17    },
18    sequence::Sequence,
19    vision_models::{
20        image_processor::{ImagePreProcessor, PreprocessedImages},
21        preprocessor_config::{PreProcessorConfig, ToFilter},
22        processor_config::ProcessorConfig,
23        ModelInputs,
24    },
25};
26
27use super::Mistral3SpecificArgs;
28
29const PLACEHOLDER: &str = "<placeholder>";
30
31struct Mistral3ImageProcessor {
32    image_break_token: String,
33    image_end_token: String,
34    image_token: String,
35    patch_size: usize,
36    spatial_merge_size: usize,
37}
38
39pub struct Mistral3Processor {
40    image_break_token: String,
41    image_end_token: String,
42    image_token: String,
43    patch_size: usize,
44    spatial_merge_size: usize,
45}
46
47impl Mistral3Processor {
48    pub fn new(processor_config: ProcessorConfig) -> Self {
49        Self {
50            image_break_token: processor_config.image_break_token.unwrap().clone(),
51            image_end_token: processor_config.image_end_token.unwrap().clone(),
52            image_token: processor_config.image_token.unwrap().clone(),
53            patch_size: processor_config.patch_size.unwrap(),
54            spatial_merge_size: processor_config.spatial_merge_size.unwrap(),
55        }
56    }
57}
58
59impl Processor for Mistral3Processor {
60    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
61        Arc::new(Mistral3ImageProcessor {
62            image_break_token: self.image_break_token.clone(),
63            image_end_token: self.image_end_token.clone(),
64            image_token: self.image_token.clone(),
65            patch_size: self.patch_size,
66            spatial_merge_size: self.spatial_merge_size,
67        })
68    }
69
70    fn get_special_tokens(&self) -> &[&'static str] {
71        &[]
72    }
73
74    fn template_action(&self) -> MessagesAction {
75        MessagesAction::Keep
76    }
77}
78
79impl InputsProcessor for Mistral3ImageProcessor {
80    fn get_type(&self) -> InputsProcessorType {
81        InputsProcessorType::Vision
82    }
83    fn process_inputs(
84        &self,
85        tokenizer: Option<Arc<Tokenizer>>,
86        input_seqs: &mut [&mut Sequence],
87        is_prompt: bool,
88        is_xlora: bool,
89        device: &Device,
90        no_kv_cache: bool,
91        last_n_context_len: Option<(usize, usize)>,
92        return_raw_logits: bool,
93        other_config: Option<Arc<dyn Any>>,
94        mut paged_attn_metadata: Option<PagedAttentionMeta>,
95        mapper: Option<&dyn DeviceMapper>,
96    ) -> anyhow::Result<InputProcessorOutput> {
97        if is_xlora {
98            return Err(anyhow::Error::msg(
99                "Cannot make inputs for X-LoRA vision model.",
100            ));
101        }
102        if no_kv_cache {
103            return Err(anyhow::Error::msg("Vision model must have kv cache."));
104        }
105        let Some(tokenizer) = tokenizer else {
106            return Err(anyhow::Error::msg(
107                "Idefics3ImageProcessor requires a specified tokenizer.",
108            ));
109        };
110
111        let config = other_config.expect("Need a PreProcessorConfig config.");
112        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
113
114        let has_images = input_seqs.iter().all(|seq| seq.has_images());
115
116        let (pixel_values, image_sizes) = if has_images {
117            let mut pixel_values_accum = Vec::new();
118            let mut image_sizes_accum = Vec::new();
119
120            for seq in input_seqs.iter_mut() {
121                let PreprocessedImages {
122                    pixel_values,
123                    pixel_attention_mask: _,
124                    image_sizes: _,
125                    num_img_tokens: _,
126                    aspect_ratio_ids: _,
127                    aspect_ratio_mask: _,
128                    num_tiles: _,
129                    image_grid_thw: _,
130                    video_grid_thw: _,
131                    rows: _,
132                    cols: _,
133                    pixel_values_list: _,
134                    tgt_sizes: _,
135                    image_sizes_all,
136                    num_crops: _,
137                } = self
138                    .preprocess(
139                        seq.take_images()
140                            .expect("Need to have images by this point."),
141                        vec![],
142                        config,
143                        device,
144                        (usize::MAX, usize::MAX), // Don't use it here...
145                    )
146                    .expect("Preprocessing failed");
147                let image_sizes_all = image_sizes_all.unwrap();
148
149                // Deliberately no .unsqueeze here
150                pixel_values_accum.push(pixel_values.clone());
151                image_sizes_accum.extend_from_slice(&image_sizes_all);
152
153                let mut prompt = tokenizer
154                    .decode(seq.get_toks(), false)
155                    .expect("Detokenization failed!");
156
157                let mut image_sizes_all_iter = image_sizes_all.into_iter();
158                let mut replace_strings = Vec::new();
159                while prompt.contains(&self.image_token) {
160                    let (height, width) = image_sizes_all_iter.next().unwrap();
161                    let num_height_tokens =
162                        (height as usize) / (self.patch_size * self.spatial_merge_size);
163                    let num_width_tokens =
164                        (width as usize) / (self.patch_size * self.spatial_merge_size);
165
166                    let mut replace_tokens = vec![
167                        [
168                            vec![self.image_token.clone(); num_width_tokens],
169                            vec![self.image_break_token.clone()],
170                        ]
171                        .concat();
172                        num_height_tokens
173                    ]
174                    .into_iter()
175                    .flatten()
176                    .collect::<Vec<_>>();
177
178                    *replace_tokens.last_mut().unwrap() = self.image_end_token.clone();
179
180                    replace_strings.push(replace_tokens.join(""));
181                    prompt = prompt.replace(&self.image_token, PLACEHOLDER);
182                }
183
184                while prompt.contains(PLACEHOLDER) {
185                    let replace_str = replace_strings.pop().unwrap();
186                    prompt = prompt.replace(PLACEHOLDER, &replace_str);
187                }
188
189                if !seq.multimodal.has_changed_prompt {
190                    seq.set_initial_prompt(prompt.clone());
191                    let toks = tokenizer
192                        .encode_fast(prompt, false)
193                        .expect("Detokenization failed!");
194
195                    let ids = toks.get_ids().to_vec();
196                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
197                    seq.multimodal.has_changed_prompt = true;
198                }
199            }
200
201            (
202                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
203                Some(image_sizes_accum),
204            )
205        } else {
206            (None, None)
207        };
208
209        let text_models_inputs_processor::InnerInputProcessorOutput {
210            inputs:
211                text_models_inputs_processor::InputMetadata {
212                    input,
213                    positions,
214                    context_lens,
215                    position_ids,
216                    paged_attn_meta,
217                    flash_meta,
218                },
219            seq_indices,
220        } = if is_prompt {
221            get_prompt_input(
222                input_seqs
223                    .iter()
224                    .map(|seq| seq.get_toks())
225                    .collect::<Vec<_>>(),
226                input_seqs,
227                device,
228                last_n_context_len,
229                return_raw_logits,
230                paged_attn_metadata.as_mut(),
231                mapper,
232            )
233            .unwrap()
234        } else {
235            get_completion_input(
236                input_seqs
237                    .iter()
238                    .map(|seq| seq.get_toks())
239                    .collect::<Vec<_>>(),
240                input_seqs,
241                device,
242                no_kv_cache,
243                last_n_context_len,
244                return_raw_logits,
245                paged_attn_metadata.as_mut(),
246                mapper,
247            )
248            .unwrap()
249        };
250
251        let inputs: Box<dyn Any> = Box::new(ModelInputs {
252            input_ids: input,
253            seqlen_offsets: positions,
254            context_lens,
255            position_ids,
256            pixel_values,
257            model_specific_args: Box::new(Mistral3SpecificArgs { image_sizes }),
258            paged_attn_meta,
259            flash_meta,
260        });
261        Ok(InputProcessorOutput {
262            inputs,
263            seq_indices,
264        })
265    }
266}
267
268impl Mistral3ImageProcessor {
269    #[allow(clippy::too_many_arguments)]
270    fn resize(
271        &self,
272        image: &DynamicImage,
273        mut height: usize,
274        mut width: usize,
275        max_height: usize,
276        max_width: usize,
277        patch_size: usize,
278        filter: FilterType,
279    ) -> DynamicImage {
280        let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
281        if ratio > 1. {
282            height = (height as f64 / ratio).floor() as usize;
283            width = (width as f64 / ratio).floor() as usize;
284        }
285
286        let num_height_tokens = (height - 1) / patch_size + 1;
287        let num_width_tokens = (width - 1) / patch_size + 1;
288
289        let resize_height = num_height_tokens * patch_size;
290        let resize_width = num_width_tokens * patch_size;
291
292        image.resize_exact(resize_width as u32, resize_height as u32, filter)
293    }
294}
295
296impl ImagePreProcessor for Mistral3ImageProcessor {
297    #[allow(clippy::excessive_precision)]
298    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
299    #[allow(clippy::excessive_precision)]
300    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
301
302    // https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/image_processing_pixtral.py
303    fn preprocess(
304        &self,
305        mut images: Vec<DynamicImage>,
306        videos: Vec<Vec<DynamicImage>>,
307        config: &PreProcessorConfig,
308        device: &Device,
309        (_bs, _max_num_images): (usize, usize),
310    ) -> Result<PreprocessedImages> {
311        assert!(videos.is_empty());
312
313        let do_resize = config.do_resize.unwrap();
314        let do_rescale = config.do_rescale.unwrap();
315        let rescale_factor = config.rescale_factor.unwrap();
316        let do_normalize = config.do_normalize.unwrap();
317        let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
318        let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
319        let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
320        let patch_size = config.patch_size.unwrap();
321        let size = config.size.as_ref().unwrap();
322        let resample = config.resampling.to_filter()?;
323
324        let default_to_square = config.default_to_square.unwrap();
325        assert!(default_to_square);
326
327        let mut pixel_values = Vec::new();
328        let mut image_sizes = Vec::new();
329
330        let (max_height, max_width) = if size.contains_key("longest_edge") {
331            (size["longest_edge"] as usize, size["longest_edge"] as usize)
332        } else if size.contains_key("height") && size.contains_key("width") {
333            (size["height"] as usize, size["width"] as usize)
334        } else {
335            candle_core::bail!("Size must be a map of `longest_edge` or `height` and `width`.");
336        };
337
338        for image in images.iter_mut() {
339            let (width, height) = image.dimensions();
340
341            // Convert to rgb
342            if do_convert_rgb {
343                *image = DynamicImage::ImageRgb8(image.to_rgb8());
344            }
345
346            if do_resize {
347                *image = self.resize(
348                    image,
349                    height as usize,
350                    width as usize,
351                    max_height,
352                    max_width,
353                    patch_size,
354                    resample,
355                );
356            }
357
358            let (width, height) = image.dimensions();
359
360            image_sizes.push((height, width));
361        }
362
363        images = mistralrs_vision::pad_to_max_image_size(images);
364
365        for image in images.iter_mut() {
366            let transforms = Transforms {
367                input: &ToTensorNoNorm,
368                inner_transforms: &[
369                    &do_rescale.then_some(Rescale {
370                        factor: Some(rescale_factor),
371                    }),
372                    &do_normalize.then(|| Normalize {
373                        mean: image_mean.to_vec(),
374                        std: image_std.to_vec(),
375                    }),
376                ],
377            };
378
379            let image = image.apply(transforms, device)?;
380            pixel_values.push(image.unsqueeze(0)?);
381        }
382
383        Ok(PreprocessedImages {
384            pixel_values: Tensor::cat(&pixel_values, 0)?,
385            pixel_attention_mask: None,
386            image_sizes: None,
387            num_img_tokens: None,
388            aspect_ratio_ids: None,
389            aspect_ratio_mask: None,
390            num_tiles: None,
391            image_grid_thw: None,
392            video_grid_thw: None,
393            rows: None,
394            cols: None,
395            pixel_values_list: None,
396            tgt_sizes: None,
397            image_sizes_all: Some(image_sizes),
398            num_crops: None,
399        })
400    }
401}