1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, cmp, collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9
10use crate::{
11 device_map::DeviceMapper,
12 pipeline::{
13 text_models_inputs_processor::{
14 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
15 },
16 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
17 },
18 sequence::Sequence,
19 vision_models::ModelInputs,
20};
21
22use crate::vision_models::{
23 image_processor::{ImagePreProcessor, PreprocessedImages},
24 preprocessor_config::{PreProcessorConfig, ToFilter},
25 processor_config::ProcessorConfig,
26};
27
28const MAX_IMAGE_SIZE: usize = 4096;
30const FAKE_IMAGE_TOKEN: &str = "<fake_token_around_image>";
31const IMAGE_TOKEN: &str = "<image>";
32const GLOBAL_IMAGE_TOKEN: &str = "<global-img>";
33
34pub struct Idefics3ImageProcessor {
35 max_edge: Option<u32>,
36 image_seq_len: usize,
37}
38
39pub struct Idefics3Processor {
40 config: ProcessorConfig,
41 max_edge: Option<u32>,
42}
43
44impl Idefics3Processor {
45 pub fn new(
46 config: ProcessorConfig,
47 _preprocessor_config: PreProcessorConfig,
48 max_edge: Option<u32>,
49 ) -> Self {
50 Self { config, max_edge }
51 }
52}
53
54impl Processor for Idefics3Processor {
55 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
56 Arc::new(Idefics3ImageProcessor {
58 max_edge: self.max_edge,
59 image_seq_len: self.config.image_seq_len.unwrap_or(169),
60 })
61 }
62
63 fn get_special_tokens(&self) -> &[&'static str] {
64 &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
65 }
66
67 fn template_action(&self) -> MessagesAction {
68 MessagesAction::Keep
69 }
70}
71
72fn get_image_prompt_string(n_rows: usize, n_cols: usize, image_seq_len: usize) -> String {
73 if n_rows == 0 && n_cols == 0 {
74 format!(
75 "{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
76 IMAGE_TOKEN.repeat(image_seq_len)
77 )
78 } else {
79 let mut text_split_images = String::new();
80 for n_h in 0..n_rows {
81 for n_w in 0..n_cols {
82 text_split_images.push_str(&format!(
83 "{FAKE_IMAGE_TOKEN}<row_{}_col_{}>{}",
84 n_h + 1,
85 n_w + 1,
86 IMAGE_TOKEN.repeat(image_seq_len)
87 ));
88 }
89 text_split_images.push('\n');
90 }
91 format!(
92 "{text_split_images}\n{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
93 IMAGE_TOKEN.repeat(image_seq_len)
94 )
95 }
96}
97
98impl InputsProcessor for Idefics3ImageProcessor {
99 fn get_type(&self) -> InputsProcessorType {
100 InputsProcessorType::Vision
101 }
102 fn process_inputs(
103 &self,
104 tokenizer: Option<Arc<Tokenizer>>,
105 input_seqs: &mut [&mut Sequence],
106 is_prompt: bool,
107 is_xlora: bool,
108 device: &Device,
109 no_kv_cache: bool,
110 last_n_context_len: Option<(usize, usize)>,
111 return_raw_logits: bool,
112 other_config: Option<Arc<dyn Any>>,
113 mut paged_attn_metadata: Option<PagedAttentionMeta>,
114 mapper: Option<&dyn DeviceMapper>,
115 ) -> anyhow::Result<InputProcessorOutput> {
116 if is_xlora {
117 return Err(anyhow::Error::msg(
118 "Cannot make inputs for X-LoRA vision model.",
119 ));
120 }
121 if no_kv_cache {
122 return Err(anyhow::Error::msg("Vision model must have kv cache."));
123 }
124 let Some(tokenizer) = tokenizer else {
125 return Err(anyhow::Error::msg(
126 "Idefics3ImageProcessor requires a specified tokenizer.",
127 ));
128 };
129
130 let config = other_config.expect("Need a PreProcessorConfig config.");
131 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
132
133 let has_images = input_seqs.iter().all(|seq| seq.has_images());
134
135 let (pixel_values, pixel_attention_mask) = if has_images {
136 let mut pixel_values_accum = Vec::new();
137 let mut pixel_attention_mask_accum = Vec::new();
138 for seq in input_seqs.iter_mut() {
139 let PreprocessedImages {
140 pixel_values,
141 pixel_attention_mask,
142 image_sizes: _,
143 num_img_tokens: _,
144 aspect_ratio_ids: _,
145 aspect_ratio_mask: _,
146 num_tiles: _,
147 image_grid_thw: _,
148 video_grid_thw: _,
149 rows,
150 cols,
151 pixel_values_list: _,
152 tgt_sizes: _,
153 image_sizes_all: _,
154 num_crops: _,
155 } = self
156 .preprocess(
157 seq.take_images()
158 .expect("Need to have images by this point."),
159 vec![],
160 config,
161 device,
162 (usize::MAX, usize::MAX), )
164 .expect("Preprocessing failed");
165 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
166 pixel_attention_mask_accum
167 .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
168
169 if !seq.multimodal.has_changed_prompt {
170 let detok = tokenizer
171 .decode(seq.get_toks(), false)
172 .expect("Detokenization failed!");
173
174 let mut image_prompt_strings = Vec::new();
175 for (n_rows, n_cols) in rows.unwrap().into_iter().zip(cols.unwrap().into_iter())
176 {
177 let image_prompt_string =
178 get_image_prompt_string(n_rows, n_cols, self.image_seq_len);
179 image_prompt_strings.push(image_prompt_string);
180 }
181
182 let split_sample = detok.split(IMAGE_TOKEN).collect::<Vec<_>>();
183 let mut sample = split_sample
184 .first()
185 .expect("The image token <image> should be present in the text.")
186 .to_string();
187 for (i, image_prompt_string) in image_prompt_strings.into_iter().enumerate() {
188 sample.push_str(&format!(
189 "{image_prompt_string}{}",
190 split_sample
191 .get(i + 1)
192 .expect("Incorrect chat template. Use the one provided in `chat_templates` with the `--chat-template`/`chat_template` settings.")
193 ));
194 }
195
196 seq.set_initial_prompt(sample.clone());
197 let toks = tokenizer
198 .encode_fast(sample, false)
199 .expect("Detokenization failed!");
200
201 let ids = toks.get_ids().to_vec();
202 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
203 seq.multimodal.has_changed_prompt = true;
204 }
205 }
206
207 (
208 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
209 Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
210 )
211 } else {
212 (None, None)
213 };
214
215 let text_models_inputs_processor::InnerInputProcessorOutput {
216 inputs:
217 text_models_inputs_processor::InputMetadata {
218 input,
219 positions,
220 context_lens,
221 position_ids,
222 paged_attn_meta,
223 flash_meta,
224 },
225 seq_indices,
226 } = if is_prompt {
227 get_prompt_input(
228 input_seqs
229 .iter()
230 .map(|seq| seq.get_toks())
231 .collect::<Vec<_>>(),
232 input_seqs,
233 device,
234 last_n_context_len,
235 return_raw_logits,
236 paged_attn_metadata.as_mut(),
237 mapper,
238 )
239 .unwrap()
240 } else {
241 get_completion_input(
242 input_seqs
243 .iter()
244 .map(|seq| seq.get_toks())
245 .collect::<Vec<_>>(),
246 input_seqs,
247 device,
248 no_kv_cache,
249 last_n_context_len,
250 return_raw_logits,
251 paged_attn_metadata.as_mut(),
252 mapper,
253 )
254 .unwrap()
255 };
256
257 let inputs: Box<dyn Any> = Box::new(ModelInputs {
258 input_ids: input,
259 seqlen_offsets: positions,
260 context_lens,
261 position_ids,
262 pixel_values,
263 model_specific_args: Box::new(pixel_attention_mask),
264 paged_attn_meta,
265 flash_meta,
266 });
267 Ok(InputProcessorOutput {
268 inputs,
269 seq_indices,
270 })
271 }
272}
273
274fn resize_output_size_rescale_to_max_len(
276 height: usize,
277 width: usize,
278 min_len: Option<usize>,
279 max_len: Option<usize>,
280) -> (usize, usize) {
281 let min_len = min_len.unwrap_or(1);
282 let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
283 let aspect_ratio = width as f32 / height as f32;
284 let (mut height, mut width) = (height, width);
285
286 if width >= height {
287 width = max_len;
288 height = (width as f32 / aspect_ratio).round() as usize;
289 if height % 2 != 0 {
290 height += 1;
291 }
292 } else {
293 height = max_len;
294 width = (height as f32 * aspect_ratio).round() as usize;
295 if width % 2 != 0 {
296 width += 1;
297 }
298 }
299
300 height = cmp::max(height, min_len);
301 width = cmp::max(width, min_len);
302
303 (height, width)
304}
305
306fn resize_output_size_scale_below_upper_bound(
308 height: usize,
309 width: usize,
310 max_len: Option<usize>,
311) -> (usize, usize) {
312 let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
313 let aspect_ratio = width as f32 / height as f32;
314 let (mut height, mut width) = (height, width);
315
316 if width >= height && width > max_len {
317 width = max_len;
318 height = (width as f32 / aspect_ratio).round() as usize;
319 } else if height > width && height > max_len {
320 height = max_len;
321 width = (height as f32 * aspect_ratio).round() as usize;
322 }
323
324 height = cmp::max(height, 1);
325 width = cmp::max(width, 1);
326
327 (height, width)
328}
329
330fn get_resize_output_image_size(
333 (h, w): (usize, usize),
334 resolution_max_side: usize,
335) -> (usize, usize) {
336 let (h, w) = resize_output_size_rescale_to_max_len(h, w, None, Some(resolution_max_side));
337 resize_output_size_scale_below_upper_bound(h, w, Some(MAX_IMAGE_SIZE))
338}
339
340fn resize_for_vision_encoder(
341 (h, w): (usize, usize),
342 vision_encoder_max_size: usize,
343) -> (usize, usize) {
344 let aspect_ratio = w as f32 / h as f32;
345
346 let (new_h, new_w) = if w >= h {
347 let new_w = ((w as f32 / vision_encoder_max_size as f32).ceil()
348 * vision_encoder_max_size as f32) as usize;
349 let mut new_h = (new_w as f32 / aspect_ratio) as usize;
350 new_h = ((new_h as f32 / vision_encoder_max_size as f32).ceil()
351 * vision_encoder_max_size as f32) as usize;
352 (new_h, new_w)
353 } else {
354 let new_h = ((h as f32 / vision_encoder_max_size as f32).ceil()
355 * vision_encoder_max_size as f32) as usize;
356 let mut new_w = (new_h as f32 * aspect_ratio) as usize;
357 new_w = ((new_w as f32 / vision_encoder_max_size as f32).ceil()
358 * vision_encoder_max_size as f32) as usize;
359 (new_h, new_w)
360 };
361
362 (new_h, new_w)
363}
364
365fn resize(
366 image: &DynamicImage,
367 size: &HashMap<String, u32>,
368 resampling: FilterType,
369) -> Result<DynamicImage> {
370 let (h, w) = if size.contains_key("longest_edge") {
371 get_resize_output_image_size(
372 (image.dimensions().1 as usize, image.dimensions().0 as usize),
373 size["longest_edge"] as usize,
374 )
375 } else if size.contains_key("height") && size.contains_key("width") {
376 (size["height"] as usize, size["width"] as usize)
377 } else {
378 candle_core::bail!(
379 "Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."
380 );
381 };
382
383 Ok(image.resize_exact(w as u32, h as u32, resampling))
384 }
386
387fn split_image(
389 image: &DynamicImage,
390 longest_edge: usize,
391) -> Result<(Vec<DynamicImage>, usize, usize)> {
392 let (width, height) = image.dimensions();
393 let mut frames = Vec::new();
394
395 if width > longest_edge as u32 || height > longest_edge as u32 {
396 let num_splits_h = (height as f64 / (longest_edge as f64)).ceil() as usize;
397 let num_splits_w = (width as f64 / (longest_edge as f64)).ceil() as usize;
398
399 let optimal_height = (height as f64 / num_splits_h as f64).ceil() as u32;
400 let optimal_width = (width as f64 / num_splits_w as f64).ceil() as u32;
401
402 for r in 0..num_splits_h {
403 for c in 0..num_splits_w {
404 let start_x = (c as u32) * optimal_width;
405 let start_y = (r as u32) * optimal_height;
406
407 let end_x = std::cmp::min(start_x + optimal_width, width);
408 let end_y = std::cmp::min(start_y + optimal_height, height);
409
410 let cropped_image =
412 image.crop_imm(start_x, start_y, end_x - start_x, end_y - start_y);
413 frames.push(cropped_image);
414 }
415 }
416
417 let resized_image = resize(
419 image,
420 &HashMap::from([
421 ("height".to_string(), longest_edge as u32),
422 ("width".to_string(), longest_edge as u32),
423 ]),
424 FilterType::Lanczos3,
425 )?;
426 frames.push(resized_image);
427
428 Ok((frames, num_splits_h, num_splits_w))
429 } else {
430 frames.push(image.clone());
431 Ok((frames, 0, 0))
432 }
433}
434
435impl ImagePreProcessor for Idefics3ImageProcessor {
436 #[allow(clippy::excessive_precision)]
437 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
438 #[allow(clippy::excessive_precision)]
439 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
440
441 fn preprocess(
442 &self,
443 mut images: Vec<DynamicImage>,
444 videos: Vec<Vec<DynamicImage>>,
445 config: &PreProcessorConfig,
446 device: &Device,
447 (_bs, _max_num_images): (usize, usize),
448 ) -> Result<PreprocessedImages> {
449 assert!(videos.is_empty());
450
451 let mut patch_masks = Vec::new();
452 let mut pixel_values = Vec::new();
453
454 if let Some(max_edge) = self.max_edge {
455 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
456 }
457
458 for image in images.iter_mut() {
459 if config.do_convert_rgb.is_some_and(|x| x) {
461 *image = DynamicImage::ImageRgb8(image.to_rgb8());
462 }
463
464 if config.do_resize.is_some_and(|x| x) {
466 *image = resize(
467 image,
468 config.size.as_ref().unwrap(),
469 config.resampling.to_filter()?,
470 )?;
471 }
472 }
473
474 let mut image_rows = Vec::new();
475 let mut image_cols = Vec::new();
476 let mut new_images = Vec::new();
477 let max_image_size = config
478 .max_image_size
479 .clone()
480 .unwrap_or_else(|| HashMap::from([("longest_edge".to_string(), 364)]));
481 if config.do_image_splitting.unwrap_or(true) {
482 for image in images.iter_mut() {
486 let (new_h, new_w) = resize_for_vision_encoder(
487 (image.dimensions().1 as usize, image.dimensions().0 as usize),
488 max_image_size["longest_edge"] as usize,
489 );
490
491 *image =
492 image.resize_exact(new_w as u32, new_h as u32, config.resampling.to_filter()?);
493
494 let (split_image_array, rows, cols) =
495 split_image(image, max_image_size["longest_edge"] as usize)?;
496 new_images.extend(split_image_array.into_iter());
497 image_rows.push(rows);
498 image_cols.push(cols);
499 }
500 } else {
501 for image in images.iter_mut() {
503 new_images.push(resize(
504 image,
505 &HashMap::from([
506 ("height".to_string(), max_image_size["longest_edge"]),
507 ("width".to_string(), max_image_size["longest_edge"]),
508 ]),
509 FilterType::Lanczos3,
510 )?);
511 }
512 image_rows = vec![0; images.len()];
513 image_cols = vec![0; images.len()];
514 }
515 images = new_images;
516
517 let mut max_h = 0;
518 let mut max_w = 0;
519 for image in &images {
520 let (w, h) = image.dimensions();
521 if w > max_w {
522 max_w = w;
523 }
524 if h > max_h {
525 max_h = h;
526 }
527 }
528
529 for image in images.iter_mut() {
530 let transforms = Transforms {
531 input: &ToTensorNoNorm,
532 inner_transforms: &[
533 &config
534 .do_rescale
535 .is_some_and(|x| x)
536 .then_some(())
537 .map(|_| Rescale {
538 factor: config.rescale_factor,
539 }),
540 &config
541 .do_normalize
542 .is_some_and(|x| x)
543 .then_some(())
544 .map(|_| Normalize {
545 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
546 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
547 }),
548 ],
549 };
550
551 let mut image = image.apply(transforms, device)?;
552 if config.do_pad.is_some_and(|x| x) {
554 let (_c, h, w) = image.dims3()?;
555 let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
556 let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
557 patch_masks.push(mask.unsqueeze(0)?);
558 image = padded;
559 }
560
561 pixel_values.push(image.unsqueeze(0)?)
563 }
564
565 Ok(PreprocessedImages {
566 pixel_values: Tensor::cat(&pixel_values, 0)?,
567 pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
568 image_sizes: None,
569 num_img_tokens: None,
570 aspect_ratio_ids: None,
571 aspect_ratio_mask: None,
572 num_tiles: None,
573 image_grid_thw: None,
574 video_grid_thw: None,
575 rows: Some(image_rows),
576 cols: Some(image_cols),
577 pixel_values_list: None,
578 tgt_sizes: None,
579 image_sizes_all: None,
580 num_crops: None,
581 })
582 }
583}