mistralrs_core/vision_models/llava/
llava_inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use std::any::Any;
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use candle_core::Result;
7use candle_core::{DType, Device, Tensor};
8use image::GenericImageView;
9use image::Rgb;
10use itertools::Itertools;
11use regex_automata::meta::Regex;
12use tokenizers::Tokenizer;
13use tracing::warn;
14
15use super::llava15::LLaVAVisionSpecificArgs;
16use super::utils::{expand2square, LLaVAImageProcessor};
17use crate::device_map::DeviceMapper;
18use crate::pipeline::text_models_inputs_processor::{
19    get_completion_input, get_prompt_input, PagedAttentionMeta,
20};
21use crate::pipeline::{
22    text_models_inputs_processor, InputProcessorOutput, InputsProcessor, InputsProcessorType,
23    MessagesAction, Processor,
24};
25use crate::sequence::Sequence;
26use crate::vision_models::image_processor::{self, ImagePreProcessor, PreprocessedImages};
27use crate::vision_models::llava::config::Config as LLaVAConfig;
28use crate::vision_models::preprocessor_config::{PreProcessorConfig, ToFilter};
29use crate::vision_models::{preprocessor_config, ModelInputs};
30
31pub struct LLaVAProcessor {
32    inputs_processor: Arc<LLaVAInputProcessor>,
33}
34
35impl Processor for LLaVAProcessor {
36    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37        self.inputs_processor.clone()
38    }
39    fn get_special_tokens(&self) -> &[&'static str] {
40        &[]
41    }
42    fn template_action(&self) -> MessagesAction {
43        MessagesAction::FlattenOnlyText
44    }
45}
46
47impl LLaVAProcessor {
48    pub fn new(config: &str) -> Self {
49        let model_config =
50            serde_json::from_str::<LLaVAConfig>(config).expect("Failed to parse model config.");
51        let image_tag_splitter = Regex::new(r"<image>").expect("Failed to compile split regex.");
52        let inputs_processor = Arc::new(LLaVAInputProcessor {
53            image_tag_splitter,
54            model_config: model_config.clone(),
55        });
56        Self { inputs_processor }
57    }
58}
59
60pub struct LLaVAInputProcessor {
61    image_tag_splitter: Regex,
62    model_config: LLaVAConfig,
63}
64
65impl LLaVAInputProcessor {
66    pub fn get_num_image_tokens(cfg: &LLaVAConfig) -> usize {
67        let patch_size = cfg.vision_config.patch_size;
68        let patch_per_side = cfg.vision_config.image_size / patch_size;
69        patch_per_side * patch_per_side
70    }
71}
72
73// Copy from phi3_inputs_processor. different is (1) calculate of num_image_token (2) process_anyres_image (3)image_ids_pad
74impl InputsProcessor for LLaVAInputProcessor {
75    fn get_type(&self) -> InputsProcessorType {
76        InputsProcessorType::Vision
77    }
78    fn process_inputs(
79        &self,
80        tokenizer: Option<Arc<Tokenizer>>,
81        input_seqs: &mut [&mut Sequence],
82        is_prompt: bool,
83        is_xlora: bool,
84        device: &Device,
85        no_kv_cache: bool,
86        last_n_context_len: Option<(usize, usize)>,
87        return_raw_logits: bool,
88        other_config: Option<Arc<dyn Any>>,
89        mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
90        prompt_chunksize: Option<NonZeroUsize>,
91        mapper: Option<&dyn DeviceMapper>,
92    ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
93        if is_xlora {
94            return Box::new(std::iter::once(Err(anyhow::Error::msg(
95                "Cannot make inputs for X-LoRA vision model.",
96            ))));
97        }
98        if no_kv_cache {
99            return Box::new(std::iter::once(Err(anyhow::Error::msg(
100                "Vision model must have kv cache.",
101            ))));
102        }
103        // TODO(EricLBuehler): support this? Would require some handling of image tokens.
104        if prompt_chunksize.is_some() {
105            warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
106        }
107        let Some(tokenizer) = tokenizer else {
108            return Box::new(std::iter::once(Err(anyhow::Error::msg(
109                "LLaVAInputProcessor requires a specified tokenizer.",
110            ))));
111        };
112
113        let config = other_config
114            .clone()
115            .expect("Need a PreProcessorConfig config.");
116        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118        let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120        let (pixel_values, num_img_tokens) = if has_images {
121            let mut pixel_values_accum = Vec::new();
122            let mut num_img_tokens_accum = Vec::new();
123            for seq in input_seqs.iter_mut() {
124                let imgs = seq
125                    .take_images()
126                    .expect("Need to have images by this point.");
127                let PreprocessedImages {
128                    pixel_values,
129                    pixel_attention_mask: _,
130                    image_sizes: _,
131                    num_img_tokens,
132                    aspect_ratio_ids: _,
133                    aspect_ratio_mask: _,
134                    num_tiles: _,
135                    image_grid_thw: _,
136                    video_grid_thw: _,
137                    rows: _,
138                    cols: _,
139                    pixel_values_list: _,
140                    tgt_sizes: _,
141                    image_sizes_all: _,
142                    num_crops: _,
143                } = self
144                    .preprocess(
145                        imgs.clone(),
146                        vec![],
147                        config,
148                        device,
149                        (usize::MAX, usize::MAX),
150                    )
151                    .expect("Preprocessor failed");
152                pixel_values_accum.push(pixel_values);
153                num_img_tokens_accum.push(num_img_tokens.unwrap());
154            }
155            (
156                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157                Some(num_img_tokens_accum),
158            )
159        } else {
160            return Box::new(
161                text_models_inputs_processor::TextInputsProcessor
162                    .process_inputs(
163                        Some(tokenizer),
164                        input_seqs,
165                        is_prompt,
166                        is_xlora,
167                        device,
168                        no_kv_cache,
169                        last_n_context_len,
170                        return_raw_logits,
171                        other_config,
172                        paged_attn_metadata,
173                        None, // TODO
174                        mapper,
175                    )
176                    .map(|metadata| {
177                        let InputProcessorOutput {
178                            inputs,
179                            seq_indices,
180                        } = metadata?;
181
182                        let text_models_inputs_processor::ModelInputs {
183                            input_ids,
184                            input_ids_full: _,
185                            seqlen_offsets,
186                            seqlen_offsets_full: _,
187                            context_lens,
188                            position_ids,
189                            paged_attn_meta,
190                            flash_meta,
191                            flash_meta_full: _,
192                        } = *inputs
193                            .downcast::<text_models_inputs_processor::ModelInputs>()
194                            .expect("Downcast failed.");
195
196                        let inputs: Box<dyn Any> = Box::new(ModelInputs {
197                            input_ids,
198                            seqlen_offsets,
199                            context_lens,
200                            position_ids,
201                            pixel_values: None,
202                            model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
203                            paged_attn_meta,
204                            flash_meta,
205                        });
206                        Ok(InputProcessorOutput {
207                            inputs,
208                            seq_indices,
209                        })
210                    }),
211            );
212        };
213
214        let mut toks = Vec::new();
215        let detokenized = tokenizer
216            .decode_batch(
217                &input_seqs
218                    .iter()
219                    .map(|seq| seq.get_toks())
220                    .collect::<Vec<_>>(),
221                false,
222            )
223            .expect("Decoding failed");
224
225        for (detokenized, (seq, num_img_tokens)) in detokenized.into_iter().zip(
226            input_seqs
227                .iter_mut()
228                .zip(num_img_tokens.unwrap().into_iter()),
229        ) {
230            let splits = self
231                .image_tag_splitter
232                .split(&detokenized)
233                .map(|span| &detokenized[span.range()])
234                .collect::<Vec<_>>();
235            let prompt_chunks = splits
236                .iter()
237                .map(|s| {
238                    // we don't use encode_batch here, because encode_batch will pad 0 to the end of the shor sequences, which will cause the image_ids_pad to be wrong.
239                    tokenizer
240                        .encode_fast(*s, false)
241                        .unwrap()
242                        .get_ids()
243                        .to_vec()
244                        .iter()
245                        .map(|x| *x as i64)
246                        .collect()
247                })
248                .collect::<Vec<Vec<_>>>();
249            let mut image_ids_pad = Vec::new();
250            for (i, num_img_token) in num_img_tokens.iter().enumerate() {
251                let mut image_id_pad = vec![0; *num_img_token];
252                image_id_pad[0] = -(i as i64 + 1);
253                image_ids_pad.push(image_id_pad);
254            }
255            let mut input_ids: Vec<i64> = Vec::new();
256            for item in prompt_chunks
257                .iter()
258                .map(|x| x.to_vec())
259                .interleave(image_ids_pad)
260            {
261                input_ids.extend(item);
262            }
263            // NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
264            seq.set_toks_and_reallocate(
265                input_ids
266                    .iter()
267                    .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
268                    .collect::<Vec<_>>(),
269                paged_attn_metadata.as_mut(),
270            );
271
272            toks.push(input_ids);
273        }
274
275        let iter = if is_prompt {
276            get_prompt_input(
277                toks,
278                input_seqs,
279                device,
280                last_n_context_len,
281                return_raw_logits,
282                paged_attn_metadata.as_mut(),
283                None, // TODO: evaluate if it is possible to batch this
284                mapper,
285            )
286        } else {
287            get_completion_input(
288                toks,
289                input_seqs,
290                device,
291                no_kv_cache,
292                last_n_context_len,
293                return_raw_logits,
294                paged_attn_metadata.as_mut(),
295                None, // TODO: evaluate if it is possible to batch this
296                mapper,
297            )
298        };
299
300        Box::new(iter.into_iter().map(move |metadata| {
301            let text_models_inputs_processor::InnerInputProcessorOutput {
302                inputs:
303                    text_models_inputs_processor::InputMetadata {
304                        input,
305                        positions,
306                        context_lens,
307                        position_ids,
308                        paged_attn_meta,
309                        flash_meta,
310                    },
311                seq_indices,
312            } = metadata?;
313            let inputs: Box<dyn Any> = Box::new(ModelInputs {
314                input_ids: input,
315                seqlen_offsets: positions,
316                context_lens,
317                position_ids,
318                pixel_values: pixel_values.clone(),
319                model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
320                paged_attn_meta,
321                flash_meta,
322            });
323            Ok(InputProcessorOutput {
324                inputs,
325                seq_indices,
326            })
327        }))
328    }
329}
330
331impl ImagePreProcessor for LLaVAInputProcessor {
332    #[allow(clippy::excessive_precision)]
333    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
334    #[allow(clippy::excessive_precision)]
335    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
336    fn preprocess(
337        &self,
338        images: Vec<image::DynamicImage>,
339        videos: Vec<Vec<image::DynamicImage>>,
340        config: &preprocessor_config::PreProcessorConfig,
341        device: &candle_core::Device,
342        (_, _): (usize, usize),
343    ) -> candle_core::Result<image_processor::PreprocessedImages> {
344        if images.len() > 1 {
345            candle_core::bail!("Can only process one image per batch"); // This is no different from phi3_input_processor
346        };
347        assert!(videos.is_empty());
348        let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize;
349
350        let original_size = images[0].dimensions();
351        let filter = config.resampling.to_filter()?;
352        let image_mean = config
353            .image_mean
354            .unwrap_or(Self::DEFAULT_MEAN)
355            .map(|x| x as f32);
356        let mean_color = image_mean
357            .iter()
358            .map(|x| ((*x) * 255.0) as u8)
359            .collect::<Vec<u8>>();
360        let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
361        let image = expand2square(&images[0], mean_color);
362        let image_std = config
363            .image_std
364            .unwrap_or(Self::DEFAULT_STD)
365            .map(|x| x as f32);
366        let pixel_values = [image]
367            .iter()
368            .map(|x| {
369                LLaVAImageProcessor::process_one_image(
370                    x,
371                    config,
372                    resized_size as u32,
373                    filter,
374                    DType::BF16,
375                    device,
376                    &image_mean,
377                    &image_std,
378                )
379            })
380            .collect::<Result<Vec<Tensor>>>()?;
381        let pixel_values = Tensor::stack(&pixel_values, 0)?;
382
383        Ok(image_processor::PreprocessedImages {
384            pixel_values,
385            pixel_attention_mask: None,
386            image_sizes: Some((original_size.0 as usize, original_size.1 as usize)),
387            num_img_tokens: Some(vec![LLaVAInputProcessor::get_num_image_tokens(
388                &self.model_config,
389            )]),
390            aspect_ratio_ids: None,
391            aspect_ratio_mask: None,
392            num_tiles: None,
393            image_grid_thw: None,
394            video_grid_thw: None,
395            rows: None,
396            cols: None,
397            pixel_values_list: None,
398            tgt_sizes: None,
399            image_sizes_all: None,
400            num_crops: None,
401        })
402    }
403}