1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use regex::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14 device_map::DeviceMapper,
15 pipeline::{
16 text_models_inputs_processor::{
17 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18 },
19 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20 },
21 sequence::Sequence,
22 vision_models::{
23 image_processor::{ImagePreProcessor, PreprocessedImages},
24 preprocessor_config::{PreProcessorConfig, ToFilter},
25 processor_config::ProcessorConfig,
26 ModelInputs,
27 },
28};
29
30use super::Gemma3SpecificArgs;
31
32struct Gemma3ImageProcessor {
33 full_image_sequence: String,
34 supports_images: bool,
35}
36
37const IMAGE_TOKEN: &str = "<image_soft_token>";
38const BOI_TOKEN: &str = "<start_of_image>";
39const EOI_TOKEN: &str = "<end_of_image>";
40
41pub struct Gemma3Processor {
42 full_image_sequence: String,
43 supports_images: bool,
44}
45
46impl Gemma3Processor {
47 pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self {
48 let image_tokens_expanded =
49 vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
50 let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
51
52 Self {
53 full_image_sequence,
54 supports_images,
55 }
56 }
57}
58
59impl Processor for Gemma3Processor {
60 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
61 Arc::new(Gemma3ImageProcessor {
62 full_image_sequence: self.full_image_sequence.clone(),
63 supports_images: self.supports_images,
64 })
65 }
66
67 fn get_special_tokens(&self) -> &[&'static str] {
68 &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
69 }
70
71 fn template_action(&self) -> MessagesAction {
72 MessagesAction::Keep
73 }
74}
75
76impl InputsProcessor for Gemma3ImageProcessor {
77 fn get_type(&self) -> InputsProcessorType {
78 InputsProcessorType::Vision
79 }
80 fn process_inputs(
81 &self,
82 tokenizer: Option<Arc<Tokenizer>>,
83 input_seqs: &mut [&mut Sequence],
84 is_prompt: bool,
85 is_xlora: bool,
86 device: &Device,
87 no_kv_cache: bool,
88 last_n_context_len: Option<(usize, usize)>,
89 return_raw_logits: bool,
90 other_config: Option<Arc<dyn Any>>,
91 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
92 prompt_chunksize: Option<NonZeroUsize>,
93 mapper: Option<&dyn DeviceMapper>,
94 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
95 if is_xlora {
96 return Box::new(std::iter::once(Err(anyhow::Error::msg(
97 "Cannot make inputs for X-LoRA vision model.",
98 ))));
99 }
100 if no_kv_cache {
101 return Box::new(std::iter::once(Err(anyhow::Error::msg(
102 "Vision model must have kv cache.",
103 ))));
104 }
105 if prompt_chunksize.is_some() {
107 warn!("`prompt_chunksize` is set. Gemma3 does not support prompt batching.");
108 }
109 let Some(tokenizer) = tokenizer else {
110 return Box::new(std::iter::once(Err(anyhow::Error::msg(
111 "Idefics3ImageProcessor requires a specified tokenizer.",
112 ))));
113 };
114
115 let config = other_config.expect("Need a PreProcessorConfig config.");
116 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118 let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120 let pixel_values = if has_images {
121 if !self.supports_images {
122 return Box::new(std::iter::once(Err(anyhow::Error::msg(
123 "This image processor does not support images.",
124 ))));
125 }
126
127 let mut pixel_values_accum = Vec::new();
128 let re = Regex::new(BOI_TOKEN).unwrap();
129 for seq in input_seqs.iter_mut() {
130 let PreprocessedImages {
131 pixel_values,
132 pixel_attention_mask: _,
133 image_sizes: _,
134 num_img_tokens: _,
135 aspect_ratio_ids: _,
136 aspect_ratio_mask: _,
137 num_tiles: _,
138 image_grid_thw: _,
139 video_grid_thw: _,
140 rows: _,
141 cols: _,
142 pixel_values_list: _,
143 tgt_sizes: _,
144 image_sizes_all: _,
145 num_crops,
146 } = self
147 .preprocess(
148 seq.take_images()
149 .expect("Need to have images by this point."),
150 vec![],
151 config,
152 device,
153 (usize::MAX, usize::MAX), )
155 .expect("Preprocessing failed");
156
157 let num_crops = num_crops.unwrap();
158
159 pixel_values_accum.push(pixel_values.clone());
161
162 let mut prompt = tokenizer
163 .decode(seq.get_toks(), false)
164 .expect("Detokenization failed!");
165
166 let image_indexes: Vec<usize> =
167 re.find_iter(&prompt).map(|mat| mat.start()).collect();
168
169 for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
170 if num != 0 {
171 let formatted_image_text = format!(
172 "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
173 );
174 prompt = format!(
175 "{}{formatted_image_text}{}",
176 &prompt[..idx],
177 &prompt[idx + BOI_TOKEN.len()..]
178 );
179 }
180 }
181
182 prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
183
184 seq.set_initial_prompt(prompt.clone());
185 let toks = tokenizer
186 .encode_fast(prompt, false)
187 .expect("Detokenization failed!");
188
189 let ids = toks.get_ids().to_vec();
190 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
191 }
192
193 Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
194 } else {
195 None
196 };
197
198 let text_models_inputs_processor::InnerInputProcessorOutput {
199 inputs:
200 text_models_inputs_processor::InputMetadata {
201 input,
202 positions,
203 context_lens,
204 position_ids,
205 paged_attn_meta,
206 flash_meta,
207 },
208 seq_indices,
209 } = if is_prompt {
210 get_prompt_input(
211 input_seqs
212 .iter()
213 .map(|seq| seq.get_toks().to_vec())
214 .collect::<Vec<_>>(),
215 input_seqs,
216 device,
217 last_n_context_len,
218 return_raw_logits,
219 paged_attn_metadata.as_mut(),
220 None, mapper,
222 )
223 .nth(0)
224 .unwrap()
225 .unwrap()
226 } else {
227 get_completion_input(
228 input_seqs
229 .iter()
230 .map(|seq| seq.get_toks().to_vec())
231 .collect::<Vec<_>>(),
232 input_seqs,
233 device,
234 no_kv_cache,
235 last_n_context_len,
236 return_raw_logits,
237 paged_attn_metadata.as_mut(),
238 None, mapper,
240 )
241 .nth(0)
242 .unwrap()
243 .unwrap()
244 };
245
246 let inputs: Box<dyn Any> = Box::new(ModelInputs {
247 input_ids: input,
248 seqlen_offsets: positions,
249 context_lens,
250 position_ids,
251 pixel_values,
252 model_specific_args: Box::new(Gemma3SpecificArgs),
253 paged_attn_meta,
254 flash_meta,
255 });
256 Box::new(std::iter::once(Ok(InputProcessorOutput {
257 inputs,
258 seq_indices,
259 })))
260 }
261}
262
263impl Gemma3ImageProcessor {
264 fn pan_and_scan(
265 &self,
266 image: &DynamicImage,
267 pan_and_scan_min_crop_size: usize,
268 pan_and_scan_max_num_crops: usize,
269 pan_and_scan_min_ratio_to_activate: f64,
270 ) -> Vec<DynamicImage> {
271 let (width, height) = image.dimensions();
272
273 let (num_crops_w, num_crops_h) = if width >= height {
274 if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
275 return vec![];
276 }
277
278 let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
280 num_crops_w = num_crops_w
281 .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
282
283 num_crops_w = num_crops_w.max(2);
285 num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
286
287 (num_crops_w, 1)
288 } else {
289 if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
290 return vec![];
291 }
292
293 let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
295 num_crops_h = num_crops_h
296 .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
297
298 num_crops_h = num_crops_h.max(2);
300 num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
301
302 (1, num_crops_h)
303 };
304
305 let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
306 let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
307
308 if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
309 return vec![];
310 }
311
312 let crop_positions_w = (0..num_crops_w)
313 .map(|i| i * crop_size_w)
314 .collect::<Vec<_>>();
315 let crop_positions_h = (0..num_crops_h)
316 .map(|i| i * crop_size_h)
317 .collect::<Vec<_>>();
318
319 let mut image_crops = Vec::new();
320 for (pos_h, pos_w) in crop_positions_h
321 .into_iter()
322 .cartesian_product(crop_positions_w)
323 {
324 image_crops.push(image.crop_imm(
325 pos_w as u32,
326 pos_h as u32,
327 crop_size_w as u32,
328 crop_size_h as u32,
329 ));
330 }
331
332 image_crops
333 }
334
335 fn process_images_for_pan_and_scan(
336 &self,
337 images: Vec<DynamicImage>,
338 pan_and_scan_min_crop_size: usize,
339 pan_and_scan_max_num_crops: usize,
340 pan_and_scan_min_ratio_to_activate: f64,
341 ) -> (Vec<DynamicImage>, Vec<usize>) {
342 let mut pas_images_list = Vec::new();
343 let mut num_crops = Vec::new();
344
345 for image in images {
346 let pas_images = self.pan_and_scan(
347 &image,
348 pan_and_scan_min_crop_size,
349 pan_and_scan_max_num_crops,
350 pan_and_scan_min_ratio_to_activate,
351 );
352 num_crops.push(pas_images.len());
353 pas_images_list.extend([vec![image], pas_images].concat());
354 }
355
356 (pas_images_list, num_crops)
357 }
358}
359
360impl ImagePreProcessor for Gemma3ImageProcessor {
361 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
362 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
363
364 fn preprocess(
365 &self,
366 mut images: Vec<DynamicImage>,
367 videos: Vec<Vec<DynamicImage>>,
368 config: &PreProcessorConfig,
369 device: &Device,
370 (_bs, _max_num_images): (usize, usize),
371 ) -> Result<PreprocessedImages> {
372 assert!(videos.is_empty());
373
374 let do_resize = config.do_resize.unwrap();
375 let size = config.size.as_ref().unwrap();
376 let (height, width) = (size["height"], size["width"]);
377 let resample = config.resampling.to_filter()?;
378 let do_rescale = config.do_rescale.unwrap();
379 let rescale_factor = config.rescale_factor.unwrap();
380 let do_normalize = config.do_normalize.unwrap();
381 let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
382 let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
383 let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
384 let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
385 let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
387 let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
388 let pan_and_scan_min_ratio_to_activate =
389 config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
390
391 for image in images.iter_mut() {
392 if do_convert_rgb {
394 *image = DynamicImage::ImageRgb8(image.to_rgb8());
395 }
396 }
397
398 let num_crops = if do_pan_and_scan {
399 let (new_images, num_crops) = self.process_images_for_pan_and_scan(
400 images,
401 pan_and_scan_min_crop_size,
402 pan_and_scan_max_num_crops,
403 pan_and_scan_min_ratio_to_activate,
404 );
405 images = new_images;
406 num_crops
407 } else {
408 vec![0]
409 };
410
411 let mut pixel_values = Vec::new();
412 for mut image in images {
413 if do_resize {
414 image = image.resize_exact(width, height, resample);
415 }
416
417 let transforms = Transforms {
418 input: &ToTensorNoNorm,
419 inner_transforms: &[
420 &do_rescale.then_some(Rescale {
421 factor: Some(rescale_factor),
422 }),
423 &do_normalize.then(|| Normalize {
424 mean: image_mean.to_vec(),
425 std: image_std.to_vec(),
426 }),
427 ],
428 };
429
430 let image = image.apply(transforms, device)?;
431 pixel_values.push(image.unsqueeze(0)?);
432 }
433
434 Ok(PreprocessedImages {
435 pixel_values: Tensor::cat(&pixel_values, 0)?,
436 pixel_attention_mask: None,
437 image_sizes: None,
438 num_img_tokens: None,
439 aspect_ratio_ids: None,
440 aspect_ratio_mask: None,
441 num_tiles: None,
442 image_grid_thw: None,
443 video_grid_thw: None,
444 rows: None,
445 cols: None,
446 pixel_values_list: None,
447 tgt_sizes: None,
448 image_sizes_all: None,
449 num_crops: Some(num_crops),
450 })
451 }
452}