1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use indexmap::IndexMap;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use tokenizers::Tokenizer;
10
11use crate::{
12 device_map::DeviceMapper,
13 pipeline::{
14 apply_chat_template,
15 text_models_inputs_processor::{
16 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17 },
18 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19 },
20 sequence::Sequence,
21 vision_models::ModelInputs,
22 MessageContent, Pipeline, Tool,
23};
24
25use crate::vision_models::{
26 image_processor::{ImagePreProcessor, PreprocessedImages},
27 preprocessor_config::{PreProcessorConfig, ToFilter},
28 processor_config::ProcessorConfig,
29};
30
31pub struct Idefics2ImageProcessor {
33 max_edge: Option<u32>,
34}
35pub struct Idefics2Processor {
37 config: ProcessorConfig,
38 preprocessor_config: PreProcessorConfig,
39 fake_image_token: &'static str,
40 image_token: &'static str,
41 max_edge: Option<u32>,
42}
43
44impl Idefics2Processor {
45 pub fn new(
46 config: ProcessorConfig,
47 preprocessor_config: PreProcessorConfig,
48 max_edge: Option<u32>,
49 ) -> Self {
50 Self {
51 config,
52 preprocessor_config,
53 fake_image_token: "<fake_token_around_image>",
54 image_token: "<image>",
55 max_edge,
56 }
57 }
58}
59
60impl Processor for Idefics2Processor {
61 fn process(
62 &self,
63 pipeline: &dyn Pipeline,
64 messages: Vec<IndexMap<String, MessageContent>>,
65 add_generation_prompt: bool,
66 add_special_tokens: bool,
67 enable_thinking: Option<bool>,
68 tools: Vec<Tool>,
69 ) -> anyhow::Result<(Vec<u32>, String)> {
70 let mut prompt = apply_chat_template(
71 pipeline,
72 messages,
73 add_generation_prompt,
74 enable_thinking,
75 self.template_action(),
76 tools,
77 )?;
78
79 let mut image_str = format!(
80 "{}{}{}",
81 self.fake_image_token,
82 self.image_token.repeat(
83 self.config
84 .image_seq_len
85 .expect("Idefics 2 model needs `image_seq_len`")
86 ),
87 self.fake_image_token
88 );
89 if self
90 .preprocessor_config
91 .do_image_splitting
92 .is_some_and(|x| x)
93 {
94 image_str = image_str.repeat(5);
96 }
97
98 prompt = prompt.replace(self.image_token, &image_str);
99 prompt = prompt.replace(
101 &format!("{}{}", self.fake_image_token, self.fake_image_token),
102 self.fake_image_token,
103 );
104
105 let Some(tokenizer) = &pipeline.tokenizer() else {
106 anyhow::bail!("Idefics2InputProcessor requires a specified tokenizer.",);
107 };
108 let encoding = tokenizer
109 .encode_fast(prompt.clone(), add_special_tokens)
110 .map_err(anyhow::Error::msg)?;
111 Ok((encoding.get_ids().to_vec(), prompt))
112 }
113
114 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
115 Arc::new(Idefics2ImageProcessor {
116 max_edge: self.max_edge,
117 })
118 }
119
120 fn get_special_tokens(&self) -> &[&'static str] {
121 &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
122 }
123
124 fn template_action(&self) -> MessagesAction {
125 MessagesAction::Keep
126 }
127}
128
129impl InputsProcessor for Idefics2ImageProcessor {
130 fn get_type(&self) -> InputsProcessorType {
131 InputsProcessorType::Vision
132 }
133 fn process_inputs(
134 &self,
135 _: Option<Arc<Tokenizer>>,
136 input_seqs: &mut [&mut Sequence],
137 is_prompt: bool,
138 is_xlora: bool,
139 device: &Device,
140 no_kv_cache: bool,
141 last_n_context_len: Option<(usize, usize)>,
142 return_raw_logits: bool,
143 other_config: Option<Arc<dyn Any>>,
144 mut paged_attn_metadata: Option<PagedAttentionMeta>,
145 mapper: Option<&dyn DeviceMapper>,
146 ) -> anyhow::Result<InputProcessorOutput> {
147 if is_xlora {
148 return Err(anyhow::Error::msg(
149 "Cannot make inputs for X-LoRA vision model.",
150 ));
151 }
152 if no_kv_cache {
153 return Err(anyhow::Error::msg("Vision model must have kv cache."));
154 }
155
156 let text_models_inputs_processor::InnerInputProcessorOutput {
157 inputs:
158 text_models_inputs_processor::InputMetadata {
159 input,
160 positions,
161 context_lens,
162 position_ids,
163 paged_attn_meta,
164 flash_meta,
165 },
166 seq_indices,
167 } = if is_prompt {
168 get_prompt_input(
169 input_seqs
170 .iter()
171 .map(|seq| seq.get_toks())
172 .collect::<Vec<_>>(),
173 input_seqs,
174 device,
175 last_n_context_len,
176 return_raw_logits,
177 paged_attn_metadata.as_mut(),
178 mapper,
179 )
180 .unwrap()
181 } else {
182 get_completion_input(
183 input_seqs
184 .iter()
185 .map(|seq| seq.get_toks())
186 .collect::<Vec<_>>(),
187 input_seqs,
188 device,
189 no_kv_cache,
190 last_n_context_len,
191 return_raw_logits,
192 paged_attn_metadata.as_mut(),
193 mapper,
194 )
195 .unwrap()
196 };
197 let config = other_config.expect("Need a PreProcessorConfig config.");
198 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
199
200 let has_images = input_seqs.iter().all(|seq| seq.has_images());
201
202 let (pixel_values, pixel_attention_mask) = if has_images {
203 let mut pixel_values_accum = Vec::new();
204 let mut pixel_attention_mask_accum = Vec::new();
205 for seq in input_seqs.iter_mut() {
206 let PreprocessedImages {
207 pixel_values,
208 pixel_attention_mask,
209 image_sizes: _,
210 num_img_tokens: _,
211 aspect_ratio_ids: _,
212 aspect_ratio_mask: _,
213 num_tiles: _,
214 image_grid_thw: _,
215 video_grid_thw: _,
216 rows: _,
217 cols: _,
218 pixel_values_list: _,
219 tgt_sizes: _,
220 image_sizes_all: _,
221 num_crops: _,
222 } = self
223 .preprocess(
224 seq.take_images()
225 .expect("Need to have images by this point."),
226 vec![],
227 config,
228 device,
229 (usize::MAX, usize::MAX), )
231 .expect("Preprocessing failed");
232 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
233 pixel_attention_mask_accum
234 .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
235 }
236 (
237 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
238 Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
239 )
240 } else {
241 (None, None)
242 };
243
244 let inputs: Box<dyn Any> = Box::new(ModelInputs {
245 input_ids: input,
246 seqlen_offsets: positions,
247 context_lens,
248 position_ids,
249 pixel_values,
250 model_specific_args: Box::new(pixel_attention_mask),
251 paged_attn_meta,
252 flash_meta,
253 });
254 Ok(InputProcessorOutput {
255 inputs,
256 seq_indices,
257 })
258 }
259}
260
261impl ImagePreProcessor for Idefics2ImageProcessor {
262 #[allow(clippy::excessive_precision)]
263 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
264 #[allow(clippy::excessive_precision)]
265 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
266
267 fn preprocess(
268 &self,
269 mut images: Vec<DynamicImage>,
270 videos: Vec<Vec<DynamicImage>>,
271 config: &PreProcessorConfig,
272 device: &Device,
273 (_bs, _max_num_images): (usize, usize),
274 ) -> Result<PreprocessedImages> {
275 assert!(videos.is_empty());
276
277 let mut patch_masks = Vec::new();
278 let mut pixel_values = Vec::new();
279
280 if config.do_image_splitting.is_some_and(|x| x) {
282 let mut new_images = Vec::new();
283 for image in images {
284 let (w, h) = image.dimensions();
285 let mid_w = w / 2;
286 let mid_h = h / 2;
287 new_images.push(image.crop_imm(0, 0, mid_w, mid_h));
288 new_images.push(image.crop_imm(mid_w, 0, w, mid_h));
289 new_images.push(image.crop_imm(0, mid_h, mid_w, h));
290 new_images.push(image.crop_imm(mid_w, mid_h, w, h));
291 new_images.push(image);
292 }
293 images = new_images;
294 }
295
296 for image in images.iter_mut() {
297 if config.do_resize.is_some_and(|x| x) {
299 let size = config.size.as_ref().unwrap();
300 let (h, w) = if size.contains_key("shortest_edge")
301 && size.contains_key("longest_edge")
302 {
303 mistralrs_vision::get_resize_image_size(
304 (image.dimensions().1 as usize, image.dimensions().0 as usize),
305 (
306 size["shortest_edge"] as usize,
307 size["longest_edge"] as usize,
308 ),
309 )
310 } else if size.contains_key("height") && size.contains_key("width") {
311 (size["height"] as usize, size["width"] as usize)
312 } else {
313 candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`.");
314 };
315
316 *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?);
317 }
318 }
319
320 if let Some(max_edge) = self.max_edge {
321 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
322 }
323
324 let mut max_h = 0;
325 let mut max_w = 0;
326 for image in &images {
327 let (w, h) = image.dimensions();
328 if w > max_w {
329 max_w = w;
330 }
331 if h > max_h {
332 max_h = h;
333 }
334 }
335
336 for image in images.iter_mut() {
337 if config.do_convert_rgb.is_some_and(|x| x) {
339 *image = DynamicImage::ImageRgb8(image.to_rgb8());
340 }
341
342 let transforms = Transforms {
343 input: &ToTensorNoNorm,
344 inner_transforms: &[
345 &config
346 .do_rescale
347 .is_some_and(|x| x)
348 .then_some(())
349 .map(|_| Rescale {
350 factor: config.rescale_factor,
351 }),
352 &config
353 .do_normalize
354 .is_some_and(|x| x)
355 .then_some(())
356 .map(|_| Normalize {
357 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
358 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
359 }),
360 ],
361 };
362
363 let mut image = image.apply(transforms, device)?;
364 if config.do_pad.is_some_and(|x| x) {
366 let (_c, h, w) = image.dims3()?;
367 let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
368 let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
369 patch_masks.push(mask.unsqueeze(0)?);
370 image = padded;
371 }
372
373 pixel_values.push(image.unsqueeze(0)?)
375 }
376
377 Ok(PreprocessedImages {
378 pixel_values: Tensor::cat(&pixel_values, 0)?,
379 pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
380 image_sizes: None,
381 num_img_tokens: None,
382 aspect_ratio_ids: None,
383 aspect_ratio_mask: None,
384 num_tiles: None,
385 image_grid_thw: None,
386 video_grid_thw: None,
387 rows: None,
388 cols: None,
389 pixel_values_list: None,
390 tgt_sizes: None,
391 image_sizes_all: None,
392 num_crops: None,
393 })
394 }
395}