mistralrs_core/vision_models/mistral3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9use tracing::warn;
10
11use crate::{
12    device_map::DeviceMapper,
13    pipeline::{
14        text_models_inputs_processor::{
15            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16        },
17        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18    },
19    sequence::Sequence,
20    vision_models::{
21        image_processor::{ImagePreProcessor, PreprocessedImages},
22        preprocessor_config::{PreProcessorConfig, ToFilter},
23        processor_config::ProcessorConfig,
24        ModelInputs,
25    },
26};
27
28use super::Mistral3SpecificArgs;
29
30const PLACEHOLDER: &str = "<placeholder>";
31
32struct Mistral3ImageProcessor {
33    image_break_token: String,
34    image_end_token: String,
35    image_token: String,
36    patch_size: usize,
37    spatial_merge_size: usize,
38}
39
40pub struct Mistral3Processor {
41    image_break_token: String,
42    image_end_token: String,
43    image_token: String,
44    patch_size: usize,
45    spatial_merge_size: usize,
46}
47
48impl Mistral3Processor {
49    pub fn new(processor_config: ProcessorConfig) -> Self {
50        Self {
51            image_break_token: processor_config.image_break_token.unwrap().clone(),
52            image_end_token: processor_config.image_end_token.unwrap().clone(),
53            image_token: processor_config.image_token.unwrap().clone(),
54            patch_size: processor_config.patch_size.unwrap(),
55            spatial_merge_size: processor_config.spatial_merge_size.unwrap(),
56        }
57    }
58}
59
60impl Processor for Mistral3Processor {
61    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
62        Arc::new(Mistral3ImageProcessor {
63            image_break_token: self.image_break_token.clone(),
64            image_end_token: self.image_end_token.clone(),
65            image_token: self.image_token.clone(),
66            patch_size: self.patch_size,
67            spatial_merge_size: self.spatial_merge_size,
68        })
69    }
70
71    fn get_special_tokens(&self) -> &[&'static str] {
72        &[]
73    }
74
75    fn template_action(&self) -> MessagesAction {
76        MessagesAction::Keep
77    }
78}
79
80impl InputsProcessor for Mistral3ImageProcessor {
81    fn get_type(&self) -> InputsProcessorType {
82        InputsProcessorType::Vision
83    }
84    fn process_inputs(
85        &self,
86        tokenizer: Option<Arc<Tokenizer>>,
87        input_seqs: &mut [&mut Sequence],
88        is_prompt: bool,
89        is_xlora: bool,
90        device: &Device,
91        no_kv_cache: bool,
92        last_n_context_len: Option<(usize, usize)>,
93        return_raw_logits: bool,
94        other_config: Option<Arc<dyn Any>>,
95        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
96        prompt_chunksize: Option<NonZeroUsize>,
97        mapper: Option<&dyn DeviceMapper>,
98    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
99        if is_xlora {
100            return Box::new(std::iter::once(Err(anyhow::Error::msg(
101                "Cannot make inputs for X-LoRA vision model.",
102            ))));
103        }
104        if no_kv_cache {
105            return Box::new(std::iter::once(Err(anyhow::Error::msg(
106                "Vision model must have kv cache.",
107            ))));
108        }
109        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
110        if prompt_chunksize.is_some() {
111            warn!("`prompt_chunksize` is set. Mistral3 does not support prompt batching.");
112        }
113        let Some(tokenizer) = tokenizer else {
114            return Box::new(std::iter::once(Err(anyhow::Error::msg(
115                "Idefics3ImageProcessor requires a specified tokenizer.",
116            ))));
117        };
118
119        let config = other_config.expect("Need a PreProcessorConfig config.");
120        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
121
122        let has_images = input_seqs.iter().all(|seq| seq.has_images());
123
124        let (pixel_values, image_sizes) = if has_images {
125            let mut pixel_values_accum = Vec::new();
126            let mut image_sizes_accum = Vec::new();
127
128            for seq in input_seqs.iter_mut() {
129                let PreprocessedImages {
130                    pixel_values,
131                    pixel_attention_mask: _,
132                    image_sizes: _,
133                    num_img_tokens: _,
134                    aspect_ratio_ids: _,
135                    aspect_ratio_mask: _,
136                    num_tiles: _,
137                    image_grid_thw: _,
138                    video_grid_thw: _,
139                    rows: _,
140                    cols: _,
141                    pixel_values_list: _,
142                    tgt_sizes: _,
143                    image_sizes_all,
144                    num_crops: _,
145                } = self
146                    .preprocess(
147                        seq.take_images()
148                            .expect("Need to have images by this point."),
149                        vec![],
150                        config,
151                        device,
152                        (usize::MAX, usize::MAX), // Don't use it here...
153                    )
154                    .expect("Preprocessing failed");
155                let image_sizes_all = image_sizes_all.unwrap();
156
157                // Deliberately no .unsqueeze here
158                pixel_values_accum.push(pixel_values.clone());
159                image_sizes_accum.extend_from_slice(&image_sizes_all);
160
161                let mut prompt = tokenizer
162                    .decode(seq.get_toks(), false)
163                    .expect("Detokenization failed!");
164
165                let mut image_sizes_all_iter = image_sizes_all.into_iter();
166                let mut replace_strings = Vec::new();
167                while prompt.contains(&self.image_token) {
168                    let (height, width) = image_sizes_all_iter.next().unwrap();
169                    let num_height_tokens =
170                        (height as usize) / (self.patch_size * self.spatial_merge_size);
171                    let num_width_tokens =
172                        (width as usize) / (self.patch_size * self.spatial_merge_size);
173
174                    let mut replace_tokens = vec![
175                        [
176                            vec![self.image_token.clone(); num_width_tokens],
177                            vec![self.image_break_token.clone()],
178                        ]
179                        .concat();
180                        num_height_tokens
181                    ]
182                    .into_iter()
183                    .flatten()
184                    .collect::<Vec<_>>();
185
186                    *replace_tokens.last_mut().unwrap() = self.image_end_token.clone();
187
188                    replace_strings.push(replace_tokens.join(""));
189                    prompt = prompt.replace(&self.image_token, PLACEHOLDER);
190                }
191
192                while prompt.contains(PLACEHOLDER) {
193                    let replace_str = replace_strings.pop().unwrap();
194                    prompt = prompt.replace(PLACEHOLDER, &replace_str);
195                }
196
197                if !seq.has_changed_prompt {
198                    seq.set_initial_prompt(prompt.clone());
199                    let toks = tokenizer
200                        .encode_fast(prompt, false)
201                        .expect("Detokenization failed!");
202
203                    let ids = toks.get_ids().to_vec();
204                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
205                    seq.has_changed_prompt = true;
206                }
207            }
208
209            (
210                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
211                Some(image_sizes_accum),
212            )
213        } else {
214            (None, None)
215        };
216
217        let text_models_inputs_processor::InnerInputProcessorOutput {
218            inputs:
219                text_models_inputs_processor::InputMetadata {
220                    input,
221                    positions,
222                    context_lens,
223                    position_ids,
224                    paged_attn_meta,
225                    flash_meta,
226                },
227            seq_indices,
228        } = if is_prompt {
229            get_prompt_input(
230                input_seqs
231                    .iter()
232                    .map(|seq| seq.get_toks().to_vec())
233                    .collect::<Vec<_>>(),
234                input_seqs,
235                device,
236                last_n_context_len,
237                return_raw_logits,
238                paged_attn_metadata.as_mut(),
239                None, // TODO: evaluate if it is possible to batch this
240                mapper,
241            )
242            .nth(0)
243            .unwrap()
244            .unwrap()
245        } else {
246            get_completion_input(
247                input_seqs
248                    .iter()
249                    .map(|seq| seq.get_toks().to_vec())
250                    .collect::<Vec<_>>(),
251                input_seqs,
252                device,
253                no_kv_cache,
254                last_n_context_len,
255                return_raw_logits,
256                paged_attn_metadata.as_mut(),
257                None, // TODO: evaluate if it is possible to batch this
258                mapper,
259            )
260            .nth(0)
261            .unwrap()
262            .unwrap()
263        };
264
265        let inputs: Box<dyn Any> = Box::new(ModelInputs {
266            input_ids: input,
267            seqlen_offsets: positions,
268            context_lens,
269            position_ids,
270            pixel_values,
271            model_specific_args: Box::new(Mistral3SpecificArgs { image_sizes }),
272            paged_attn_meta,
273            flash_meta,
274        });
275        Box::new(std::iter::once(Ok(InputProcessorOutput {
276            inputs,
277            seq_indices,
278        })))
279    }
280}
281
282impl Mistral3ImageProcessor {
283    #[allow(clippy::too_many_arguments)]
284    fn resize(
285        &self,
286        image: &DynamicImage,
287        mut height: usize,
288        mut width: usize,
289        max_height: usize,
290        max_width: usize,
291        patch_size: usize,
292        filter: FilterType,
293    ) -> DynamicImage {
294        let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
295        if ratio > 1. {
296            height = (height as f64 / ratio).floor() as usize;
297            width = (width as f64 / ratio).floor() as usize;
298        }
299
300        let num_height_tokens = (height - 1) / patch_size + 1;
301        let num_width_tokens = (width - 1) / patch_size + 1;
302
303        let resize_height = num_height_tokens * patch_size;
304        let resize_width = num_width_tokens * patch_size;
305
306        image.resize_exact(resize_width as u32, resize_height as u32, filter)
307    }
308}
309
310impl ImagePreProcessor for Mistral3ImageProcessor {
311    #[allow(clippy::excessive_precision)]
312    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
313    #[allow(clippy::excessive_precision)]
314    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
315
316    // https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/image_processing_pixtral.py
317    fn preprocess(
318        &self,
319        mut images: Vec<DynamicImage>,
320        videos: Vec<Vec<DynamicImage>>,
321        config: &PreProcessorConfig,
322        device: &Device,
323        (_bs, _max_num_images): (usize, usize),
324    ) -> Result<PreprocessedImages> {
325        assert!(videos.is_empty());
326
327        let do_resize = config.do_resize.unwrap();
328        let do_rescale = config.do_rescale.unwrap();
329        let rescale_factor = config.rescale_factor.unwrap();
330        let do_normalize = config.do_normalize.unwrap();
331        let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
332        let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
333        let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
334        let patch_size = config.patch_size.unwrap();
335        let size = config.size.as_ref().unwrap();
336        let resample = config.resampling.to_filter()?;
337
338        let default_to_square = config.default_to_square.unwrap();
339        assert!(default_to_square);
340
341        let mut pixel_values = Vec::new();
342        let mut image_sizes = Vec::new();
343
344        let (max_height, max_width) = if size.contains_key("longest_edge") {
345            (size["longest_edge"] as usize, size["longest_edge"] as usize)
346        } else if size.contains_key("height") && size.contains_key("width") {
347            (size["height"] as usize, size["width"] as usize)
348        } else {
349            candle_core::bail!("Size must be a map of `longest_edge` or `height` and `width`.");
350        };
351
352        for image in images.iter_mut() {
353            let (width, height) = image.dimensions();
354
355            // Convert to rgb
356            if do_convert_rgb {
357                *image = DynamicImage::ImageRgb8(image.to_rgb8());
358            }
359
360            if do_resize {
361                *image = self.resize(
362                    image,
363                    height as usize,
364                    width as usize,
365                    max_height,
366                    max_width,
367                    patch_size,
368                    resample,
369                );
370            }
371
372            let (width, height) = image.dimensions();
373
374            image_sizes.push((height, width));
375        }
376
377        images = mistralrs_vision::pad_to_max_image_size(images);
378
379        for image in images.iter_mut() {
380            let transforms = Transforms {
381                input: &ToTensorNoNorm,
382                inner_transforms: &[
383                    &do_rescale.then_some(Rescale {
384                        factor: Some(rescale_factor),
385                    }),
386                    &do_normalize.then(|| Normalize {
387                        mean: image_mean.to_vec(),
388                        std: image_std.to_vec(),
389                    }),
390                ],
391            };
392
393            let image = image.apply(transforms, device)?;
394            pixel_values.push(image.unsqueeze(0)?);
395        }
396
397        Ok(PreprocessedImages {
398            pixel_values: Tensor::cat(&pixel_values, 0)?,
399            pixel_attention_mask: None,
400            image_sizes: None,
401            num_img_tokens: None,
402            aspect_ratio_ids: None,
403            aspect_ratio_mask: None,
404            num_tiles: None,
405            image_grid_thw: None,
406            video_grid_thw: None,
407            rows: None,
408            cols: None,
409            pixel_values_list: None,
410            tgt_sizes: None,
411            image_sizes_all: Some(image_sizes),
412            num_crops: None,
413        })
414    }
415}