mistralrs_core/vision_models/idefics2/
idefics2_input_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use indexmap::IndexMap;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use tokenizers::Tokenizer;
10
11use crate::{
12    device_map::DeviceMapper,
13    pipeline::{
14        apply_chat_template,
15        text_models_inputs_processor::{
16            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
17        },
18        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
19    },
20    sequence::Sequence,
21    vision_models::ModelInputs,
22    MessageContent, Pipeline, Tool,
23};
24
25use crate::vision_models::{
26    image_processor::{ImagePreProcessor, PreprocessedImages},
27    preprocessor_config::{PreProcessorConfig, ToFilter},
28    processor_config::ProcessorConfig,
29};
30
31// Input processor
32pub struct Idefics2ImageProcessor {
33    max_edge: Option<u32>,
34}
35// Processor
36pub struct Idefics2Processor {
37    config: ProcessorConfig,
38    preprocessor_config: PreProcessorConfig,
39    fake_image_token: &'static str,
40    image_token: &'static str,
41    max_edge: Option<u32>,
42}
43
44impl Idefics2Processor {
45    pub fn new(
46        config: ProcessorConfig,
47        preprocessor_config: PreProcessorConfig,
48        max_edge: Option<u32>,
49    ) -> Self {
50        Self {
51            config,
52            preprocessor_config,
53            fake_image_token: "<fake_token_around_image>",
54            image_token: "<image>",
55            max_edge,
56        }
57    }
58}
59
60impl Processor for Idefics2Processor {
61    fn process(
62        &self,
63        pipeline: &dyn Pipeline,
64        messages: Vec<IndexMap<String, MessageContent>>,
65        add_generation_prompt: bool,
66        add_special_tokens: bool,
67        enable_thinking: Option<bool>,
68        tools: Vec<Tool>,
69    ) -> anyhow::Result<(Vec<u32>, String)> {
70        let mut prompt = apply_chat_template(
71            pipeline,
72            messages,
73            add_generation_prompt,
74            enable_thinking,
75            self.template_action(),
76            tools,
77        )?;
78
79        let mut image_str = format!(
80            "{}{}{}",
81            self.fake_image_token,
82            self.image_token.repeat(
83                self.config
84                    .image_seq_len
85                    .expect("Idefics 2 model needs `image_seq_len`")
86            ),
87            self.fake_image_token
88        );
89        if self
90            .preprocessor_config
91            .do_image_splitting
92            .is_some_and(|x| x)
93        {
94            // 4 patches + 1 original
95            image_str = image_str.repeat(5);
96        }
97
98        prompt = prompt.replace(self.image_token, &image_str);
99        // Deal with any adjacent images.
100        prompt = prompt.replace(
101            &format!("{}{}", self.fake_image_token, self.fake_image_token),
102            self.fake_image_token,
103        );
104
105        let Some(tokenizer) = &pipeline.tokenizer() else {
106            anyhow::bail!("Idefics2InputProcessor requires a specified tokenizer.",);
107        };
108        let encoding = tokenizer
109            .encode_fast(prompt.clone(), add_special_tokens)
110            .map_err(anyhow::Error::msg)?;
111        Ok((encoding.get_ids().to_vec(), prompt))
112    }
113
114    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
115        Arc::new(Idefics2ImageProcessor {
116            max_edge: self.max_edge,
117        })
118    }
119
120    fn get_special_tokens(&self) -> &[&'static str] {
121        &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
122    }
123
124    fn template_action(&self) -> MessagesAction {
125        MessagesAction::Keep
126    }
127}
128
129impl InputsProcessor for Idefics2ImageProcessor {
130    fn get_type(&self) -> InputsProcessorType {
131        InputsProcessorType::Vision
132    }
133    fn process_inputs(
134        &self,
135        _: Option<Arc<Tokenizer>>,
136        input_seqs: &mut [&mut Sequence],
137        is_prompt: bool,
138        is_xlora: bool,
139        device: &Device,
140        no_kv_cache: bool,
141        last_n_context_len: Option<(usize, usize)>,
142        return_raw_logits: bool,
143        other_config: Option<Arc<dyn Any>>,
144        mut paged_attn_metadata: Option<PagedAttentionMeta>,
145        mapper: Option<&dyn DeviceMapper>,
146    ) -> anyhow::Result<InputProcessorOutput> {
147        if is_xlora {
148            return Err(anyhow::Error::msg(
149                "Cannot make inputs for X-LoRA vision model.",
150            ));
151        }
152        if no_kv_cache {
153            return Err(anyhow::Error::msg("Vision model must have kv cache."));
154        }
155
156        let text_models_inputs_processor::InnerInputProcessorOutput {
157            inputs:
158                text_models_inputs_processor::InputMetadata {
159                    input,
160                    positions,
161                    context_lens,
162                    position_ids,
163                    paged_attn_meta,
164                    flash_meta,
165                },
166            seq_indices,
167        } = if is_prompt {
168            get_prompt_input(
169                input_seqs
170                    .iter()
171                    .map(|seq| seq.get_toks())
172                    .collect::<Vec<_>>(),
173                input_seqs,
174                device,
175                last_n_context_len,
176                return_raw_logits,
177                paged_attn_metadata.as_mut(),
178                mapper,
179            )
180            .unwrap()
181        } else {
182            get_completion_input(
183                input_seqs
184                    .iter()
185                    .map(|seq| seq.get_toks())
186                    .collect::<Vec<_>>(),
187                input_seqs,
188                device,
189                no_kv_cache,
190                last_n_context_len,
191                return_raw_logits,
192                paged_attn_metadata.as_mut(),
193                mapper,
194            )
195            .unwrap()
196        };
197        let config = other_config.expect("Need a PreProcessorConfig config.");
198        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
199
200        let has_images = input_seqs.iter().all(|seq| seq.has_images());
201
202        let (pixel_values, pixel_attention_mask) = if has_images {
203            let mut pixel_values_accum = Vec::new();
204            let mut pixel_attention_mask_accum = Vec::new();
205            for seq in input_seqs.iter_mut() {
206                let PreprocessedImages {
207                    pixel_values,
208                    pixel_attention_mask,
209                    image_sizes: _,
210                    num_img_tokens: _,
211                    aspect_ratio_ids: _,
212                    aspect_ratio_mask: _,
213                    num_tiles: _,
214                    image_grid_thw: _,
215                    video_grid_thw: _,
216                    rows: _,
217                    cols: _,
218                    pixel_values_list: _,
219                    tgt_sizes: _,
220                    image_sizes_all: _,
221                    num_crops: _,
222                } = self
223                    .preprocess(
224                        seq.take_images()
225                            .expect("Need to have images by this point."),
226                        vec![],
227                        config,
228                        device,
229                        (usize::MAX, usize::MAX), // Don't use it here...
230                    )
231                    .expect("Preprocessing failed");
232                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
233                pixel_attention_mask_accum
234                    .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
235            }
236            (
237                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
238                Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
239            )
240        } else {
241            (None, None)
242        };
243
244        let inputs: Box<dyn Any> = Box::new(ModelInputs {
245            input_ids: input,
246            seqlen_offsets: positions,
247            context_lens,
248            position_ids,
249            pixel_values,
250            model_specific_args: Box::new(pixel_attention_mask),
251            paged_attn_meta,
252            flash_meta,
253        });
254        Ok(InputProcessorOutput {
255            inputs,
256            seq_indices,
257        })
258    }
259}
260
261impl ImagePreProcessor for Idefics2ImageProcessor {
262    #[allow(clippy::excessive_precision)]
263    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
264    #[allow(clippy::excessive_precision)]
265    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
266
267    fn preprocess(
268        &self,
269        mut images: Vec<DynamicImage>,
270        videos: Vec<Vec<DynamicImage>>,
271        config: &PreProcessorConfig,
272        device: &Device,
273        (_bs, _max_num_images): (usize, usize),
274    ) -> Result<PreprocessedImages> {
275        assert!(videos.is_empty());
276
277        let mut patch_masks = Vec::new();
278        let mut pixel_values = Vec::new();
279
280        // Image splitting
281        if config.do_image_splitting.is_some_and(|x| x) {
282            let mut new_images = Vec::new();
283            for image in images {
284                let (w, h) = image.dimensions();
285                let mid_w = w / 2;
286                let mid_h = h / 2;
287                new_images.push(image.crop_imm(0, 0, mid_w, mid_h));
288                new_images.push(image.crop_imm(mid_w, 0, w, mid_h));
289                new_images.push(image.crop_imm(0, mid_h, mid_w, h));
290                new_images.push(image.crop_imm(mid_w, mid_h, w, h));
291                new_images.push(image);
292            }
293            images = new_images;
294        }
295
296        for image in images.iter_mut() {
297            // Resize
298            if config.do_resize.is_some_and(|x| x) {
299                let size = config.size.as_ref().unwrap();
300                let (h, w) = if size.contains_key("shortest_edge")
301                    && size.contains_key("longest_edge")
302                {
303                    mistralrs_vision::get_resize_image_size(
304                        (image.dimensions().1 as usize, image.dimensions().0 as usize),
305                        (
306                            size["shortest_edge"] as usize,
307                            size["longest_edge"] as usize,
308                        ),
309                    )
310                } else if size.contains_key("height") && size.contains_key("width") {
311                    (size["height"] as usize, size["width"] as usize)
312                } else {
313                    candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`.");
314                };
315
316                *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?);
317            }
318        }
319
320        if let Some(max_edge) = self.max_edge {
321            images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
322        }
323
324        let mut max_h = 0;
325        let mut max_w = 0;
326        for image in &images {
327            let (w, h) = image.dimensions();
328            if w > max_w {
329                max_w = w;
330            }
331            if h > max_h {
332                max_h = h;
333            }
334        }
335
336        for image in images.iter_mut() {
337            // Convert to rgb
338            if config.do_convert_rgb.is_some_and(|x| x) {
339                *image = DynamicImage::ImageRgb8(image.to_rgb8());
340            }
341
342            let transforms = Transforms {
343                input: &ToTensorNoNorm,
344                inner_transforms: &[
345                    &config
346                        .do_rescale
347                        .is_some_and(|x| x)
348                        .then_some(())
349                        .map(|_| Rescale {
350                            factor: config.rescale_factor,
351                        }),
352                    &config
353                        .do_normalize
354                        .is_some_and(|x| x)
355                        .then_some(())
356                        .map(|_| Normalize {
357                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
358                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
359                        }),
360                ],
361            };
362
363            let mut image = image.apply(transforms, device)?;
364            // Pad images, calculating attention mask.
365            if config.do_pad.is_some_and(|x| x) {
366                let (_c, h, w) = image.dims3()?;
367                let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
368                let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
369                patch_masks.push(mask.unsqueeze(0)?);
370                image = padded;
371            }
372
373            // Get pixel values
374            pixel_values.push(image.unsqueeze(0)?)
375        }
376
377        Ok(PreprocessedImages {
378            pixel_values: Tensor::cat(&pixel_values, 0)?,
379            pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
380            image_sizes: None,
381            num_img_tokens: None,
382            aspect_ratio_ids: None,
383            aspect_ratio_mask: None,
384            num_tiles: None,
385            image_grid_thw: None,
386            video_grid_thw: None,
387            rows: None,
388            cols: None,
389            pixel_values_list: None,
390            tgt_sizes: None,
391            image_sizes_all: None,
392            num_crops: None,
393        })
394    }
395}