1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use regex::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14 device_map::DeviceMapper,
15 pipeline::{
16 text_models_inputs_processor::{
17 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18 },
19 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20 },
21 sequence::Sequence,
22 vision_models::{
23 image_processor::{ImagePreProcessor, PreprocessedImages},
24 preprocessor_config::{PreProcessorConfig, ToFilter},
25 processor_config::ProcessorConfig,
26 ModelInputs,
27 },
28};
29
30use super::Gemma3SpecificArgs;
31
32struct Gemma3ImageProcessor {
33 full_image_sequence: String,
34 supports_images: bool,
35}
36
37const IMAGE_TOKEN: &str = "<image_soft_token>";
38const BOI_TOKEN: &str = "<start_of_image>";
39const EOI_TOKEN: &str = "<end_of_image>";
40
41pub struct Gemma3Processor {
42 full_image_sequence: String,
43 supports_images: bool,
44}
45
46impl Gemma3Processor {
47 pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self {
48 let image_tokens_expanded =
49 vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
50 let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
51
52 Self {
53 full_image_sequence,
54 supports_images,
55 }
56 }
57}
58
59impl Processor for Gemma3Processor {
60 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
61 Arc::new(Gemma3ImageProcessor {
62 full_image_sequence: self.full_image_sequence.clone(),
63 supports_images: self.supports_images,
64 })
65 }
66
67 fn get_special_tokens(&self) -> &[&'static str] {
68 &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
69 }
70
71 fn template_action(&self) -> MessagesAction {
72 MessagesAction::Keep
73 }
74}
75
76impl InputsProcessor for Gemma3ImageProcessor {
77 fn get_type(&self) -> InputsProcessorType {
78 InputsProcessorType::Vision
79 }
80 fn process_inputs(
81 &self,
82 tokenizer: Option<Arc<Tokenizer>>,
83 input_seqs: &mut [&mut Sequence],
84 is_prompt: bool,
85 is_xlora: bool,
86 device: &Device,
87 no_kv_cache: bool,
88 last_n_context_len: Option<(usize, usize)>,
89 return_raw_logits: bool,
90 other_config: Option<Arc<dyn Any>>,
91 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
92 prompt_chunksize: Option<NonZeroUsize>,
93 mapper: Option<&dyn DeviceMapper>,
94 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
95 if is_xlora {
96 return Box::new(std::iter::once(Err(anyhow::Error::msg(
97 "Cannot make inputs for X-LoRA vision model.",
98 ))));
99 }
100 if no_kv_cache {
101 return Box::new(std::iter::once(Err(anyhow::Error::msg(
102 "Vision model must have kv cache.",
103 ))));
104 }
105 if prompt_chunksize.is_some() {
107 warn!("`prompt_chunksize` is set. Gemma3 does not support prompt batching.");
108 }
109 let Some(tokenizer) = tokenizer else {
110 return Box::new(std::iter::once(Err(anyhow::Error::msg(
111 "Idefics3ImageProcessor requires a specified tokenizer.",
112 ))));
113 };
114
115 let config = other_config.expect("Need a PreProcessorConfig config.");
116 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118 let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120 let pixel_values = if has_images {
121 if !self.supports_images {
122 return Box::new(std::iter::once(Err(anyhow::Error::msg(
123 "This image processor does not support images.",
124 ))));
125 }
126
127 let mut pixel_values_accum = Vec::new();
128 let re = Regex::new(BOI_TOKEN).unwrap();
129 for seq in input_seqs.iter_mut() {
130 let PreprocessedImages {
131 pixel_values,
132 pixel_attention_mask: _,
133 image_sizes: _,
134 num_img_tokens: _,
135 aspect_ratio_ids: _,
136 aspect_ratio_mask: _,
137 num_tiles: _,
138 image_grid_thw: _,
139 video_grid_thw: _,
140 rows: _,
141 cols: _,
142 pixel_values_list: _,
143 tgt_sizes: _,
144 image_sizes_all: _,
145 num_crops,
146 } = self
147 .preprocess(
148 seq.take_images()
149 .expect("Need to have images by this point."),
150 vec![],
151 config,
152 device,
153 (usize::MAX, usize::MAX), )
155 .expect("Preprocessing failed");
156
157 let num_crops = num_crops.unwrap();
158
159 pixel_values_accum.push(pixel_values.clone());
161
162 let mut prompt = tokenizer
163 .decode(seq.get_toks(), false)
164 .expect("Detokenization failed!");
165
166 let image_indexes: Vec<usize> =
167 re.find_iter(&prompt).map(|mat| mat.start()).collect();
168
169 for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
170 if num != 0 {
171 let formatted_image_text = format!(
172 "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
173 );
174 prompt = format!(
175 "{}{formatted_image_text}{}",
176 &prompt[..idx],
177 &prompt[idx + BOI_TOKEN.len()..]
178 );
179 }
180 }
181
182 prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
183
184 if !seq.has_changed_prompt {
185 seq.set_initial_prompt(prompt.clone());
186 let toks = tokenizer
187 .encode_fast(prompt, false)
188 .expect("Detokenization failed!");
189
190 let ids = toks.get_ids().to_vec();
191 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
192 seq.has_changed_prompt = true;
193 }
194 }
195
196 Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
197 } else {
198 None
199 };
200
201 let text_models_inputs_processor::InnerInputProcessorOutput {
202 inputs:
203 text_models_inputs_processor::InputMetadata {
204 input,
205 positions,
206 context_lens,
207 position_ids,
208 paged_attn_meta,
209 flash_meta,
210 },
211 seq_indices,
212 } = if is_prompt {
213 get_prompt_input(
214 input_seqs
215 .iter()
216 .map(|seq| seq.get_toks().to_vec())
217 .collect::<Vec<_>>(),
218 input_seqs,
219 device,
220 last_n_context_len,
221 return_raw_logits,
222 paged_attn_metadata.as_mut(),
223 None, mapper,
225 )
226 .nth(0)
227 .unwrap()
228 .unwrap()
229 } else {
230 get_completion_input(
231 input_seqs
232 .iter()
233 .map(|seq| seq.get_toks().to_vec())
234 .collect::<Vec<_>>(),
235 input_seqs,
236 device,
237 no_kv_cache,
238 last_n_context_len,
239 return_raw_logits,
240 paged_attn_metadata.as_mut(),
241 None, mapper,
243 )
244 .nth(0)
245 .unwrap()
246 .unwrap()
247 };
248
249 let inputs: Box<dyn Any> = Box::new(ModelInputs {
250 input_ids: input,
251 seqlen_offsets: positions,
252 context_lens,
253 position_ids,
254 pixel_values,
255 model_specific_args: Box::new(Gemma3SpecificArgs),
256 paged_attn_meta,
257 flash_meta,
258 });
259 Box::new(std::iter::once(Ok(InputProcessorOutput {
260 inputs,
261 seq_indices,
262 })))
263 }
264}
265
266impl Gemma3ImageProcessor {
267 fn pan_and_scan(
268 &self,
269 image: &DynamicImage,
270 pan_and_scan_min_crop_size: usize,
271 pan_and_scan_max_num_crops: usize,
272 pan_and_scan_min_ratio_to_activate: f64,
273 ) -> Vec<DynamicImage> {
274 let (width, height) = image.dimensions();
275
276 let (num_crops_w, num_crops_h) = if width >= height {
277 if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
278 return vec![];
279 }
280
281 let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
283 num_crops_w = num_crops_w
284 .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
285
286 num_crops_w = num_crops_w.max(2);
288 num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
289
290 (num_crops_w, 1)
291 } else {
292 if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
293 return vec![];
294 }
295
296 let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
298 num_crops_h = num_crops_h
299 .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
300
301 num_crops_h = num_crops_h.max(2);
303 num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
304
305 (1, num_crops_h)
306 };
307
308 let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
309 let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
310
311 if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
312 return vec![];
313 }
314
315 let crop_positions_w = (0..num_crops_w)
316 .map(|i| i * crop_size_w)
317 .collect::<Vec<_>>();
318 let crop_positions_h = (0..num_crops_h)
319 .map(|i| i * crop_size_h)
320 .collect::<Vec<_>>();
321
322 let mut image_crops = Vec::new();
323 for (pos_h, pos_w) in crop_positions_h
324 .into_iter()
325 .cartesian_product(crop_positions_w)
326 {
327 image_crops.push(image.crop_imm(
328 pos_w as u32,
329 pos_h as u32,
330 crop_size_w as u32,
331 crop_size_h as u32,
332 ));
333 }
334
335 image_crops
336 }
337
338 fn process_images_for_pan_and_scan(
339 &self,
340 images: Vec<DynamicImage>,
341 pan_and_scan_min_crop_size: usize,
342 pan_and_scan_max_num_crops: usize,
343 pan_and_scan_min_ratio_to_activate: f64,
344 ) -> (Vec<DynamicImage>, Vec<usize>) {
345 let mut pas_images_list = Vec::new();
346 let mut num_crops = Vec::new();
347
348 for image in images {
349 let pas_images = self.pan_and_scan(
350 &image,
351 pan_and_scan_min_crop_size,
352 pan_and_scan_max_num_crops,
353 pan_and_scan_min_ratio_to_activate,
354 );
355 num_crops.push(pas_images.len());
356 pas_images_list.extend([vec![image], pas_images].concat());
357 }
358
359 (pas_images_list, num_crops)
360 }
361}
362
363impl ImagePreProcessor for Gemma3ImageProcessor {
364 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
365 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
366
367 fn preprocess(
368 &self,
369 mut images: Vec<DynamicImage>,
370 videos: Vec<Vec<DynamicImage>>,
371 config: &PreProcessorConfig,
372 device: &Device,
373 (_bs, _max_num_images): (usize, usize),
374 ) -> Result<PreprocessedImages> {
375 assert!(videos.is_empty());
376
377 let do_resize = config.do_resize.unwrap();
378 let size = config.size.as_ref().unwrap();
379 let (height, width) = (size["height"], size["width"]);
380 let resample = config.resampling.to_filter()?;
381 let do_rescale = config.do_rescale.unwrap();
382 let rescale_factor = config.rescale_factor.unwrap();
383 let do_normalize = config.do_normalize.unwrap();
384 let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
385 let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
386 let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
387 let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
388 let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
390 let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
391 let pan_and_scan_min_ratio_to_activate =
392 config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
393
394 for image in images.iter_mut() {
395 if do_convert_rgb {
397 *image = DynamicImage::ImageRgb8(image.to_rgb8());
398 }
399 }
400
401 let num_crops = if do_pan_and_scan {
402 let (new_images, num_crops) = self.process_images_for_pan_and_scan(
403 images,
404 pan_and_scan_min_crop_size,
405 pan_and_scan_max_num_crops,
406 pan_and_scan_min_ratio_to_activate,
407 );
408 images = new_images;
409 num_crops
410 } else {
411 vec![0]
412 };
413
414 let mut pixel_values = Vec::new();
415 for mut image in images {
416 if do_resize {
417 image = image.resize_exact(width, height, resample);
418 }
419
420 let transforms = Transforms {
421 input: &ToTensorNoNorm,
422 inner_transforms: &[
423 &do_rescale.then_some(Rescale {
424 factor: Some(rescale_factor),
425 }),
426 &do_normalize.then(|| Normalize {
427 mean: image_mean.to_vec(),
428 std: image_std.to_vec(),
429 }),
430 ],
431 };
432
433 let image = image.apply(transforms, device)?;
434 pixel_values.push(image.unsqueeze(0)?);
435 }
436
437 Ok(PreprocessedImages {
438 pixel_values: Tensor::cat(&pixel_values, 0)?,
439 pixel_attention_mask: None,
440 image_sizes: None,
441 num_img_tokens: None,
442 aspect_ratio_ids: None,
443 aspect_ratio_mask: None,
444 num_tiles: None,
445 image_grid_thw: None,
446 video_grid_thw: None,
447 rows: None,
448 cols: None,
449 pixel_values_list: None,
450 tgt_sizes: None,
451 image_sizes_all: None,
452 num_crops: Some(num_crops),
453 })
454 }
455}