mistralrs_core/vision_models/gemma3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use regex::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14    device_map::DeviceMapper,
15    pipeline::{
16        text_models_inputs_processor::{
17            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18        },
19        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20    },
21    sequence::Sequence,
22    vision_models::{
23        image_processor::{ImagePreProcessor, PreprocessedImages},
24        preprocessor_config::{PreProcessorConfig, ToFilter},
25        processor_config::ProcessorConfig,
26        ModelInputs,
27    },
28};
29
30use super::Gemma3SpecificArgs;
31
32struct Gemma3ImageProcessor {
33    full_image_sequence: String,
34    supports_images: bool,
35}
36
37const IMAGE_TOKEN: &str = "<image_soft_token>";
38const BOI_TOKEN: &str = "<start_of_image>";
39const EOI_TOKEN: &str = "<end_of_image>";
40
41pub struct Gemma3Processor {
42    full_image_sequence: String,
43    supports_images: bool,
44}
45
46impl Gemma3Processor {
47    pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self {
48        let image_tokens_expanded =
49            vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
50        let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
51
52        Self {
53            full_image_sequence,
54            supports_images,
55        }
56    }
57}
58
59impl Processor for Gemma3Processor {
60    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
61        Arc::new(Gemma3ImageProcessor {
62            full_image_sequence: self.full_image_sequence.clone(),
63            supports_images: self.supports_images,
64        })
65    }
66
67    fn get_special_tokens(&self) -> &[&'static str] {
68        &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
69    }
70
71    fn template_action(&self) -> MessagesAction {
72        MessagesAction::Keep
73    }
74}
75
76impl InputsProcessor for Gemma3ImageProcessor {
77    fn get_type(&self) -> InputsProcessorType {
78        InputsProcessorType::Vision
79    }
80    fn process_inputs(
81        &self,
82        tokenizer: Option<Arc<Tokenizer>>,
83        input_seqs: &mut [&mut Sequence],
84        is_prompt: bool,
85        is_xlora: bool,
86        device: &Device,
87        no_kv_cache: bool,
88        last_n_context_len: Option<(usize, usize)>,
89        return_raw_logits: bool,
90        other_config: Option<Arc<dyn Any>>,
91        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
92        prompt_chunksize: Option<NonZeroUsize>,
93        mapper: Option<&dyn DeviceMapper>,
94    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
95        if is_xlora {
96            return Box::new(std::iter::once(Err(anyhow::Error::msg(
97                "Cannot make inputs for X-LoRA vision model.",
98            ))));
99        }
100        if no_kv_cache {
101            return Box::new(std::iter::once(Err(anyhow::Error::msg(
102                "Vision model must have kv cache.",
103            ))));
104        }
105        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
106        if prompt_chunksize.is_some() {
107            warn!("`prompt_chunksize` is set. Gemma3 does not support prompt batching.");
108        }
109        let Some(tokenizer) = tokenizer else {
110            return Box::new(std::iter::once(Err(anyhow::Error::msg(
111                "Idefics3ImageProcessor requires a specified tokenizer.",
112            ))));
113        };
114
115        let config = other_config.expect("Need a PreProcessorConfig config.");
116        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118        let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120        let pixel_values = if has_images {
121            if !self.supports_images {
122                return Box::new(std::iter::once(Err(anyhow::Error::msg(
123                    "This image processor does not support images.",
124                ))));
125            }
126
127            let mut pixel_values_accum = Vec::new();
128            let re = Regex::new(BOI_TOKEN).unwrap();
129            for seq in input_seqs.iter_mut() {
130                let PreprocessedImages {
131                    pixel_values,
132                    pixel_attention_mask: _,
133                    image_sizes: _,
134                    num_img_tokens: _,
135                    aspect_ratio_ids: _,
136                    aspect_ratio_mask: _,
137                    num_tiles: _,
138                    image_grid_thw: _,
139                    video_grid_thw: _,
140                    rows: _,
141                    cols: _,
142                    pixel_values_list: _,
143                    tgt_sizes: _,
144                    image_sizes_all: _,
145                    num_crops,
146                } = self
147                    .preprocess(
148                        seq.take_images()
149                            .expect("Need to have images by this point."),
150                        vec![],
151                        config,
152                        device,
153                        (usize::MAX, usize::MAX), // Don't use it here...
154                    )
155                    .expect("Preprocessing failed");
156
157                let num_crops = num_crops.unwrap();
158
159                // Deliberately no .unsqueeze here
160                pixel_values_accum.push(pixel_values.clone());
161
162                let mut prompt = tokenizer
163                    .decode(seq.get_toks(), false)
164                    .expect("Detokenization failed!");
165
166                let image_indexes: Vec<usize> =
167                    re.find_iter(&prompt).map(|mat| mat.start()).collect();
168
169                for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
170                    if num != 0 {
171                        let formatted_image_text = format!(
172                            "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
173                        );
174                        prompt = format!(
175                            "{}{formatted_image_text}{}",
176                            &prompt[..idx],
177                            &prompt[idx + BOI_TOKEN.len()..]
178                        );
179                    }
180                }
181
182                prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
183
184                if !seq.has_changed_prompt {
185                    seq.set_initial_prompt(prompt.clone());
186                    let toks = tokenizer
187                        .encode_fast(prompt, false)
188                        .expect("Detokenization failed!");
189
190                    let ids = toks.get_ids().to_vec();
191                    seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
192                    seq.has_changed_prompt = true;
193                }
194            }
195
196            Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
197        } else {
198            None
199        };
200
201        let text_models_inputs_processor::InnerInputProcessorOutput {
202            inputs:
203                text_models_inputs_processor::InputMetadata {
204                    input,
205                    positions,
206                    context_lens,
207                    position_ids,
208                    paged_attn_meta,
209                    flash_meta,
210                },
211            seq_indices,
212        } = if is_prompt {
213            get_prompt_input(
214                input_seqs
215                    .iter()
216                    .map(|seq| seq.get_toks().to_vec())
217                    .collect::<Vec<_>>(),
218                input_seqs,
219                device,
220                last_n_context_len,
221                return_raw_logits,
222                paged_attn_metadata.as_mut(),
223                None, // TODO: evaluate if it is possible to batch this
224                mapper,
225            )
226            .nth(0)
227            .unwrap()
228            .unwrap()
229        } else {
230            get_completion_input(
231                input_seqs
232                    .iter()
233                    .map(|seq| seq.get_toks().to_vec())
234                    .collect::<Vec<_>>(),
235                input_seqs,
236                device,
237                no_kv_cache,
238                last_n_context_len,
239                return_raw_logits,
240                paged_attn_metadata.as_mut(),
241                None, // TODO: evaluate if it is possible to batch this
242                mapper,
243            )
244            .nth(0)
245            .unwrap()
246            .unwrap()
247        };
248
249        let inputs: Box<dyn Any> = Box::new(ModelInputs {
250            input_ids: input,
251            seqlen_offsets: positions,
252            context_lens,
253            position_ids,
254            pixel_values,
255            model_specific_args: Box::new(Gemma3SpecificArgs),
256            paged_attn_meta,
257            flash_meta,
258        });
259        Box::new(std::iter::once(Ok(InputProcessorOutput {
260            inputs,
261            seq_indices,
262        })))
263    }
264}
265
266impl Gemma3ImageProcessor {
267    fn pan_and_scan(
268        &self,
269        image: &DynamicImage,
270        pan_and_scan_min_crop_size: usize,
271        pan_and_scan_max_num_crops: usize,
272        pan_and_scan_min_ratio_to_activate: f64,
273    ) -> Vec<DynamicImage> {
274        let (width, height) = image.dimensions();
275
276        let (num_crops_w, num_crops_h) = if width >= height {
277            if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
278                return vec![];
279            }
280
281            // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
282            let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
283            num_crops_w = num_crops_w
284                .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
285
286            // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
287            num_crops_w = num_crops_w.max(2);
288            num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
289
290            (num_crops_w, 1)
291        } else {
292            if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
293                return vec![];
294            }
295
296            // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
297            let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
298            num_crops_h = num_crops_h
299                .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
300
301            // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
302            num_crops_h = num_crops_h.max(2);
303            num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
304
305            (1, num_crops_h)
306        };
307
308        let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
309        let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
310
311        if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
312            return vec![];
313        }
314
315        let crop_positions_w = (0..num_crops_w)
316            .map(|i| i * crop_size_w)
317            .collect::<Vec<_>>();
318        let crop_positions_h = (0..num_crops_h)
319            .map(|i| i * crop_size_h)
320            .collect::<Vec<_>>();
321
322        let mut image_crops = Vec::new();
323        for (pos_h, pos_w) in crop_positions_h
324            .into_iter()
325            .cartesian_product(crop_positions_w)
326        {
327            image_crops.push(image.crop_imm(
328                pos_w as u32,
329                pos_h as u32,
330                crop_size_w as u32,
331                crop_size_h as u32,
332            ));
333        }
334
335        image_crops
336    }
337
338    fn process_images_for_pan_and_scan(
339        &self,
340        images: Vec<DynamicImage>,
341        pan_and_scan_min_crop_size: usize,
342        pan_and_scan_max_num_crops: usize,
343        pan_and_scan_min_ratio_to_activate: f64,
344    ) -> (Vec<DynamicImage>, Vec<usize>) {
345        let mut pas_images_list = Vec::new();
346        let mut num_crops = Vec::new();
347
348        for image in images {
349            let pas_images = self.pan_and_scan(
350                &image,
351                pan_and_scan_min_crop_size,
352                pan_and_scan_max_num_crops,
353                pan_and_scan_min_ratio_to_activate,
354            );
355            num_crops.push(pas_images.len());
356            pas_images_list.extend([vec![image], pas_images].concat());
357        }
358
359        (pas_images_list, num_crops)
360    }
361}
362
363impl ImagePreProcessor for Gemma3ImageProcessor {
364    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
365    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
366
367    fn preprocess(
368        &self,
369        mut images: Vec<DynamicImage>,
370        videos: Vec<Vec<DynamicImage>>,
371        config: &PreProcessorConfig,
372        device: &Device,
373        (_bs, _max_num_images): (usize, usize),
374    ) -> Result<PreprocessedImages> {
375        assert!(videos.is_empty());
376
377        let do_resize = config.do_resize.unwrap();
378        let size = config.size.as_ref().unwrap();
379        let (height, width) = (size["height"], size["width"]);
380        let resample = config.resampling.to_filter()?;
381        let do_rescale = config.do_rescale.unwrap();
382        let rescale_factor = config.rescale_factor.unwrap();
383        let do_normalize = config.do_normalize.unwrap();
384        let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
385        let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
386        let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
387        let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
388        // https://github.com/huggingface/transformers/blob/ea219ed164bead55a5513e8cfaa17a25d5613b9e/src/transformers/models/gemma3/processing_gemma3.py#L42
389        let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
390        let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
391        let pan_and_scan_min_ratio_to_activate =
392            config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
393
394        for image in images.iter_mut() {
395            // Convert to rgb
396            if do_convert_rgb {
397                *image = DynamicImage::ImageRgb8(image.to_rgb8());
398            }
399        }
400
401        let num_crops = if do_pan_and_scan {
402            let (new_images, num_crops) = self.process_images_for_pan_and_scan(
403                images,
404                pan_and_scan_min_crop_size,
405                pan_and_scan_max_num_crops,
406                pan_and_scan_min_ratio_to_activate,
407            );
408            images = new_images;
409            num_crops
410        } else {
411            vec![0]
412        };
413
414        let mut pixel_values = Vec::new();
415        for mut image in images {
416            if do_resize {
417                image = image.resize_exact(width, height, resample);
418            }
419
420            let transforms = Transforms {
421                input: &ToTensorNoNorm,
422                inner_transforms: &[
423                    &do_rescale.then_some(Rescale {
424                        factor: Some(rescale_factor),
425                    }),
426                    &do_normalize.then(|| Normalize {
427                        mean: image_mean.to_vec(),
428                        std: image_std.to_vec(),
429                    }),
430                ],
431            };
432
433            let image = image.apply(transforms, device)?;
434            pixel_values.push(image.unsqueeze(0)?);
435        }
436
437        Ok(PreprocessedImages {
438            pixel_values: Tensor::cat(&pixel_values, 0)?,
439            pixel_attention_mask: None,
440            image_sizes: None,
441            num_img_tokens: None,
442            aspect_ratio_ids: None,
443            aspect_ratio_mask: None,
444            num_tiles: None,
445            image_grid_thw: None,
446            video_grid_thw: None,
447            rows: None,
448            cols: None,
449            pixel_values_list: None,
450            tgt_sizes: None,
451            image_sizes_all: None,
452            num_crops: Some(num_crops),
453        })
454    }
455}