1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use indexmap::IndexMap;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13 device_map::DeviceMapper,
14 pipeline::{
15 apply_chat_template,
16 text_models_inputs_processor::{
17 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18 },
19 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20 },
21 sequence::Sequence,
22 vision_models::ModelInputs,
23 MessageContent, Pipeline, Tool,
24};
25
26use crate::vision_models::{
27 image_processor::{ImagePreProcessor, PreprocessedImages},
28 preprocessor_config::{PreProcessorConfig, ToFilter},
29 processor_config::ProcessorConfig,
30};
31
32pub struct Idefics2ImageProcessor {
34 max_edge: Option<u32>,
35}
36pub struct Idefics2Processor {
38 config: ProcessorConfig,
39 preprocessor_config: PreProcessorConfig,
40 fake_image_token: &'static str,
41 image_token: &'static str,
42 max_edge: Option<u32>,
43}
44
45impl Idefics2Processor {
46 pub fn new(
47 config: ProcessorConfig,
48 preprocessor_config: PreProcessorConfig,
49 max_edge: Option<u32>,
50 ) -> Self {
51 Self {
52 config,
53 preprocessor_config,
54 fake_image_token: "<fake_token_around_image>",
55 image_token: "<image>",
56 max_edge,
57 }
58 }
59}
60
61impl Processor for Idefics2Processor {
62 fn process(
63 &self,
64 pipeline: &dyn Pipeline,
65 messages: Vec<IndexMap<String, MessageContent>>,
66 add_generation_prompt: bool,
67 add_special_tokens: bool,
68 enable_thinking: Option<bool>,
69 tools: Vec<Tool>,
70 ) -> anyhow::Result<(Vec<u32>, String)> {
71 let mut prompt = apply_chat_template(
72 pipeline,
73 messages,
74 add_generation_prompt,
75 enable_thinking,
76 self.template_action(),
77 tools,
78 )?;
79
80 let mut image_str = format!(
81 "{}{}{}",
82 self.fake_image_token,
83 self.image_token.repeat(
84 self.config
85 .image_seq_len
86 .expect("Idefics 2 model needs `image_seq_len`")
87 ),
88 self.fake_image_token
89 );
90 if self
91 .preprocessor_config
92 .do_image_splitting
93 .is_some_and(|x| x)
94 {
95 image_str = image_str.repeat(5);
97 }
98
99 prompt = prompt.replace(self.image_token, &image_str);
100 prompt = prompt.replace(
102 &format!("{}{}", self.fake_image_token, self.fake_image_token),
103 self.fake_image_token,
104 );
105
106 let Some(tokenizer) = &pipeline.tokenizer() else {
107 anyhow::bail!("Idefics2InputProcessor requires a specified tokenizer.",);
108 };
109 let encoding = tokenizer
110 .encode_fast(prompt.clone(), add_special_tokens)
111 .map_err(anyhow::Error::msg)?;
112 Ok((encoding.get_ids().to_vec(), prompt))
113 }
114
115 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
116 Arc::new(Idefics2ImageProcessor {
117 max_edge: self.max_edge,
118 })
119 }
120
121 fn get_special_tokens(&self) -> &[&'static str] {
122 &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
123 }
124
125 fn template_action(&self) -> MessagesAction {
126 MessagesAction::Keep
127 }
128}
129
130impl InputsProcessor for Idefics2ImageProcessor {
131 fn get_type(&self) -> InputsProcessorType {
132 InputsProcessorType::Vision
133 }
134 fn process_inputs(
135 &self,
136 _: Option<Arc<Tokenizer>>,
137 input_seqs: &mut [&mut Sequence],
138 is_prompt: bool,
139 is_xlora: bool,
140 device: &Device,
141 no_kv_cache: bool,
142 last_n_context_len: Option<(usize, usize)>,
143 return_raw_logits: bool,
144 other_config: Option<Arc<dyn Any>>,
145 mut paged_attn_metadata: Option<PagedAttentionMeta>,
146 prompt_chunksize: Option<NonZeroUsize>,
147 mapper: Option<&dyn DeviceMapper>,
148 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
149 if is_xlora {
150 return Box::new(std::iter::once(Err(anyhow::Error::msg(
151 "Cannot make inputs for X-LoRA vision model.",
152 ))));
153 }
154 if no_kv_cache {
155 return Box::new(std::iter::once(Err(anyhow::Error::msg(
156 "Vision model must have kv cache.",
157 ))));
158 }
159 if prompt_chunksize.is_some() {
161 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
162 }
163
164 let text_models_inputs_processor::InnerInputProcessorOutput {
165 inputs:
166 text_models_inputs_processor::InputMetadata {
167 input,
168 positions,
169 context_lens,
170 position_ids,
171 paged_attn_meta,
172 flash_meta,
173 },
174 seq_indices,
175 } = if is_prompt {
176 get_prompt_input(
177 input_seqs
178 .iter()
179 .map(|seq| seq.get_toks())
180 .collect::<Vec<_>>(),
181 input_seqs,
182 device,
183 last_n_context_len,
184 return_raw_logits,
185 paged_attn_metadata.as_mut(),
186 None, mapper,
188 )
189 .nth(0)
190 .unwrap()
191 .unwrap()
192 } else {
193 get_completion_input(
194 input_seqs
195 .iter()
196 .map(|seq| seq.get_toks())
197 .collect::<Vec<_>>(),
198 input_seqs,
199 device,
200 no_kv_cache,
201 last_n_context_len,
202 return_raw_logits,
203 paged_attn_metadata.as_mut(),
204 None, mapper,
206 )
207 .nth(0)
208 .unwrap()
209 .unwrap()
210 };
211 let config = other_config.expect("Need a PreProcessorConfig config.");
212 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
213
214 let has_images = input_seqs.iter().all(|seq| seq.has_images());
215
216 let (pixel_values, pixel_attention_mask) = if has_images {
217 let mut pixel_values_accum = Vec::new();
218 let mut pixel_attention_mask_accum = Vec::new();
219 for seq in input_seqs.iter_mut() {
220 let PreprocessedImages {
221 pixel_values,
222 pixel_attention_mask,
223 image_sizes: _,
224 num_img_tokens: _,
225 aspect_ratio_ids: _,
226 aspect_ratio_mask: _,
227 num_tiles: _,
228 image_grid_thw: _,
229 video_grid_thw: _,
230 rows: _,
231 cols: _,
232 pixel_values_list: _,
233 tgt_sizes: _,
234 image_sizes_all: _,
235 num_crops: _,
236 } = self
237 .preprocess(
238 seq.take_images()
239 .expect("Need to have images by this point."),
240 vec![],
241 config,
242 device,
243 (usize::MAX, usize::MAX), )
245 .expect("Preprocessing failed");
246 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
247 pixel_attention_mask_accum
248 .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
249 }
250 (
251 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
252 Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
253 )
254 } else {
255 (None, None)
256 };
257
258 let inputs: Box<dyn Any> = Box::new(ModelInputs {
259 input_ids: input,
260 seqlen_offsets: positions,
261 context_lens,
262 position_ids,
263 pixel_values,
264 model_specific_args: Box::new(pixel_attention_mask),
265 paged_attn_meta,
266 flash_meta,
267 });
268 Box::new(std::iter::once(Ok(InputProcessorOutput {
269 inputs,
270 seq_indices,
271 })))
272 }
273}
274
275impl ImagePreProcessor for Idefics2ImageProcessor {
276 #[allow(clippy::excessive_precision)]
277 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
278 #[allow(clippy::excessive_precision)]
279 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
280
281 fn preprocess(
282 &self,
283 mut images: Vec<DynamicImage>,
284 videos: Vec<Vec<DynamicImage>>,
285 config: &PreProcessorConfig,
286 device: &Device,
287 (_bs, _max_num_images): (usize, usize),
288 ) -> Result<PreprocessedImages> {
289 assert!(videos.is_empty());
290
291 let mut patch_masks = Vec::new();
292 let mut pixel_values = Vec::new();
293
294 if config.do_image_splitting.is_some_and(|x| x) {
296 let mut new_images = Vec::new();
297 for image in images {
298 let (w, h) = image.dimensions();
299 let mid_w = w / 2;
300 let mid_h = h / 2;
301 new_images.push(image.crop_imm(0, 0, mid_w, mid_h));
302 new_images.push(image.crop_imm(mid_w, 0, w, mid_h));
303 new_images.push(image.crop_imm(0, mid_h, mid_w, h));
304 new_images.push(image.crop_imm(mid_w, mid_h, w, h));
305 new_images.push(image);
306 }
307 images = new_images;
308 }
309
310 for image in images.iter_mut() {
311 if config.do_resize.is_some_and(|x| x) {
313 let size = config.size.as_ref().unwrap();
314 let (h, w) = if size.contains_key("shortest_edge")
315 && size.contains_key("longest_edge")
316 {
317 mistralrs_vision::get_resize_image_size(
318 (image.dimensions().1 as usize, image.dimensions().0 as usize),
319 (
320 size["shortest_edge"] as usize,
321 size["longest_edge"] as usize,
322 ),
323 )
324 } else if size.contains_key("height") && size.contains_key("width") {
325 (size["height"] as usize, size["width"] as usize)
326 } else {
327 candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`.");
328 };
329
330 *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?);
331 }
332 }
333
334 if let Some(max_edge) = self.max_edge {
335 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
336 }
337
338 let mut max_h = 0;
339 let mut max_w = 0;
340 for image in &images {
341 let (w, h) = image.dimensions();
342 if w > max_w {
343 max_w = w;
344 }
345 if h > max_h {
346 max_h = h;
347 }
348 }
349
350 for image in images.iter_mut() {
351 if config.do_convert_rgb.is_some_and(|x| x) {
353 *image = DynamicImage::ImageRgb8(image.to_rgb8());
354 }
355
356 let transforms = Transforms {
357 input: &ToTensorNoNorm,
358 inner_transforms: &[
359 &config
360 .do_rescale
361 .is_some_and(|x| x)
362 .then_some(())
363 .map(|_| Rescale {
364 factor: config.rescale_factor,
365 }),
366 &config
367 .do_normalize
368 .is_some_and(|x| x)
369 .then_some(())
370 .map(|_| Normalize {
371 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
372 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
373 }),
374 ],
375 };
376
377 let mut image = image.apply(transforms, device)?;
378 if config.do_pad.is_some_and(|x| x) {
380 let (_c, h, w) = image.dims3()?;
381 let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
382 let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
383 patch_masks.push(mask.unsqueeze(0)?);
384 image = padded;
385 }
386
387 pixel_values.push(image.unsqueeze(0)?)
389 }
390
391 Ok(PreprocessedImages {
392 pixel_values: Tensor::cat(&pixel_values, 0)?,
393 pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
394 image_sizes: None,
395 num_img_tokens: None,
396 aspect_ratio_ids: None,
397 aspect_ratio_mask: None,
398 num_tiles: None,
399 image_grid_thw: None,
400 video_grid_thw: None,
401 rows: None,
402 cols: None,
403 pixel_values_list: None,
404 tgt_sizes: None,
405 image_sizes_all: None,
406 num_crops: None,
407 })
408 }
409}