1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, cmp, collections::HashMap, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9use tracing::warn;
10
11use crate::{
12 device_map::DeviceMapper,
13 pipeline::{
14 text_models_inputs_processor::{
15 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16 },
17 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18 },
19 sequence::Sequence,
20 vision_models::ModelInputs,
21};
22
23use crate::vision_models::{
24 image_processor::{ImagePreProcessor, PreprocessedImages},
25 preprocessor_config::{PreProcessorConfig, ToFilter},
26 processor_config::ProcessorConfig,
27};
28
29const MAX_IMAGE_SIZE: usize = 4096;
31const FAKE_IMAGE_TOKEN: &str = "<fake_token_around_image>";
32const IMAGE_TOKEN: &str = "<image>";
33const GLOBAL_IMAGE_TOKEN: &str = "<global-img>";
34
35pub struct Idefics3ImageProcessor {
36 max_edge: Option<u32>,
37 image_seq_len: usize,
38}
39
40pub struct Idefics3Processor {
41 config: ProcessorConfig,
42 max_edge: Option<u32>,
43}
44
45impl Idefics3Processor {
46 pub fn new(
47 config: ProcessorConfig,
48 _preprocessor_config: PreProcessorConfig,
49 max_edge: Option<u32>,
50 ) -> Self {
51 Self { config, max_edge }
52 }
53}
54
55impl Processor for Idefics3Processor {
56 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
57 Arc::new(Idefics3ImageProcessor {
59 max_edge: self.max_edge,
60 image_seq_len: self.config.image_seq_len.unwrap_or(169),
61 })
62 }
63
64 fn get_special_tokens(&self) -> &[&'static str] {
65 &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
66 }
67
68 fn template_action(&self) -> MessagesAction {
69 MessagesAction::Keep
70 }
71}
72
73fn get_image_prompt_string(n_rows: usize, n_cols: usize, image_seq_len: usize) -> String {
74 if n_rows == 0 && n_cols == 0 {
75 format!(
76 "{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
77 IMAGE_TOKEN.repeat(image_seq_len)
78 )
79 } else {
80 let mut text_split_images = String::new();
81 for n_h in 0..n_rows {
82 for n_w in 0..n_cols {
83 text_split_images.push_str(&format!(
84 "{FAKE_IMAGE_TOKEN}<row_{}_col_{}>{}",
85 n_h + 1,
86 n_w + 1,
87 IMAGE_TOKEN.repeat(image_seq_len)
88 ));
89 }
90 text_split_images.push('\n');
91 }
92 format!(
93 "{text_split_images}\n{FAKE_IMAGE_TOKEN}{GLOBAL_IMAGE_TOKEN}{}{FAKE_IMAGE_TOKEN}",
94 IMAGE_TOKEN.repeat(image_seq_len)
95 )
96 }
97}
98
99impl InputsProcessor for Idefics3ImageProcessor {
100 fn get_type(&self) -> InputsProcessorType {
101 InputsProcessorType::Vision
102 }
103 fn process_inputs(
104 &self,
105 tokenizer: Option<Arc<Tokenizer>>,
106 input_seqs: &mut [&mut Sequence],
107 is_prompt: bool,
108 is_xlora: bool,
109 device: &Device,
110 no_kv_cache: bool,
111 last_n_context_len: Option<(usize, usize)>,
112 return_raw_logits: bool,
113 other_config: Option<Arc<dyn Any>>,
114 mut paged_attn_metadata: Option<PagedAttentionMeta>,
115 prompt_chunksize: Option<NonZeroUsize>,
116 mapper: Option<&dyn DeviceMapper>,
117 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
118 if is_xlora {
119 return Box::new(std::iter::once(Err(anyhow::Error::msg(
120 "Cannot make inputs for X-LoRA vision model.",
121 ))));
122 }
123 if no_kv_cache {
124 return Box::new(std::iter::once(Err(anyhow::Error::msg(
125 "Vision model must have kv cache.",
126 ))));
127 }
128 if prompt_chunksize.is_some() {
130 warn!("`prompt_chunksize` is set. Idefics 3 does not support prompt batching.");
131 }
132 let Some(tokenizer) = tokenizer else {
133 return Box::new(std::iter::once(Err(anyhow::Error::msg(
134 "Idefics3ImageProcessor requires a specified tokenizer.",
135 ))));
136 };
137
138 let config = other_config.expect("Need a PreProcessorConfig config.");
139 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
140
141 let has_images = input_seqs.iter().all(|seq| seq.has_images());
142
143 let (pixel_values, pixel_attention_mask) = if has_images {
144 let mut pixel_values_accum = Vec::new();
145 let mut pixel_attention_mask_accum = Vec::new();
146 for seq in input_seqs.iter_mut() {
147 let PreprocessedImages {
148 pixel_values,
149 pixel_attention_mask,
150 image_sizes: _,
151 num_img_tokens: _,
152 aspect_ratio_ids: _,
153 aspect_ratio_mask: _,
154 num_tiles: _,
155 image_grid_thw: _,
156 video_grid_thw: _,
157 rows,
158 cols,
159 pixel_values_list: _,
160 tgt_sizes: _,
161 image_sizes_all: _,
162 num_crops: _,
163 } = self
164 .preprocess(
165 seq.take_images()
166 .expect("Need to have images by this point."),
167 vec![],
168 config,
169 device,
170 (usize::MAX, usize::MAX), )
172 .expect("Preprocessing failed");
173 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
174 pixel_attention_mask_accum
175 .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
176
177 if !seq.multimodal.has_changed_prompt {
178 let detok = tokenizer
179 .decode(seq.get_toks(), false)
180 .expect("Detokenization failed!");
181
182 let mut image_prompt_strings = Vec::new();
183 for (n_rows, n_cols) in rows.unwrap().into_iter().zip(cols.unwrap().into_iter())
184 {
185 let image_prompt_string =
186 get_image_prompt_string(n_rows, n_cols, self.image_seq_len);
187 image_prompt_strings.push(image_prompt_string);
188 }
189
190 let split_sample = detok.split(IMAGE_TOKEN).collect::<Vec<_>>();
191 let mut sample = split_sample
192 .first()
193 .expect("The image token <image> should be present in the text.")
194 .to_string();
195 for (i, image_prompt_string) in image_prompt_strings.into_iter().enumerate() {
196 sample.push_str(&format!(
197 "{image_prompt_string}{}",
198 split_sample
199 .get(i + 1)
200 .expect("Incorrect chat template. Use the one provided in `chat_templates` with the `--chat-template`/`chat_template` settings.")
201 ));
202 }
203
204 seq.set_initial_prompt(sample.clone());
205 let toks = tokenizer
206 .encode_fast(sample, false)
207 .expect("Detokenization failed!");
208
209 let ids = toks.get_ids().to_vec();
210 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
211 seq.multimodal.has_changed_prompt = true;
212 }
213 }
214
215 (
216 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
217 Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
218 )
219 } else {
220 (None, None)
221 };
222
223 let text_models_inputs_processor::InnerInputProcessorOutput {
224 inputs:
225 text_models_inputs_processor::InputMetadata {
226 input,
227 positions,
228 context_lens,
229 position_ids,
230 paged_attn_meta,
231 flash_meta,
232 },
233 seq_indices,
234 } = if is_prompt {
235 get_prompt_input(
236 input_seqs
237 .iter()
238 .map(|seq| seq.get_toks())
239 .collect::<Vec<_>>(),
240 input_seqs,
241 device,
242 last_n_context_len,
243 return_raw_logits,
244 paged_attn_metadata.as_mut(),
245 None, mapper,
247 )
248 .nth(0)
249 .unwrap()
250 .unwrap()
251 } else {
252 get_completion_input(
253 input_seqs
254 .iter()
255 .map(|seq| seq.get_toks())
256 .collect::<Vec<_>>(),
257 input_seqs,
258 device,
259 no_kv_cache,
260 last_n_context_len,
261 return_raw_logits,
262 paged_attn_metadata.as_mut(),
263 None, mapper,
265 )
266 .nth(0)
267 .unwrap()
268 .unwrap()
269 };
270
271 let inputs: Box<dyn Any> = Box::new(ModelInputs {
272 input_ids: input,
273 seqlen_offsets: positions,
274 context_lens,
275 position_ids,
276 pixel_values,
277 model_specific_args: Box::new(pixel_attention_mask),
278 paged_attn_meta,
279 flash_meta,
280 });
281 Box::new(std::iter::once(Ok(InputProcessorOutput {
282 inputs,
283 seq_indices,
284 })))
285 }
286}
287
288fn resize_output_size_rescale_to_max_len(
290 height: usize,
291 width: usize,
292 min_len: Option<usize>,
293 max_len: Option<usize>,
294) -> (usize, usize) {
295 let min_len = min_len.unwrap_or(1);
296 let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
297 let aspect_ratio = width as f32 / height as f32;
298 let (mut height, mut width) = (height, width);
299
300 if width >= height {
301 width = max_len;
302 height = (width as f32 / aspect_ratio).round() as usize;
303 if height % 2 != 0 {
304 height += 1;
305 }
306 } else {
307 height = max_len;
308 width = (height as f32 * aspect_ratio).round() as usize;
309 if width % 2 != 0 {
310 width += 1;
311 }
312 }
313
314 height = cmp::max(height, min_len);
315 width = cmp::max(width, min_len);
316
317 (height, width)
318}
319
320fn resize_output_size_scale_below_upper_bound(
322 height: usize,
323 width: usize,
324 max_len: Option<usize>,
325) -> (usize, usize) {
326 let max_len = max_len.unwrap_or_else(|| cmp::max(height, width));
327 let aspect_ratio = width as f32 / height as f32;
328 let (mut height, mut width) = (height, width);
329
330 if width >= height && width > max_len {
331 width = max_len;
332 height = (width as f32 / aspect_ratio).round() as usize;
333 } else if height > width && height > max_len {
334 height = max_len;
335 width = (height as f32 * aspect_ratio).round() as usize;
336 }
337
338 height = cmp::max(height, 1);
339 width = cmp::max(width, 1);
340
341 (height, width)
342}
343
344fn get_resize_output_image_size(
347 (h, w): (usize, usize),
348 resolution_max_side: usize,
349) -> (usize, usize) {
350 let (h, w) = resize_output_size_rescale_to_max_len(h, w, None, Some(resolution_max_side));
351 resize_output_size_scale_below_upper_bound(h, w, Some(MAX_IMAGE_SIZE))
352}
353
354fn resize_for_vision_encoder(
355 (h, w): (usize, usize),
356 vision_encoder_max_size: usize,
357) -> (usize, usize) {
358 let aspect_ratio = w as f32 / h as f32;
359
360 let (new_h, new_w) = if w >= h {
361 let new_w = ((w as f32 / vision_encoder_max_size as f32).ceil()
362 * vision_encoder_max_size as f32) as usize;
363 let mut new_h = (new_w as f32 / aspect_ratio) as usize;
364 new_h = ((new_h as f32 / vision_encoder_max_size as f32).ceil()
365 * vision_encoder_max_size as f32) as usize;
366 (new_h, new_w)
367 } else {
368 let new_h = ((h as f32 / vision_encoder_max_size as f32).ceil()
369 * vision_encoder_max_size as f32) as usize;
370 let mut new_w = (new_h as f32 * aspect_ratio) as usize;
371 new_w = ((new_w as f32 / vision_encoder_max_size as f32).ceil()
372 * vision_encoder_max_size as f32) as usize;
373 (new_h, new_w)
374 };
375
376 (new_h, new_w)
377}
378
379fn resize(
380 image: &DynamicImage,
381 size: &HashMap<String, u32>,
382 resampling: FilterType,
383) -> Result<DynamicImage> {
384 let (h, w) = if size.contains_key("longest_edge") {
385 get_resize_output_image_size(
386 (image.dimensions().1 as usize, image.dimensions().0 as usize),
387 size["longest_edge"] as usize,
388 )
389 } else if size.contains_key("height") && size.contains_key("width") {
390 (size["height"] as usize, size["width"] as usize)
391 } else {
392 candle_core::bail!(
393 "Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`."
394 );
395 };
396
397 Ok(image.resize_exact(w as u32, h as u32, resampling))
398 }
400
401fn split_image(
403 image: &DynamicImage,
404 longest_edge: usize,
405) -> Result<(Vec<DynamicImage>, usize, usize)> {
406 let (width, height) = image.dimensions();
407 let mut frames = Vec::new();
408
409 if width > longest_edge as u32 || height > longest_edge as u32 {
410 let num_splits_h = (height as f64 / (longest_edge as f64)).ceil() as usize;
411 let num_splits_w = (width as f64 / (longest_edge as f64)).ceil() as usize;
412
413 let optimal_height = (height as f64 / num_splits_h as f64).ceil() as u32;
414 let optimal_width = (width as f64 / num_splits_w as f64).ceil() as u32;
415
416 for r in 0..num_splits_h {
417 for c in 0..num_splits_w {
418 let start_x = (c as u32) * optimal_width;
419 let start_y = (r as u32) * optimal_height;
420
421 let end_x = std::cmp::min(start_x + optimal_width, width);
422 let end_y = std::cmp::min(start_y + optimal_height, height);
423
424 let cropped_image =
426 image.crop_imm(start_x, start_y, end_x - start_x, end_y - start_y);
427 frames.push(cropped_image);
428 }
429 }
430
431 let resized_image = resize(
433 image,
434 &HashMap::from([
435 ("height".to_string(), longest_edge as u32),
436 ("width".to_string(), longest_edge as u32),
437 ]),
438 FilterType::Lanczos3,
439 )?;
440 frames.push(resized_image);
441
442 Ok((frames, num_splits_h, num_splits_w))
443 } else {
444 frames.push(image.clone());
445 Ok((frames, 0, 0))
446 }
447}
448
449impl ImagePreProcessor for Idefics3ImageProcessor {
450 #[allow(clippy::excessive_precision)]
451 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
452 #[allow(clippy::excessive_precision)]
453 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
454
455 fn preprocess(
456 &self,
457 mut images: Vec<DynamicImage>,
458 videos: Vec<Vec<DynamicImage>>,
459 config: &PreProcessorConfig,
460 device: &Device,
461 (_bs, _max_num_images): (usize, usize),
462 ) -> Result<PreprocessedImages> {
463 assert!(videos.is_empty());
464
465 let mut patch_masks = Vec::new();
466 let mut pixel_values = Vec::new();
467
468 if let Some(max_edge) = self.max_edge {
469 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
470 }
471
472 for image in images.iter_mut() {
473 if config.do_convert_rgb.is_some_and(|x| x) {
475 *image = DynamicImage::ImageRgb8(image.to_rgb8());
476 }
477
478 if config.do_resize.is_some_and(|x| x) {
480 *image = resize(
481 image,
482 config.size.as_ref().unwrap(),
483 config.resampling.to_filter()?,
484 )?;
485 }
486 }
487
488 let mut image_rows = Vec::new();
489 let mut image_cols = Vec::new();
490 let mut new_images = Vec::new();
491 let max_image_size = config
492 .max_image_size
493 .clone()
494 .unwrap_or_else(|| HashMap::from([("longest_edge".to_string(), 364)]));
495 if config.do_image_splitting.unwrap_or(true) {
496 for image in images.iter_mut() {
500 let (new_h, new_w) = resize_for_vision_encoder(
501 (image.dimensions().1 as usize, image.dimensions().0 as usize),
502 max_image_size["longest_edge"] as usize,
503 );
504
505 *image =
506 image.resize_exact(new_w as u32, new_h as u32, config.resampling.to_filter()?);
507
508 let (split_image_array, rows, cols) =
509 split_image(image, max_image_size["longest_edge"] as usize)?;
510 new_images.extend(split_image_array.into_iter());
511 image_rows.push(rows);
512 image_cols.push(cols);
513 }
514 } else {
515 for image in images.iter_mut() {
517 new_images.push(resize(
518 image,
519 &HashMap::from([
520 ("height".to_string(), max_image_size["longest_edge"]),
521 ("width".to_string(), max_image_size["longest_edge"]),
522 ]),
523 FilterType::Lanczos3,
524 )?);
525 }
526 image_rows = vec![0; images.len()];
527 image_cols = vec![0; images.len()];
528 }
529 images = new_images;
530
531 let mut max_h = 0;
532 let mut max_w = 0;
533 for image in &images {
534 let (w, h) = image.dimensions();
535 if w > max_w {
536 max_w = w;
537 }
538 if h > max_h {
539 max_h = h;
540 }
541 }
542
543 for image in images.iter_mut() {
544 let transforms = Transforms {
545 input: &ToTensorNoNorm,
546 inner_transforms: &[
547 &config
548 .do_rescale
549 .is_some_and(|x| x)
550 .then_some(())
551 .map(|_| Rescale {
552 factor: config.rescale_factor,
553 }),
554 &config
555 .do_normalize
556 .is_some_and(|x| x)
557 .then_some(())
558 .map(|_| Normalize {
559 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
560 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
561 }),
562 ],
563 };
564
565 let mut image = image.apply(transforms, device)?;
566 if config.do_pad.is_some_and(|x| x) {
568 let (_c, h, w) = image.dims3()?;
569 let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
570 let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
571 patch_masks.push(mask.unsqueeze(0)?);
572 image = padded;
573 }
574
575 pixel_values.push(image.unsqueeze(0)?)
577 }
578
579 Ok(PreprocessedImages {
580 pixel_values: Tensor::cat(&pixel_values, 0)?,
581 pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
582 image_sizes: None,
583 num_img_tokens: None,
584 aspect_ratio_ids: None,
585 aspect_ratio_mask: None,
586 num_tiles: None,
587 image_grid_thw: None,
588 video_grid_thw: None,
589 rows: Some(image_rows),
590 cols: Some(image_cols),
591 pixel_values_list: None,
592 tgt_sizes: None,
593 image_sizes_all: None,
594 num_crops: None,
595 })
596 }
597}