mistralrs_core/vision_models/idefics2/
idefics2_input_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{DynamicImage, GenericImageView};
7use indexmap::IndexMap;
8use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
9use tokenizers::Tokenizer;
10use tracing::warn;
11
12use crate::{
13    device_map::DeviceMapper,
14    pipeline::{
15        apply_chat_template,
16        text_models_inputs_processor::{
17            self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18        },
19        InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20    },
21    sequence::Sequence,
22    vision_models::ModelInputs,
23    MessageContent, Pipeline, Tool,
24};
25
26use crate::vision_models::{
27    image_processor::{ImagePreProcessor, PreprocessedImages},
28    preprocessor_config::{PreProcessorConfig, ToFilter},
29    processor_config::ProcessorConfig,
30};
31
32// Input processor
33pub struct Idefics2ImageProcessor {
34    max_edge: Option<u32>,
35}
36// Processor
37pub struct Idefics2Processor {
38    config: ProcessorConfig,
39    preprocessor_config: PreProcessorConfig,
40    fake_image_token: &'static str,
41    image_token: &'static str,
42    max_edge: Option<u32>,
43}
44
45impl Idefics2Processor {
46    pub fn new(
47        config: ProcessorConfig,
48        preprocessor_config: PreProcessorConfig,
49        max_edge: Option<u32>,
50    ) -> Self {
51        Self {
52            config,
53            preprocessor_config,
54            fake_image_token: "<fake_token_around_image>",
55            image_token: "<image>",
56            max_edge,
57        }
58    }
59}
60
61impl Processor for Idefics2Processor {
62    fn process(
63        &self,
64        pipeline: &dyn Pipeline,
65        messages: Vec<IndexMap<String, MessageContent>>,
66        add_generation_prompt: bool,
67        add_special_tokens: bool,
68        tools: Vec<Tool>,
69    ) -> anyhow::Result<(Vec<u32>, String)> {
70        let mut prompt = apply_chat_template(
71            pipeline,
72            messages,
73            add_generation_prompt,
74            self.template_action(),
75            tools,
76        )?;
77
78        let mut image_str = format!(
79            "{}{}{}",
80            self.fake_image_token,
81            self.image_token.repeat(
82                self.config
83                    .image_seq_len
84                    .expect("Idefics 2 model needs `image_seq_len`")
85            ),
86            self.fake_image_token
87        );
88        if self
89            .preprocessor_config
90            .do_image_splitting
91            .is_some_and(|x| x)
92        {
93            // 4 patches + 1 original
94            image_str = image_str.repeat(5);
95        }
96
97        prompt = prompt.replace(self.image_token, &image_str);
98        // Deal with any adjacent images.
99        prompt = prompt.replace(
100            &format!("{}{}", self.fake_image_token, self.fake_image_token),
101            self.fake_image_token,
102        );
103
104        let Some(tokenizer) = &pipeline.tokenizer() else {
105            anyhow::bail!("Idefics2InputProcessor requires a specified tokenizer.",);
106        };
107        let encoding = tokenizer
108            .encode_fast(prompt.clone(), add_special_tokens)
109            .map_err(anyhow::Error::msg)?;
110        Ok((encoding.get_ids().to_vec(), prompt))
111    }
112
113    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
114        Arc::new(Idefics2ImageProcessor {
115            max_edge: self.max_edge,
116        })
117    }
118
119    fn get_special_tokens(&self) -> &[&'static str] {
120        &["<fake_token_around_image>", "<image>", "<end_of_utterance>"]
121    }
122
123    fn template_action(&self) -> MessagesAction {
124        MessagesAction::Keep
125    }
126}
127
128impl InputsProcessor for Idefics2ImageProcessor {
129    fn get_type(&self) -> InputsProcessorType {
130        InputsProcessorType::Vision
131    }
132    fn process_inputs(
133        &self,
134        _: Option<Arc<Tokenizer>>,
135        input_seqs: &mut [&mut Sequence],
136        is_prompt: bool,
137        is_xlora: bool,
138        device: &Device,
139        no_kv_cache: bool,
140        last_n_context_len: Option<(usize, usize)>,
141        return_raw_logits: bool,
142        other_config: Option<Arc<dyn Any>>,
143        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
144        prompt_chunksize: Option<NonZeroUsize>,
145        mapper: Option<&dyn DeviceMapper>,
146    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
147        if is_xlora {
148            return Box::new(std::iter::once(Err(anyhow::Error::msg(
149                "Cannot make inputs for X-LoRA vision model.",
150            ))));
151        }
152        if no_kv_cache {
153            return Box::new(std::iter::once(Err(anyhow::Error::msg(
154                "Vision model must have kv cache.",
155            ))));
156        }
157        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
158        if prompt_chunksize.is_some() {
159            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
160        }
161
162        let text_models_inputs_processor::InnerInputProcessorOutput {
163            inputs:
164                text_models_inputs_processor::InputMetadata {
165                    input,
166                    positions,
167                    context_lens,
168                    position_ids,
169                    paged_attn_meta,
170                    flash_meta,
171                },
172            seq_indices,
173        } = if is_prompt {
174            get_prompt_input(
175                input_seqs
176                    .iter()
177                    .map(|seq| seq.get_toks().to_vec())
178                    .collect::<Vec<_>>(),
179                input_seqs,
180                device,
181                last_n_context_len,
182                return_raw_logits,
183                paged_attn_metadata.as_mut(),
184                None, // TODO: evaluate if it is possible to batch this
185                mapper,
186            )
187            .nth(0)
188            .unwrap()
189            .unwrap()
190        } else {
191            get_completion_input(
192                input_seqs
193                    .iter()
194                    .map(|seq| seq.get_toks().to_vec())
195                    .collect::<Vec<_>>(),
196                input_seqs,
197                device,
198                no_kv_cache,
199                last_n_context_len,
200                return_raw_logits,
201                paged_attn_metadata.as_mut(),
202                None, // TODO: evaluate if it is possible to batch this
203                mapper,
204            )
205            .nth(0)
206            .unwrap()
207            .unwrap()
208        };
209        let config = other_config.expect("Need a PreProcessorConfig config.");
210        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
211
212        let has_images = input_seqs.iter().all(|seq| seq.has_images());
213
214        let (pixel_values, pixel_attention_mask) = if has_images {
215            let mut pixel_values_accum = Vec::new();
216            let mut pixel_attention_mask_accum = Vec::new();
217            for seq in input_seqs.iter_mut() {
218                let PreprocessedImages {
219                    pixel_values,
220                    pixel_attention_mask,
221                    image_sizes: _,
222                    num_img_tokens: _,
223                    aspect_ratio_ids: _,
224                    aspect_ratio_mask: _,
225                    num_tiles: _,
226                    image_grid_thw: _,
227                    video_grid_thw: _,
228                    rows: _,
229                    cols: _,
230                    pixel_values_list: _,
231                    tgt_sizes: _,
232                    image_sizes_all: _,
233                    num_crops: _,
234                } = self
235                    .preprocess(
236                        seq.take_images()
237                            .expect("Need to have images by this point."),
238                        vec![],
239                        config,
240                        device,
241                        (usize::MAX, usize::MAX), // Don't use it here...
242                    )
243                    .expect("Preprocessing failed");
244                pixel_values_accum.push(pixel_values.unsqueeze(0).unwrap());
245                pixel_attention_mask_accum
246                    .push(pixel_attention_mask.unwrap().unsqueeze(0).unwrap());
247            }
248            (
249                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
250                Some(Tensor::cat(&pixel_attention_mask_accum, 0).unwrap()),
251            )
252        } else {
253            (None, None)
254        };
255
256        let inputs: Box<dyn Any> = Box::new(ModelInputs {
257            input_ids: input,
258            seqlen_offsets: positions,
259            context_lens,
260            position_ids,
261            pixel_values,
262            model_specific_args: Box::new(pixel_attention_mask),
263            paged_attn_meta,
264            flash_meta,
265        });
266        Box::new(std::iter::once(Ok(InputProcessorOutput {
267            inputs,
268            seq_indices,
269        })))
270    }
271}
272
273impl ImagePreProcessor for Idefics2ImageProcessor {
274    #[allow(clippy::excessive_precision)]
275    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
276    #[allow(clippy::excessive_precision)]
277    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
278
279    fn preprocess(
280        &self,
281        mut images: Vec<DynamicImage>,
282        videos: Vec<Vec<DynamicImage>>,
283        config: &PreProcessorConfig,
284        device: &Device,
285        (_bs, _max_num_images): (usize, usize),
286    ) -> Result<PreprocessedImages> {
287        assert!(videos.is_empty());
288
289        let mut patch_masks = Vec::new();
290        let mut pixel_values = Vec::new();
291
292        // Image splitting
293        if config.do_image_splitting.is_some_and(|x| x) {
294            let mut new_images = Vec::new();
295            for image in images {
296                let (w, h) = image.dimensions();
297                let mid_w = w / 2;
298                let mid_h = h / 2;
299                new_images.push(image.crop_imm(0, 0, mid_w, mid_h));
300                new_images.push(image.crop_imm(mid_w, 0, w, mid_h));
301                new_images.push(image.crop_imm(0, mid_h, mid_w, h));
302                new_images.push(image.crop_imm(mid_w, mid_h, w, h));
303                new_images.push(image);
304            }
305            images = new_images;
306        }
307
308        for image in images.iter_mut() {
309            // Resize
310            if config.do_resize.is_some_and(|x| x) {
311                let size = config.size.as_ref().unwrap();
312                let (h, w) = if size.contains_key("shortest_edge")
313                    && size.contains_key("longest_edge")
314                {
315                    mistralrs_vision::get_resize_image_size(
316                        (image.dimensions().1 as usize, image.dimensions().0 as usize),
317                        (
318                            size["shortest_edge"] as usize,
319                            size["longest_edge"] as usize,
320                        ),
321                    )
322                } else if size.contains_key("height") && size.contains_key("width") {
323                    (size["height"] as usize, size["width"] as usize)
324                } else {
325                    candle_core::bail!("Size must be a map of `shortest_edge` and `longest_edge` or `height` and `width`.");
326                };
327
328                *image = image.resize_exact(w as u32, h as u32, config.resampling.to_filter()?);
329            }
330        }
331
332        if let Some(max_edge) = self.max_edge {
333            images = mistralrs_vision::pad_to_max_edge(&images, max_edge);
334        }
335
336        let mut max_h = 0;
337        let mut max_w = 0;
338        for image in &images {
339            let (w, h) = image.dimensions();
340            if w > max_w {
341                max_w = w;
342            }
343            if h > max_h {
344                max_h = h;
345            }
346        }
347
348        for image in images.iter_mut() {
349            // Convert to rgb
350            if config.do_convert_rgb.is_some_and(|x| x) {
351                *image = DynamicImage::ImageRgb8(image.to_rgb8());
352            }
353
354            let transforms = Transforms {
355                input: &ToTensorNoNorm,
356                inner_transforms: &[
357                    &config
358                        .do_rescale
359                        .is_some_and(|x| x)
360                        .then_some(())
361                        .map(|_| Rescale {
362                            factor: config.rescale_factor,
363                        }),
364                    &config
365                        .do_normalize
366                        .is_some_and(|x| x)
367                        .then_some(())
368                        .map(|_| Normalize {
369                            mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
370                            std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
371                        }),
372                ],
373            };
374
375            let mut image = image.apply(transforms, device)?;
376            // Pad images, calculating attention mask.
377            if config.do_pad.is_some_and(|x| x) {
378                let (_c, h, w) = image.dims3()?;
379                let padded = mistralrs_vision::pad(&image, max_h as usize, max_w as usize)?;
380                let mask = mistralrs_vision::make_pixel_mask(&padded, h, w)?;
381                patch_masks.push(mask.unsqueeze(0)?);
382                image = padded;
383            }
384
385            // Get pixel values
386            pixel_values.push(image.unsqueeze(0)?)
387        }
388
389        Ok(PreprocessedImages {
390            pixel_values: Tensor::cat(&pixel_values, 0)?,
391            pixel_attention_mask: Some(Tensor::cat(&patch_masks, 0)?),
392            image_sizes: None,
393            num_img_tokens: None,
394            aspect_ratio_ids: None,
395            aspect_ratio_mask: None,
396            num_tiles: None,
397            image_grid_thw: None,
398            video_grid_thw: None,
399            rows: None,
400            cols: None,
401            pixel_values_list: None,
402            tgt_sizes: None,
403            image_sizes_all: None,
404            num_crops: None,
405        })
406    }
407}