1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use indexmap::IndexMap;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13 device_map::DeviceMapper,
14 pipeline::{
15 apply_chat_template,
16 text_models_inputs_processor::{
17 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18 },
19 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20 },
21 sequence::Sequence,
22 vision_models::ModelInputs,
23 MessageContent, Pipeline, Tool,
24};
25
26use crate::vision_models::{
27 image_processor::{ImagePreProcessor, PreprocessedImages},
28 preprocessor_config::{PreProcessorConfig, ToFilter},
29 processor_config::ProcessorConfig,
30};
31
32pub struct Idefics2ImageProcessor {
34 max_edge: Option<u32>,
35}
36pub struct Idefics2Processor {
38 config: ProcessorConfig,
39 preprocessor_config: PreProcessorConfig,
40 fake_image_token: &'static str,
41 image_token: &'static str,
42 max_edge: Option<u32>,
43}
44
45impl Idefics2Processor {
46 pub fn new(
47 config: ProcessorConfig,
48 preprocessor_config: PreProcessorConfig,
49 max_edge: Option<u32>,
50 ) -> Self {
51 Self {
52 config,
53 preprocessor_config,
54 fake_image_token: "<fake_token_around_image>",
55 image_token: "<image>",
56 max_edge,
57 }
58 }
59}
60
61impl Processor for Idefics2Processor {
62 fn process(
63 &self,
64 pipeline: &dyn Pipeline,
65 messages: Vec<IndexMap<String, MessageContent>>,
66 add_generation_prompt: bool,
67 add_special_tokens: bool,
68 tools: Vec<Tool>,
69 ) -> anyhow::Result<(Vec<u32>, String)> {
70 let mut prompt = apply_chat_template(
71 pipeline,
72 messages,
73 add_generation_prompt,
74 self.template_action(),
75 tools,
76 )?;
77
78 let mut image_str = format!(
79 "{}{}{}",
80 self.fake_image_token,
81 self.image_token.repeat(
82 self.config
83 .image_seq_len
84 .expect("Idefics 2 model needs `image_seq_len`")
85 ),
86 self.fake_image_token
87 );
88 if self
89 .preprocessor_config
90 .do_image_splitting
91 .is_some_and(|x| x)
92 {
93 image_str = image_str.repeat(5);
95 }
96
97 prompt = prompt.replace(self.image_token, &image_str);
98 prompt = prompt.replace(
100 &format!("{}{}", self.fake_image_token, self.fake_image_token),
101 self.fake_image_token,
102 );
103
104 let Some(tokenizer) = &pipeline.tokenizer() else {
105 anyhow::bail!("Idefics2InputProcessor requires a specified tokenizer.",);
106 };
107 let encoding = tokenizer
108 .encode_fast(prompt.clone(), add_special_tokens)
109 .map_err(anyhow::Error::msg)?;
110 Ok((encoding.get_ids().to_vec(), prompt))
111 }
112
113 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
114 Arc::new(Idefics2ImageProcessor {
115 max_edge: self.max_edge,
116 })
117 }
118
119 fn get_special_tokens(&self) -> &[&'static str] {
120 &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
121 }
122
123 fn template_action(&self) -> MessagesAction {
124 MessagesAction::Keep
125 }
126}
127
128impl InputsProcessor for Idefics2ImageProcessor {
129 fn get_type(&self) -> InputsProcessorType {
130 InputsProcessorType::Vision
131 }
132 fn process_inputs(
133 &self,
134 _: Option<Arc<Tokenizer>>,
135 input_seqs: &mut [&mut Sequence],
136 is_prompt: bool,
137 is_xlora: bool,
138 device: &Device,
139 no_kv_cache: bool,
140 last_n_context_len: Option<(usize, usize)>,
141 return_raw_logits: bool,
142 other_config: Option<Arc<dyn Any>>,
143 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
144 prompt_chunksize: Option<NonZeroUsize>,
145 mapper: Option<&dyn DeviceMapper>,
146 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
147 if is_xlora {
148 return Box::new(std::iter::once(Err(anyhow::Error::msg(
149 "Cannot make inputs for X-LoRA vision model.",
150 ))));
151 }
152 if no_kv_cache {
153 return Box::new(std::iter::once(Err(anyhow::Error::msg(
154 "Vision model must have kv cache.",
155 ))));
156 }
157 if prompt_chunksize.is_some() {
159 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
160 }
161
162 let text_models_inputs_processor::InnerInputProcessorOutput {
163 inputs:
164 text_models_inputs_processor::InputMetadata {
165 input,
166 positions,
167 context_lens,
168 position_ids,
169 paged_attn_meta,
170 flash_meta,
171 },
172 seq_indices,
173 } = if is_prompt {
174 get_prompt_input(
175 input_seqs
176 .iter()
177 .map(|seq| seq.get_toks().to_vec())
178 .collect::<Vec<_>>(),
179 input_seqs,
180 device,
181 last_n_context_len,
182 return_raw_logits,
183 paged_attn_metadata.as_mut(),
184 None, mapper,
186 )
187 .nth(0)
188 .unwrap()
189 .unwrap()
190 } else {
191 get_completion_input(
192 input_seqs
193 .iter()
194 .map(|seq| seq.get_toks().to_vec())
195 .collect::<Vec<_>>(),
196 input_seqs,
197 device,
198 no_kv_cache,
199 last_n_context_len,
200 return_raw_logits,
201 paged_attn_metadata.as_mut(),
202 None, mapper,
204 )
205 .nth(0)
206 .unwrap()
207 .unwrap()
208 };
209 let config = other_config.expect("Need a PreProcessorConfig config.");
210 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
211
212 let has_images = input_seqs.iter().all(|seq| seq.has_images());
213
214 let (pixel_values, pixel_attention_mask) = if has_images {
215 let mut pixel_values_accum = Vec::new();
216 let mut pixel_attention_mask_accum = Vec::new();
217 for seq in input_seqs.iter_mut() {
218 let PreprocessedImages {
219 pixel_values,
220 pixel_attention_mask,
221 image_sizes: _,
222 num_img_tokens: _,
223 aspect_ratio_ids: _,
224 aspect_ratio_mask: _,
225 num_tiles: _,
226 image_grid_thw: _,
227 video_grid_thw: _,
228 rows: _,
229 cols: _,
230 pixel_values_list: _,
231 tgt_sizes: _,
232 image_sizes_all: _,
233 num_crops: _,
234 } = self
235 .preprocess(
236 seq.take_images()
237 .expect("Need to have images by this point."),
238 vec![],
239 config,
240 device,
241 (usize::MAX, usize::MAX), )
243 .expect("Preprocessing failed");
244 pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
245 pixel_attention_mask_accum
246 .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
247 }
248 (
249 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
250 Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
251 )
252 } else {
253 (None, None)
254 };
255
256 let inputs: Box<dyn Any> = Box::new(ModelInputs {
257 input_ids: input,
258 seqlen_offsets: positions,
259 context_lens,
260 position_ids,
261 pixel_values,
262 model_specific_args: Box::new(pixel_attention_mask),
263 paged_attn_meta,
264 flash_meta,
265 });
266 Box::new(std::iter::once(Ok(InputProcessorOutput {
267 inputs,
268 seq_indices,
269 })))
270 }
271}
272
273impl ImagePreProcessor for Idefics2ImageProcessor {
274 #[allow(clippy::excessive_precision)]
275 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
276 #[allow(clippy::excessive_precision)]
277 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
278
279 fn preprocess(
280 &self,
281 mut images: Vec<DynamicImage>,
282 videos: Vec<Vec<DynamicImage>>,
283 config: &PreProcessorConfig,
284 device: &Device,
285 (_bs, _max_num_images): (usize, usize),
286 ) -> Result<PreprocessedImages> {
287 assert!(videos.is_empty());
288
289 let mut patch_masks = Vec::new();
290 let mut pixel_values = Vec::new();
291
292 if config.do_image_splitting.is_some_and(|x| x) {
294 let mut new_images = Vec::new();
295 for image in images {
296 let (w, h) = image.dimensions();
297 let mid_w = w / 2;
298 let mid_h = h / 2;
299 new_images.push(image.crop_imm(0, 0, mid_w, mid_h));
300 new_images.push(image.crop_imm(mid_w, 0, w, mid_h));
301 new_images.push(image.crop_imm(0, mid_h, mid_w, h));
302 new_images.push(image.crop_imm(mid_w, mid_h, w, h));
303 new_images.push(image);
304 }
305 images = new_images;
306 }
307
308 for image in images.iter_mut() {
309 if config.do_resize.is_some_and(|x| x) {
311 let size = config.size.as_ref().unwrap();
312 let (h, w) = if size.contains_key("shortest_edge")
313 && size.contains_key("longest_edge")
314 {
315 mistralrs_vision::get_resize_image_size(
316 (image.dimensions().1 as usize, image.dimensions().0 as usize),
317 (
318 size["shortest_edge"] as usize,
319 size["longest_edge"] as usize,
320 ),
321 )
322 } else if size.contains_key("height") && size.contains_key("width") {
323 (size["height"] as usize, size["width"] as usize)
324 } else {
325 candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`.");
326 };
327
328 *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?);
329 }
330 }
331
332 if let Some(max_edge) = self.max_edge {
333 images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
334 }
335
336 let mut max_h = 0;
337 let mut max_w = 0;
338 for image in &images {
339 let (w, h) = image.dimensions();
340 if w > max_w {
341 max_w = w;
342 }
343 if h > max_h {
344 max_h = h;
345 }
346 }
347
348 for image in images.iter_mut() {
349 if config.do_convert_rgb.is_some_and(|x| x) {
351 *image = DynamicImage::ImageRgb8(image.to_rgb8());
352 }
353
354 let transforms = Transforms {
355 input: &ToTensorNoNorm,
356 inner_transforms: &[
357 &config
358 .do_rescale
359 .is_some_and(|x| x)
360 .then_some(())
361 .map(|_| Rescale {
362 factor: config.rescale_factor,
363 }),
364 &config
365 .do_normalize
366 .is_some_and(|x| x)
367 .then_some(())
368 .map(|_| Normalize {
369 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
370 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
371 }),
372 ],
373 };
374
375 let mut image = image.apply(transforms, device)?;
376 if config.do_pad.is_some_and(|x| x) {
378 let (_c, h, w) = image.dims3()?;
379 let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
380 let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
381 patch_masks.push(mask.unsqueeze(0)?);
382 image = padded;
383 }
384
385 pixel_values.push(image.unsqueeze(0)?)
387 }
388
389 Ok(PreprocessedImages {
390 pixel_values: Tensor::cat(&pixel_values, 0)?,
391 pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
392 image_sizes: None,
393 num_img_tokens: None,
394 aspect_ratio_ids: None,
395 aspect_ratio_mask: None,
396 num_tiles: None,
397 image_grid_thw: None,
398 video_grid_thw: None,
399 rows: None,
400 cols: None,
401 pixel_values_list: None,
402 tgt_sizes: None,
403 image_sizes_all: None,
404 num_crops: None,
405 })
406 }
407}