mistralrs_core/vision_models/gemma3/
inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use regex::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14    device_map::DeviceMapper,
15    pipeline::{
16        text_models_inputs_processor::{
17            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18        },
19        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20    },
21    sequence::Sequence,
22    vision_models::{
23        image_processor::{ImagePreProcessor, PreprocessedImages},
24        preprocessor_config::{PreProcessorConfig, ToFilter},
25        processor_config::ProcessorConfig,
26        ModelInputs,
27    },
28};
29
30use super::Gemma3SpecificArgs;
31
32struct Gemma3ImageProcessor {
33    full_image_sequence: String,
34    supports_images: bool,
35}
36
37const IMAGE_TOKEN: &str = "<image_soft_token>";
38const BOI_TOKEN: &str = "<start_of_image>";
39const EOI_TOKEN: &str = "<end_of_image>";
40
41pub struct Gemma3Processor {
42    full_image_sequence: String,
43    supports_images: bool,
44}
45
46impl Gemma3Processor {
47    pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self {
48        let image_tokens_expanded =
49            vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
50        let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
51
52        Self {
53            full_image_sequence,
54            supports_images,
55        }
56    }
57}
58
59impl Processor for Gemma3Processor {
60    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
61        Arc::new(Gemma3ImageProcessor {
62            full_image_sequence: self.full_image_sequence.clone(),
63            supports_images: self.supports_images,
64        })
65    }
66
67    fn get_special_tokens(&self) -> &[&'static str] {
68        &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
69    }
70
71    fn template_action(&self) -> MessagesAction {
72        MessagesAction::Keep
73    }
74}
75
76impl InputsProcessor for Gemma3ImageProcessor {
77    fn get_type(&self) -> InputsProcessorType {
78        InputsProcessorType::Vision
79    }
80    fn process_inputs(
81        &self,
82        tokenizer: Option<Arc<Tokenizer>>,
83        input_seqs: &mut [&mut Sequence],
84        is_prompt: bool,
85        is_xlora: bool,
86        device: &Device,
87        no_kv_cache: bool,
88        last_n_context_len: Option<(usize, usize)>,
89        return_raw_logits: bool,
90        other_config: Option<Arc<dyn Any>>,
91        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
92        prompt_chunksize: Option<NonZeroUsize>,
93        mapper: Option<&dyn DeviceMapper>,
94    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
95        if is_xlora {
96            return Box::new(std::iter::once(Err(anyhow::Error::msg(
97                "Cannot make inputs for X-LoRA vision model.",
98            ))));
99        }
100        if no_kv_cache {
101            return Box::new(std::iter::once(Err(anyhow::Error::msg(
102                "Vision model must have kv cache.",
103            ))));
104        }
105        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
106        if prompt_chunksize.is_some() {
107            warn!("`prompt_chunksize` is set. Gemma3 does not support prompt batching.");
108        }
109        let Some(tokenizer) = tokenizer else {
110            return Box::new(std::iter::once(Err(anyhow::Error::msg(
111                "Idefics3ImageProcessor requires a specified tokenizer.",
112            ))));
113        };
114
115        let config = other_config.expect("Need a PreProcessorConfig config.");
116        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118        let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120        let pixel_values = if has_images {
121            if !self.supports_images {
122                return Box::new(std::iter::once(Err(anyhow::Error::msg(
123                    "This image processor does not support images.",
124                ))));
125            }
126
127            let mut pixel_values_accum = Vec::new();
128            let re = Regex::new(BOI_TOKEN).unwrap();
129            for seq in input_seqs.iter_mut() {
130                let PreprocessedImages {
131                    pixel_values,
132                    pixel_attention_mask: _,
133                    image_sizes: _,
134                    num_img_tokens: _,
135                    aspect_ratio_ids: _,
136                    aspect_ratio_mask: _,
137                    num_tiles: _,
138                    image_grid_thw: _,
139                    video_grid_thw: _,
140                    rows: _,
141                    cols: _,
142                    pixel_values_list: _,
143                    tgt_sizes: _,
144                    image_sizes_all: _,
145                    num_crops,
146                } = self
147                    .preprocess(
148                        seq.take_images()
149                            .expect("Need to have images by this point."),
150                        vec![],
151                        config,
152                        device,
153                        (usize::MAX, usize::MAX), // Don't use it here...
154                    )
155                    .expect("Preprocessing failed");
156
157                let num_crops = num_crops.unwrap();
158
159                // Deliberately no .unsqueeze here
160                pixel_values_accum.push(pixel_values.clone());
161
162                let mut prompt = tokenizer
163                    .decode(seq.get_toks(), false)
164                    .expect("Detokenization failed!");
165
166                let image_indexes: Vec<usize> =
167                    re.find_iter(&prompt).map(|mat| mat.start()).collect();
168
169                for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
170                    if num != 0 {
171                        let formatted_image_text = format!(
172                            "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
173                        );
174                        prompt = format!(
175                            "{}{formatted_image_text}{}",
176                            &prompt[..idx],
177                            &prompt[idx + BOI_TOKEN.len()..]
178                        );
179                    }
180                }
181
182                prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
183
184                seq.set_initial_prompt(prompt.clone());
185                let toks = tokenizer
186                    .encode_fast(prompt, false)
187                    .expect("Detokenization failed!");
188
189                let ids = toks.get_ids().to_vec();
190                seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
191            }
192
193            Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
194        } else {
195            None
196        };
197
198        let text_models_inputs_processor::InnerInputProcessorOutput {
199            inputs:
200                text_models_inputs_processor::InputMetadata {
201                    input,
202                    positions,
203                    context_lens,
204                    position_ids,
205                    paged_attn_meta,
206                    flash_meta,
207                },
208            seq_indices,
209        } = if is_prompt {
210            get_prompt_input(
211                input_seqs
212                    .iter()
213                    .map(|seq| seq.get_toks().to_vec())
214                    .collect::<Vec<_>>(),
215                input_seqs,
216                device,
217                last_n_context_len,
218                return_raw_logits,
219                paged_attn_metadata.as_mut(),
220                None, // TODO: evaluate if it is possible to batch this
221                mapper,
222            )
223            .nth(0)
224            .unwrap()
225            .unwrap()
226        } else {
227            get_completion_input(
228                input_seqs
229                    .iter()
230                    .map(|seq| seq.get_toks().to_vec())
231                    .collect::<Vec<_>>(),
232                input_seqs,
233                device,
234                no_kv_cache,
235                last_n_context_len,
236                return_raw_logits,
237                paged_attn_metadata.as_mut(),
238                None, // TODO: evaluate if it is possible to batch this
239                mapper,
240            )
241            .nth(0)
242            .unwrap()
243            .unwrap()
244        };
245
246        let inputs: Box<dyn Any> = Box::new(ModelInputs {
247            input_ids: input,
248            seqlen_offsets: positions,
249            context_lens,
250            position_ids,
251            pixel_values,
252            model_specific_args: Box::new(Gemma3SpecificArgs),
253            paged_attn_meta,
254            flash_meta,
255        });
256        Box::new(std::iter::once(Ok(InputProcessorOutput {
257            inputs,
258            seq_indices,
259        })))
260    }
261}
262
263impl Gemma3ImageProcessor {
264    fn pan_and_scan(
265        &self,
266        image: &DynamicImage,
267        pan_and_scan_min_crop_size: usize,
268        pan_and_scan_max_num_crops: usize,
269        pan_and_scan_min_ratio_to_activate: f64,
270    ) -> Vec<DynamicImage> {
271        let (width, height) = image.dimensions();
272
273        let (num_crops_w, num_crops_h) = if width >= height {
274            if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
275                return vec![];
276            }
277
278            // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
279            let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
280            num_crops_w = num_crops_w
281                .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
282
283            // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
284            num_crops_w = num_crops_w.max(2);
285            num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
286
287            (num_crops_w, 1)
288        } else {
289            if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
290                return vec![];
291            }
292
293            // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
294            let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
295            num_crops_h = num_crops_h
296                .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
297
298            // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
299            num_crops_h = num_crops_h.max(2);
300            num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
301
302            (1, num_crops_h)
303        };
304
305        let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
306        let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
307
308        if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
309            return vec![];
310        }
311
312        let crop_positions_w = (0..num_crops_w)
313            .map(|i| i * crop_size_w)
314            .collect::<Vec<_>>();
315        let crop_positions_h = (0..num_crops_h)
316            .map(|i| i * crop_size_h)
317            .collect::<Vec<_>>();
318
319        let mut image_crops = Vec::new();
320        for (pos_h, pos_w) in crop_positions_h
321            .into_iter()
322            .cartesian_product(crop_positions_w)
323        {
324            image_crops.push(image.crop_imm(
325                pos_w as u32,
326                pos_h as u32,
327                crop_size_w as u32,
328                crop_size_h as u32,
329            ));
330        }
331
332        image_crops
333    }
334
335    fn process_images_for_pan_and_scan(
336        &self,
337        images: Vec<DynamicImage>,
338        pan_and_scan_min_crop_size: usize,
339        pan_and_scan_max_num_crops: usize,
340        pan_and_scan_min_ratio_to_activate: f64,
341    ) -> (Vec<DynamicImage>, Vec<usize>) {
342        let mut pas_images_list = Vec::new();
343        let mut num_crops = Vec::new();
344
345        for image in images {
346            let pas_images = self.pan_and_scan(
347                &image,
348                pan_and_scan_min_crop_size,
349                pan_and_scan_max_num_crops,
350                pan_and_scan_min_ratio_to_activate,
351            );
352            num_crops.push(pas_images.len());
353            pas_images_list.extend([vec![image], pas_images].concat());
354        }
355
356        (pas_images_list, num_crops)
357    }
358}
359
360impl ImagePreProcessor for Gemma3ImageProcessor {
361    const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
362    const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
363
364    fn preprocess(
365        &self,
366        mut images: Vec<DynamicImage>,
367        videos: Vec<Vec<DynamicImage>>,
368        config: &PreProcessorConfig,
369        device: &Device,
370        (_bs, _max_num_images): (usize, usize),
371    ) -> Result<PreprocessedImages> {
372        assert!(videos.is_empty());
373
374        let do_resize = config.do_resize.unwrap();
375        let size = config.size.as_ref().unwrap();
376        let (height, width) = (size["height"], size["width"]);
377        let resample = config.resampling.to_filter()?;
378        let do_rescale = config.do_rescale.unwrap();
379        let rescale_factor = config.rescale_factor.unwrap();
380        let do_normalize = config.do_normalize.unwrap();
381        let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
382        let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
383        let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
384        let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
385        // https://github.com/huggingface/transformers/blob/ea219ed164bead55a5513e8cfaa17a25d5613b9e/src/transformers/models/gemma3/processing_gemma3.py#L42
386        let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
387        let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
388        let pan_and_scan_min_ratio_to_activate =
389            config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
390
391        for image in images.iter_mut() {
392            // Convert to rgb
393            if do_convert_rgb {
394                *image = DynamicImage::ImageRgb8(image.to_rgb8());
395            }
396        }
397
398        let num_crops = if do_pan_and_scan {
399            let (new_images, num_crops) = self.process_images_for_pan_and_scan(
400                images,
401                pan_and_scan_min_crop_size,
402                pan_and_scan_max_num_crops,
403                pan_and_scan_min_ratio_to_activate,
404            );
405            images = new_images;
406            num_crops
407        } else {
408            vec![0]
409        };
410
411        let mut pixel_values = Vec::new();
412        for mut image in images {
413            if do_resize {
414                image = image.resize_exact(width, height, resample);
415            }
416
417            let transforms = Transforms {
418                input: &ToTensorNoNorm,
419                inner_transforms: &[
420                    &do_rescale.then_some(Rescale {
421                        factor: Some(rescale_factor),
422                    }),
423                    &do_normalize.then(|| Normalize {
424                        mean: image_mean.to_vec(),
425                        std: image_std.to_vec(),
426                    }),
427                ],
428            };
429
430            let image = image.apply(transforms, device)?;
431            pixel_values.push(image.unsqueeze(0)?);
432        }
433
434        Ok(PreprocessedImages {
435            pixel_values: Tensor::cat(&pixel_values, 0)?,
436            pixel_attention_mask: None,
437            image_sizes: None,
438            num_img_tokens: None,
439            aspect_ratio_ids: None,
440            aspect_ratio_mask: None,
441            num_tiles: None,
442            image_grid_thw: None,
443            video_grid_thw: None,
444            rows: None,
445            cols: None,
446            pixel_values_list: None,
447            tgt_sizes: None,
448            image_sizes_all: None,
449            num_crops: Some(num_crops),
450        })
451    }
452}