mistralrs_core/vision_models/gemma3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use regex::Regex;
10use tokenizers::Tokenizer;
11
12use crate::{
13    device_map::DeviceMapper,
14    pipeline::{
15        text_models_inputs_processor::{
16            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17        },
18        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19    },
20    sequence::Sequence,
21    vision_models::{
22        image_processor::{ImagePreProcessor, PreprocessedImages},
23        preprocessor_config::{PreProcessorConfig, ToFilter},
24        processor_config::ProcessorConfig,
25        ModelInputs,
26    },
27};
28
29use super::Gemma3SpecificArgs;
30
31struct Gemma3ImageProcessor {
32    full_image_sequence: String,
33    supports_images: bool,
34}
35
36const IMAGE_TOKEN: &str = "<image_soft_token>";
37const BOI_TOKEN: &str = "<start_of_image>";
38const EOI_TOKEN: &str = "<end_of_image>";
39
40pub struct Gemma3Processor {
41    full_image_sequence: String,
42    supports_images: bool,
43}
44
45impl Gemma3Processor {
46    pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self {
47        let image_tokens_expanded =
48            vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
49        let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
50
51        Self {
52            full_image_sequence,
53            supports_images,
54        }
55    }
56}
57
58impl Processor for Gemma3Processor {
59    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
60        Arc::new(Gemma3ImageProcessor {
61            full_image_sequence: self.full_image_sequence.clone(),
62            supports_images: self.supports_images,
63        })
64    }
65
66    fn get_special_tokens(&self) -> &[&'static str] {
67        &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
68    }
69
70    fn template_action(&self) -> MessagesAction {
71        MessagesAction::Keep
72    }
73}
74
75impl InputsProcessor for Gemma3ImageProcessor {
76    fn get_type(&self) -> InputsProcessorType {
77        InputsProcessorType::Vision
78    }
79    fn process_inputs(
80        &self,
81        tokenizer: Option<Arc<Tokenizer>>,
82        input_seqs: &mut [&mut Sequence],
83        is_prompt: bool,
84        is_xlora: bool,
85        device: &Device,
86        no_kv_cache: bool,
87        last_n_context_len: Option<(usize, usize)>,
88        return_raw_logits: bool,
89        other_config: Option<Arc<dyn Any>>,
90        mut paged_attn_metadata: Option<PagedAttentionMeta>,
91        mapper: Option<&dyn DeviceMapper>,
92    ) -> anyhow::Result<InputProcessorOutput> {
93        if is_xlora {
94            return Err(anyhow::Error::msg(
95                "Cannot make inputs for X-LoRA vision model.",
96            ));
97        }
98        if no_kv_cache {
99            return Err(anyhow::Error::msg("Vision model must have kv cache."));
100        }
101        let Some(tokenizer) = tokenizer else {
102            return Err(anyhow::Error::msg(
103                "Idefics3ImageProcessor requires a specified tokenizer.",
104            ));
105        };
106
107        let config = other_config.expect("Need a PreProcessorConfig config.");
108        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
109
110        let has_images = input_seqs.iter().all(|seq| seq.has_images());
111
112        let pixel_values = if has_images {
113            if !self.supports_images {
114                return Err(anyhow::Error::msg(
115                    "This image processor does not support images.",
116                ));
117            }
118
119            let mut pixel_values_accum = Vec::new();
120            let re = Regex::new(BOI_TOKEN).unwrap();
121            for seq in input_seqs.iter_mut() {
122                let PreprocessedImages {
123                    pixel_values,
124                    pixel_attention_mask: _,
125                    image_sizes: _,
126                    num_img_tokens: _,
127                    aspect_ratio_ids: _,
128                    aspect_ratio_mask: _,
129                    num_tiles: _,
130                    image_grid_thw: _,
131                    video_grid_thw: _,
132                    rows: _,
133                    cols: _,
134                    pixel_values_list: _,
135                    tgt_sizes: _,
136                    image_sizes_all: _,
137                    num_crops,
138                } = self
139                    .preprocess(
140                        seq.take_images()
141                            .expect("Need to have images by this point."),
142                        vec![],
143                        config,
144                        device,
145                        (usize::MAX, usize::MAX), // Don't use it here...
146                    )
147                    .expect("Preprocessing failed");
148
149                let num_crops = num_crops.unwrap();
150
151                // Deliberately no .unsqueeze here
152                pixel_values_accum.push(pixel_values.clone());
153
154                let mut prompt = tokenizer
155                    .decode(seq.get_toks(), false)
156                    .expect("Detokenization failed!");
157
158                let image_indexes: Vec<usize> =
159                    re.find_iter(&prompt).map(|mat| mat.start()).collect();
160
161                for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
162                    if num != 0 {
163                        let formatted_image_text = format!(
164                            "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
165                        );
166                        prompt = format!(
167                            "{}{formatted_image_text}{}",
168                            &prompt[..idx],
169                            &prompt[idx + BOI_TOKEN.len()..]
170                        );
171                    }
172                }
173
174                prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
175
176                if !seq.multimodal.has_changed_prompt {
177                    seq.set_initial_prompt(prompt.clone());
178                    let toks = tokenizer
179                        .encode_fast(prompt, false)
180                        .expect("Detokenization failed!");
181
182                    let ids = toks.get_ids().to_vec();
183                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
184                    seq.multimodal.has_changed_prompt = true;
185                }
186            }
187
188            Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
189        } else {
190            None
191        };
192
193        let text_models_inputs_processor::InnerInputProcessorOutput {
194            inputs:
195                text_models_inputs_processor::InputMetadata {
196                    input,
197                    positions,
198                    context_lens,
199                    position_ids,
200                    paged_attn_meta,
201                    flash_meta,
202                },
203            seq_indices,
204        } = if is_prompt {
205            get_prompt_input(
206                input_seqs
207                    .iter()
208                    .map(|seq| seq.get_toks())
209                    .collect::<Vec<_>>(),
210                input_seqs,
211                device,
212                last_n_context_len,
213                return_raw_logits,
214                paged_attn_metadata.as_mut(),
215                mapper,
216            )
217            .unwrap()
218        } else {
219            get_completion_input(
220                input_seqs
221                    .iter()
222                    .map(|seq| seq.get_toks())
223                    .collect::<Vec<_>>(),
224                input_seqs,
225                device,
226                no_kv_cache,
227                last_n_context_len,
228                return_raw_logits,
229                paged_attn_metadata.as_mut(),
230                mapper,
231            )
232            .unwrap()
233        };
234
235        let inputs: Box<dyn Any> = Box::new(ModelInputs {
236            input_ids: input,
237            seqlen_offsets: positions,
238            context_lens,
239            position_ids,
240            pixel_values,
241            model_specific_args: Box::new(Gemma3SpecificArgs),
242            paged_attn_meta,
243            flash_meta,
244        });
245        Ok(InputProcessorOutput {
246            inputs,
247            seq_indices,
248        })
249    }
250}
251
252impl Gemma3ImageProcessor {
253    fn pan_and_scan(
254        &self,
255        image: &DynamicImage,
256        pan_and_scan_min_crop_size: usize,
257        pan_and_scan_max_num_crops: usize,
258        pan_and_scan_min_ratio_to_activate: f64,
259    ) -> Vec<DynamicImage> {
260        let (width, height) = image.dimensions();
261
262        let (num_crops_w, num_crops_h) = if width >= height {
263            if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
264                return vec![];
265            }
266
267            // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
268            let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
269            num_crops_w = num_crops_w
270                .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
271
272            // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
273            num_crops_w = num_crops_w.max(2);
274            num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
275
276            (num_crops_w, 1)
277        } else {
278            if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
279                return vec![];
280            }
281
282            // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
283            let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
284            num_crops_h = num_crops_h
285                .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
286
287            // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
288            num_crops_h = num_crops_h.max(2);
289            num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
290
291            (1, num_crops_h)
292        };
293
294        let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
295        let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
296
297        if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
298            return vec![];
299        }
300
301        let crop_positions_w = (0..num_crops_w)
302            .map(|i| i * crop_size_w)
303            .collect::<Vec<_>>();
304        let crop_positions_h = (0..num_crops_h)
305            .map(|i| i * crop_size_h)
306            .collect::<Vec<_>>();
307
308        let mut image_crops = Vec::new();
309        for (pos_h, pos_w) in crop_positions_h
310            .into_iter()
311            .cartesian_product(crop_positions_w)
312        {
313            image_crops.push(image.crop_imm(
314                pos_w as u32,
315                pos_h as u32,
316                crop_size_w as u32,
317                crop_size_h as u32,
318            ));
319        }
320
321        image_crops
322    }
323
324    fn process_images_for_pan_and_scan(
325        &self,
326        images: Vec<DynamicImage>,
327        pan_and_scan_min_crop_size: usize,
328        pan_and_scan_max_num_crops: usize,
329        pan_and_scan_min_ratio_to_activate: f64,
330    ) -> (Vec<DynamicImage>, Vec<usize>) {
331        let mut pas_images_list = Vec::new();
332        let mut num_crops = Vec::new();
333
334        for image in images {
335            let pas_images = self.pan_and_scan(
336                &image,
337                pan_and_scan_min_crop_size,
338                pan_and_scan_max_num_crops,
339                pan_and_scan_min_ratio_to_activate,
340            );
341            num_crops.push(pas_images.len());
342            pas_images_list.extend([vec![image], pas_images].concat());
343        }
344
345        (pas_images_list, num_crops)
346    }
347}
348
349impl ImagePreProcessor for Gemma3ImageProcessor {
350    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
351    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
352
353    fn preprocess(
354        &self,
355        mut images: Vec<DynamicImage>,
356        videos: Vec<Vec<DynamicImage>>,
357        config: &PreProcessorConfig,
358        device: &Device,
359        (_bs, _max_num_images): (usize, usize),
360    ) -> Result<PreprocessedImages> {
361        assert!(videos.is_empty());
362
363        let do_resize = config.do_resize.unwrap();
364        let size = config.size.as_ref().unwrap();
365        let (height, width) = (size["height"], size["width"]);
366        let resample = config.resampling.to_filter()?;
367        let do_rescale = config.do_rescale.unwrap();
368        let rescale_factor = config.rescale_factor.unwrap();
369        let do_normalize = config.do_normalize.unwrap();
370        let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
371        let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
372        let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
373        let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
374        // https://github.com/huggingface/transformers/blob/ea219ed164bead55a5513e8cfaa17a25d5613b9e/src/transformers/models/gemma3/processing_gemma3.py#L42
375        let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
376        let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
377        let pan_and_scan_min_ratio_to_activate =
378            config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
379
380        for image in images.iter_mut() {
381            // Convert to rgb
382            if do_convert_rgb {
383                *image = DynamicImage::ImageRgb8(image.to_rgb8());
384            }
385        }
386
387        let num_crops = if do_pan_and_scan {
388            let (new_images, num_crops) = self.process_images_for_pan_and_scan(
389                images,
390                pan_and_scan_min_crop_size,
391                pan_and_scan_max_num_crops,
392                pan_and_scan_min_ratio_to_activate,
393            );
394            images = new_images;
395            num_crops
396        } else {
397            vec![0]
398        };
399
400        let mut pixel_values = Vec::new();
401        for mut image in images {
402            if do_resize {
403                image = image.resize_exact(width, height, resample);
404            }
405
406            let transforms = Transforms {
407                input: &ToTensorNoNorm,
408                inner_transforms: &[
409                    &do_rescale.then_some(Rescale {
410                        factor: Some(rescale_factor),
411                    }),
412                    &do_normalize.then(|| Normalize {
413                        mean: image_mean.to_vec(),
414                        std: image_std.to_vec(),
415                    }),
416                ],
417            };
418
419            let image = image.apply(transforms, device)?;
420            pixel_values.push(image.unsqueeze(0)?);
421        }
422
423        Ok(PreprocessedImages {
424            pixel_values: Tensor::cat(&pixel_values, 0)?,
425            pixel_attention_mask: None,
426            image_sizes: None,
427            num_img_tokens: None,
428            aspect_ratio_ids: None,
429            aspect_ratio_mask: None,
430            num_tiles: None,
431            image_grid_thw: None,
432            video_grid_thw: None,
433            rows: None,
434            cols: None,
435            pixel_values_list: None,
436            tgt_sizes: None,
437            image_sizes_all: None,
438            num_crops: Some(num_crops),
439        })
440    }
441}