1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use regex::Regex;
10use tokenizers::Tokenizer;
11
12use crate::{
13 device_map::DeviceMapper,
14 pipeline::{
15 text_models_inputs_processor::{
16 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17 },
18 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19 },
20 sequence::Sequence,
21 vision_models::{
22 image_processor::{ImagePreProcessor, PreprocessedImages},
23 preprocessor_config::{PreProcessorConfig, ToFilter},
24 processor_config::ProcessorConfig,
25 ModelInputs,
26 },
27};
28
29use super::Gemma3SpecificArgs;
30
31struct Gemma3ImageProcessor {
32 full_image_sequence: String,
33 supports_images: bool,
34}
35
36const IMAGE_TOKEN: &str = "<image_soft_token>";
37const BOI_TOKEN: &str = "<start_of_image>";
38const EOI_TOKEN: &str = "<end_of_image>";
39
40pub struct Gemma3Processor {
41 full_image_sequence: String,
42 supports_images: bool,
43}
44
45impl Gemma3Processor {
46 pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self {
47 let image_tokens_expanded =
48 vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
49 let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
50
51 Self {
52 full_image_sequence,
53 supports_images,
54 }
55 }
56}
57
58impl Processor for Gemma3Processor {
59 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
60 Arc::new(Gemma3ImageProcessor {
61 full_image_sequence: self.full_image_sequence.clone(),
62 supports_images: self.supports_images,
63 })
64 }
65
66 fn get_special_tokens(&self) -> &[&'static str] {
67 &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
68 }
69
70 fn template_action(&self) -> MessagesAction {
71 MessagesAction::Keep
72 }
73}
74
75impl InputsProcessor for Gemma3ImageProcessor {
76 fn get_type(&self) -> InputsProcessorType {
77 InputsProcessorType::Vision
78 }
79 fn process_inputs(
80 &self,
81 tokenizer: Option<Arc<Tokenizer>>,
82 input_seqs: &mut [&mut Sequence],
83 is_prompt: bool,
84 is_xlora: bool,
85 device: &Device,
86 no_kv_cache: bool,
87 last_n_context_len: Option<(usize, usize)>,
88 return_raw_logits: bool,
89 other_config: Option<Arc<dyn Any>>,
90 mut paged_attn_metadata: Option<PagedAttentionMeta>,
91 mapper: Option<&dyn DeviceMapper>,
92 ) -> anyhow::Result<InputProcessorOutput> {
93 if is_xlora {
94 return Err(anyhow::Error::msg(
95 "Cannot make inputs for X-LoRA vision model.",
96 ));
97 }
98 if no_kv_cache {
99 return Err(anyhow::Error::msg("Vision model must have kv cache."));
100 }
101 let Some(tokenizer) = tokenizer else {
102 return Err(anyhow::Error::msg(
103 "Idefics3ImageProcessor requires a specified tokenizer.",
104 ));
105 };
106
107 let config = other_config.expect("Need a PreProcessorConfig config.");
108 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
109
110 let has_images = input_seqs.iter().all(|seq| seq.has_images());
111
112 let pixel_values = if has_images {
113 if !self.supports_images {
114 return Err(anyhow::Error::msg(
115 "This image processor does not support images.",
116 ));
117 }
118
119 let mut pixel_values_accum = Vec::new();
120 let re = Regex::new(BOI_TOKEN).unwrap();
121 for seq in input_seqs.iter_mut() {
122 let PreprocessedImages {
123 pixel_values,
124 pixel_attention_mask: _,
125 image_sizes: _,
126 num_img_tokens: _,
127 aspect_ratio_ids: _,
128 aspect_ratio_mask: _,
129 num_tiles: _,
130 image_grid_thw: _,
131 video_grid_thw: _,
132 rows: _,
133 cols: _,
134 pixel_values_list: _,
135 tgt_sizes: _,
136 image_sizes_all: _,
137 num_crops,
138 } = self
139 .preprocess(
140 seq.take_images()
141 .expect("Need to have images by this point."),
142 vec![],
143 config,
144 device,
145 (usize::MAX, usize::MAX), )
147 .expect("Preprocessing failed");
148
149 let num_crops = num_crops.unwrap();
150
151 pixel_values_accum.push(pixel_values.clone());
153
154 let mut prompt = tokenizer
155 .decode(seq.get_toks(), false)
156 .expect("Detokenization failed!");
157
158 let image_indexes: Vec<usize> =
159 re.find_iter(&prompt).map(|mat| mat.start()).collect();
160
161 for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
162 if num != 0 {
163 let formatted_image_text = format!(
164 "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
165 );
166 prompt = format!(
167 "{}{formatted_image_text}{}",
168 &prompt[..idx],
169 &prompt[idx + BOI_TOKEN.len()..]
170 );
171 }
172 }
173
174 prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
175
176 if !seq.multimodal.has_changed_prompt {
177 seq.set_initial_prompt(prompt.clone());
178 let toks = tokenizer
179 .encode_fast(prompt, false)
180 .expect("Detokenization failed!");
181
182 let ids = toks.get_ids().to_vec();
183 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
184 seq.multimodal.has_changed_prompt = true;
185 }
186 }
187
188 Some(Tensor::cat(&pixel_values_accum, 0).unwrap())
189 } else {
190 None
191 };
192
193 let text_models_inputs_processor::InnerInputProcessorOutput {
194 inputs:
195 text_models_inputs_processor::InputMetadata {
196 input,
197 positions,
198 context_lens,
199 position_ids,
200 paged_attn_meta,
201 flash_meta,
202 },
203 seq_indices,
204 } = if is_prompt {
205 get_prompt_input(
206 input_seqs
207 .iter()
208 .map(|seq| seq.get_toks())
209 .collect::<Vec<_>>(),
210 input_seqs,
211 device,
212 last_n_context_len,
213 return_raw_logits,
214 paged_attn_metadata.as_mut(),
215 mapper,
216 )
217 .unwrap()
218 } else {
219 get_completion_input(
220 input_seqs
221 .iter()
222 .map(|seq| seq.get_toks())
223 .collect::<Vec<_>>(),
224 input_seqs,
225 device,
226 no_kv_cache,
227 last_n_context_len,
228 return_raw_logits,
229 paged_attn_metadata.as_mut(),
230 mapper,
231 )
232 .unwrap()
233 };
234
235 let inputs: Box<dyn Any> = Box::new(ModelInputs {
236 input_ids: input,
237 seqlen_offsets: positions,
238 context_lens,
239 position_ids,
240 pixel_values,
241 model_specific_args: Box::new(Gemma3SpecificArgs),
242 paged_attn_meta,
243 flash_meta,
244 });
245 Ok(InputProcessorOutput {
246 inputs,
247 seq_indices,
248 })
249 }
250}
251
252impl Gemma3ImageProcessor {
253 fn pan_and_scan(
254 &self,
255 image: &DynamicImage,
256 pan_and_scan_min_crop_size: usize,
257 pan_and_scan_max_num_crops: usize,
258 pan_and_scan_min_ratio_to_activate: f64,
259 ) -> Vec<DynamicImage> {
260 let (width, height) = image.dimensions();
261
262 let (num_crops_w, num_crops_h) = if width >= height {
263 if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
264 return vec![];
265 }
266
267 let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
269 num_crops_w = num_crops_w
270 .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
271
272 num_crops_w = num_crops_w.max(2);
274 num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
275
276 (num_crops_w, 1)
277 } else {
278 if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
279 return vec![];
280 }
281
282 let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
284 num_crops_h = num_crops_h
285 .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
286
287 num_crops_h = num_crops_h.max(2);
289 num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
290
291 (1, num_crops_h)
292 };
293
294 let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
295 let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
296
297 if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
298 return vec![];
299 }
300
301 let crop_positions_w = (0..num_crops_w)
302 .map(|i| i * crop_size_w)
303 .collect::<Vec<_>>();
304 let crop_positions_h = (0..num_crops_h)
305 .map(|i| i * crop_size_h)
306 .collect::<Vec<_>>();
307
308 let mut image_crops = Vec::new();
309 for (pos_h, pos_w) in crop_positions_h
310 .into_iter()
311 .cartesian_product(crop_positions_w)
312 {
313 image_crops.push(image.crop_imm(
314 pos_w as u32,
315 pos_h as u32,
316 crop_size_w as u32,
317 crop_size_h as u32,
318 ));
319 }
320
321 image_crops
322 }
323
324 fn process_images_for_pan_and_scan(
325 &self,
326 images: Vec<DynamicImage>,
327 pan_and_scan_min_crop_size: usize,
328 pan_and_scan_max_num_crops: usize,
329 pan_and_scan_min_ratio_to_activate: f64,
330 ) -> (Vec<DynamicImage>, Vec<usize>) {
331 let mut pas_images_list = Vec::new();
332 let mut num_crops = Vec::new();
333
334 for image in images {
335 let pas_images = self.pan_and_scan(
336 &image,
337 pan_and_scan_min_crop_size,
338 pan_and_scan_max_num_crops,
339 pan_and_scan_min_ratio_to_activate,
340 );
341 num_crops.push(pas_images.len());
342 pas_images_list.extend([vec![image], pas_images].concat());
343 }
344
345 (pas_images_list, num_crops)
346 }
347}
348
349impl ImagePreProcessor for Gemma3ImageProcessor {
350 const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
351 const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
352
353 fn preprocess(
354 &self,
355 mut images: Vec<DynamicImage>,
356 videos: Vec<Vec<DynamicImage>>,
357 config: &PreProcessorConfig,
358 device: &Device,
359 (_bs, _max_num_images): (usize, usize),
360 ) -> Result<PreprocessedImages> {
361 assert!(videos.is_empty());
362
363 let do_resize = config.do_resize.unwrap();
364 let size = config.size.as_ref().unwrap();
365 let (height, width) = (size["height"], size["width"]);
366 let resample = config.resampling.to_filter()?;
367 let do_rescale = config.do_rescale.unwrap();
368 let rescale_factor = config.rescale_factor.unwrap();
369 let do_normalize = config.do_normalize.unwrap();
370 let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
371 let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
372 let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
373 let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
374 let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
376 let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
377 let pan_and_scan_min_ratio_to_activate =
378 config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
379
380 for image in images.iter_mut() {
381 if do_convert_rgb {
383 *image = DynamicImage::ImageRgb8(image.to_rgb8());
384 }
385 }
386
387 let num_crops = if do_pan_and_scan {
388 let (new_images, num_crops) = self.process_images_for_pan_and_scan(
389 images,
390 pan_and_scan_min_crop_size,
391 pan_and_scan_max_num_crops,
392 pan_and_scan_min_ratio_to_activate,
393 );
394 images = new_images;
395 num_crops
396 } else {
397 vec![0]
398 };
399
400 let mut pixel_values = Vec::new();
401 for mut image in images {
402 if do_resize {
403 image = image.resize_exact(width, height, resample);
404 }
405
406 let transforms = Transforms {
407 input: &ToTensorNoNorm,
408 inner_transforms: &[
409 &do_rescale.then_some(Rescale {
410 factor: Some(rescale_factor),
411 }),
412 &do_normalize.then(|| Normalize {
413 mean: image_mean.to_vec(),
414 std: image_std.to_vec(),
415 }),
416 ],
417 };
418
419 let image = image.apply(transforms, device)?;
420 pixel_values.push(image.unsqueeze(0)?);
421 }
422
423 Ok(PreprocessedImages {
424 pixel_values: Tensor::cat(&pixel_values, 0)?,
425 pixel_attention_mask: None,
426 image_sizes: None,
427 num_img_tokens: None,
428 aspect_ratio_ids: None,
429 aspect_ratio_mask: None,
430 num_tiles: None,
431 image_grid_thw: None,
432 video_grid_thw: None,
433 rows: None,
434 cols: None,
435 pixel_values_list: None,
436 tgt_sizes: None,
437 image_sizes_all: None,
438 num_crops: Some(num_crops),
439 })
440 }
441}