mistralrs_core/vision_models/idefics2/
idefics2_input_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use indexmap::IndexMap;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13    device_map::DeviceMapper,
14    pipeline::{
15        apply_chat_template,
16        text_models_inputs_processor::{
17            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18        },
19        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20    },
21    sequence::Sequence,
22    vision_models::ModelInputs,
23    MessageContent, Pipeline, Tool,
24};
25
26use crate::vision_models::{
27    image_processor::{ImagePreProcessor, PreprocessedImages},
28    preprocessor_config::{PreProcessorConfig, ToFilter},
29    processor_config::ProcessorConfig,
30};
31
32// Input processor
33pub struct Idefics2ImageProcessor {
34    max_edge: Option<u32>,
35}
36// Processor
37pub struct Idefics2Processor {
38    config: ProcessorConfig,
39    preprocessor_config: PreProcessorConfig,
40    fake_image_token: &'static str,
41    image_token: &'static str,
42    max_edge: Option<u32>,
43}
44
45impl Idefics2Processor {
46    pub fn new(
47        config: ProcessorConfig,
48        preprocessor_config: PreProcessorConfig,
49        max_edge: Option<u32>,
50    ) -> Self {
51        Self {
52            config,
53            preprocessor_config,
54            fake_image_token: "<fake_token_around_image>",
55            image_token: "<image>",
56            max_edge,
57        }
58    }
59}
60
61impl Processor for Idefics2Processor {
62    fn process(
63        &self,
64        pipeline: &dyn Pipeline,
65        messages: Vec<IndexMap<String, MessageContent>>,
66        add_generation_prompt: bool,
67        add_special_tokens: bool,
68        enable_thinking: Option<bool>,
69        tools: Vec<Tool>,
70    ) -> anyhow::Result<(Vec<u32>, String)> {
71        let mut prompt = apply_chat_template(
72            pipeline,
73            messages,
74            add_generation_prompt,
75            enable_thinking,
76            self.template_action(),
77            tools,
78        )?;
79
80        let mut image_str = format!(
81            "{}{}{}",
82            self.fake_image_token,
83            self.image_token.repeat(
84                self.config
85                    .image_seq_len
86                    .expect("Idefics 2 model needs `image_seq_len`")
87            ),
88            self.fake_image_token
89        );
90        if self
91            .preprocessor_config
92            .do_image_splitting
93            .is_some_and(|x| x)
94        {
95            // 4 patches + 1 original
96            image_str = image_str.repeat(5);
97        }
98
99        prompt = prompt.replace(self.image_token, &image_str);
100        // Deal with any adjacent images.
101        prompt = prompt.replace(
102            &format!("{}{}", self.fake_image_token, self.fake_image_token),
103            self.fake_image_token,
104        );
105
106        let Some(tokenizer) = &pipeline.tokenizer() else {
107            anyhow::bail!("Idefics2InputProcessor requires a specified tokenizer.",);
108        };
109        let encoding = tokenizer
110            .encode_fast(prompt.clone(), add_special_tokens)
111            .map_err(anyhow::Error::msg)?;
112        Ok((encoding.get_ids().to_vec(), prompt))
113    }
114
115    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
116        Arc::new(Idefics2ImageProcessor {
117            max_edge: self.max_edge,
118        })
119    }
120
121    fn get_special_tokens(&self) -> &[&'static str] {
122        &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
123    }
124
125    fn template_action(&self) -> MessagesAction {
126        MessagesAction::Keep
127    }
128}
129
130impl InputsProcessor for Idefics2ImageProcessor {
131    fn get_type(&self) -> InputsProcessorType {
132        InputsProcessorType::Vision
133    }
134    fn process_inputs(
135        &self,
136        _: Option<Arc<Tokenizer>>,
137        input_seqs: &mut [&mut Sequence],
138        is_prompt: bool,
139        is_xlora: bool,
140        device: &Device,
141        no_kv_cache: bool,
142        last_n_context_len: Option<(usize, usize)>,
143        return_raw_logits: bool,
144        other_config: Option<Arc<dyn Any>>,
145        mut paged_attn_metadata: Option<PagedAttentionMeta>,
146        prompt_chunksize: Option<NonZeroUsize>,
147        mapper: Option<&dyn DeviceMapper>,
148    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
149        if is_xlora {
150            return Box::new(std::iter::once(Err(anyhow::Error::msg(
151                "Cannot make inputs for X-LoRA vision model.",
152            ))));
153        }
154        if no_kv_cache {
155            return Box::new(std::iter::once(Err(anyhow::Error::msg(
156                "Vision model must have kv cache.",
157            ))));
158        }
159        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
160        if prompt_chunksize.is_some() {
161            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
162        }
163
164        let text_models_inputs_processor::InnerInputProcessorOutput {
165            inputs:
166                text_models_inputs_processor::InputMetadata {
167                    input,
168                    positions,
169                    context_lens,
170                    position_ids,
171                    paged_attn_meta,
172                    flash_meta,
173                },
174            seq_indices,
175        } = if is_prompt {
176            get_prompt_input(
177                input_seqs
178                    .iter()
179                    .map(|seq| seq.get_toks())
180                    .collect::<Vec<_>>(),
181                input_seqs,
182                device,
183                last_n_context_len,
184                return_raw_logits,
185                paged_attn_metadata.as_mut(),
186                None, // TODO: evaluate if it is possible to batch this
187                mapper,
188            )
189            .nth(0)
190            .unwrap()
191            .unwrap()
192        } else {
193            get_completion_input(
194                input_seqs
195                    .iter()
196                    .map(|seq| seq.get_toks())
197                    .collect::<Vec<_>>(),
198                input_seqs,
199                device,
200                no_kv_cache,
201                last_n_context_len,
202                return_raw_logits,
203                paged_attn_metadata.as_mut(),
204                None, // TODO: evaluate if it is possible to batch this
205                mapper,
206            )
207            .nth(0)
208            .unwrap()
209            .unwrap()
210        };
211        let config = other_config.expect("Need a PreProcessorConfig config.");
212        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
213
214        let has_images = input_seqs.iter().all(|seq| seq.has_images());
215
216        let (pixel_values, pixel_attention_mask) = if has_images {
217            let mut pixel_values_accum = Vec::new();
218            let mut pixel_attention_mask_accum = Vec::new();
219            for seq in input_seqs.iter_mut() {
220                let PreprocessedImages {
221                    pixel_values,
222                    pixel_attention_mask,
223                    image_sizes: _,
224                    num_img_tokens: _,
225                    aspect_ratio_ids: _,
226                    aspect_ratio_mask: _,
227                    num_tiles: _,
228                    image_grid_thw: _,
229                    video_grid_thw: _,
230                    rows: _,
231                    cols: _,
232                    pixel_values_list: _,
233                    tgt_sizes: _,
234                    image_sizes_all: _,
235                    num_crops: _,
236                } = self
237                    .preprocess(
238                        seq.take_images()
239                            .expect("Need to have images by this point."),
240                        vec![],
241                        config,
242                        device,
243                        (usize::MAX, usize::MAX), // Don't use it here...
244                    )
245                    .expect("Preprocessing failed");
246                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
247                pixel_attention_mask_accum
248                    .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
249            }
250            (
251                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
252                Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
253            )
254        } else {
255            (None, None)
256        };
257
258        let inputs: Box<dyn Any> = Box::new(ModelInputs {
259            input_ids: input,
260            seqlen_offsets: positions,
261            context_lens,
262            position_ids,
263            pixel_values,
264            model_specific_args: Box::new(pixel_attention_mask),
265            paged_attn_meta,
266            flash_meta,
267        });
268        Box::new(std::iter::once(Ok(InputProcessorOutput {
269            inputs,
270            seq_indices,
271        })))
272    }
273}
274
275impl ImagePreProcessor for Idefics2ImageProcessor {
276    #[allow(clippy::excessive_precision)]
277    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
278    #[allow(clippy::excessive_precision)]
279    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
280
281    fn preprocess(
282        &self,
283        mut images: Vec<DynamicImage>,
284        videos: Vec<Vec<DynamicImage>>,
285        config: &PreProcessorConfig,
286        device: &Device,
287        (_bs, _max_num_images): (usize, usize),
288    ) -> Result<PreprocessedImages> {
289        assert!(videos.is_empty());
290
291        let mut patch_masks = Vec::new();
292        let mut pixel_values = Vec::new();
293
294        // Image splitting
295        if config.do_image_splitting.is_some_and(|x| x) {
296            let mut new_images = Vec::new();
297            for image in images {
298                let (w, h) = image.dimensions();
299                let mid_w = w / 2;
300                let mid_h = h / 2;
301                new_images.push(image.crop_imm(0, 0, mid_w, mid_h));
302                new_images.push(image.crop_imm(mid_w, 0, w, mid_h));
303                new_images.push(image.crop_imm(0, mid_h, mid_w, h));
304                new_images.push(image.crop_imm(mid_w, mid_h, w, h));
305                new_images.push(image);
306            }
307            images = new_images;
308        }
309
310        for image in images.iter_mut() {
311            // Resize
312            if config.do_resize.is_some_and(|x| x) {
313                let size = config.size.as_ref().unwrap();
314                let (h, w) = if size.contains_key("shortest_edge")
315                    && size.contains_key("longest_edge")
316                {
317                    mistralrs_vision::get_resize_image_size(
318                        (image.dimensions().1 as usize, image.dimensions().0 as usize),
319                        (
320                            size["shortest_edge"] as usize,
321                            size["longest_edge"] as usize,
322                        ),
323                    )
324                } else if size.contains_key("height") && size.contains_key("width") {
325                    (size["height"] as usize, size["width"] as usize)
326                } else {
327                    candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`.");
328                };
329
330                *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?);
331            }
332        }
333
334        if let Some(max_edge) = self.max_edge {
335            images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
336        }
337
338        let mut max_h = 0;
339        let mut max_w = 0;
340        for image in &images {
341            let (w, h) = image.dimensions();
342            if w > max_w {
343                max_w = w;
344            }
345            if h > max_h {
346                max_h = h;
347            }
348        }
349
350        for image in images.iter_mut() {
351            // Convert to rgb
352            if config.do_convert_rgb.is_some_and(|x| x) {
353                *image = DynamicImage::ImageRgb8(image.to_rgb8());
354            }
355
356            let transforms = Transforms {
357                input: &ToTensorNoNorm,
358                inner_transforms: &[
359                    &config
360                        .do_rescale
361                        .is_some_and(|x| x)
362                        .then_some(())
363                        .map(|_| Rescale {
364                            factor: config.rescale_factor,
365                        }),
366                    &config
367                        .do_normalize
368                        .is_some_and(|x| x)
369                        .then_some(())
370                        .map(|_| Normalize {
371                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
372                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
373                        }),
374                ],
375            };
376
377            let mut image = image.apply(transforms, device)?;
378            // Pad images, calculating attention mask.
379            if config.do_pad.is_some_and(|x| x) {
380                let (_c, h, w) = image.dims3()?;
381                let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
382                let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
383                patch_masks.push(mask.unsqueeze(0)?);
384                image = padded;
385            }
386
387            // Get pixel values
388            pixel_values.push(image.unsqueeze(0)?)
389        }
390
391        Ok(PreprocessedImages {
392            pixel_values: Tensor::cat(&pixel_values, 0)?,
393            pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
394            image_sizes: None,
395            num_img_tokens: None,
396            aspect_ratio_ids: None,
397            aspect_ratio_mask: None,
398            num_tiles: None,
399            image_grid_thw: None,
400            video_grid_thw: None,
401            rows: None,
402            cols: None,
403            pixel_values_list: None,
404            tgt_sizes: None,
405            image_sizes_all: None,
406            num_crops: None,
407        })
408    }
409}