mistralrs_core/vision_models/llava/
llava_inputs_processor.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use std::any::Any;
3use std::sync::Arc;
4
5use candle_core::Result;
6use candle_core::{DType, Device, Tensor};
7use image::GenericImageView;
8use image::Rgb;
9use itertools::Itertools;
10use regex_automata::meta::Regex;
11use tokenizers::Tokenizer;
12
13use super::llava15::LLaVAVisionSpecificArgs;
14use super::utils::{expand2square, LLaVAImageProcessor};
15use crate::device_map::DeviceMapper;
16use crate::pipeline::text_models_inputs_processor::{
17    get_completion_input, get_prompt_input, PagedAttentionMeta,
18};
19use crate::pipeline::{
20    text_models_inputs_processor, InputProcessorOutput, InputsProcessor, InputsProcessorType,
21    MessagesAction, Processor,
22};
23use crate::sequence::Sequence;
24use crate::vision_models::image_processor::{self, ImagePreProcessor, PreprocessedImages};
25use crate::vision_models::llava::config::Config as LLaVAConfig;
26use crate::vision_models::preprocessor_config::{PreProcessorConfig, ToFilter};
27use crate::vision_models::{preprocessor_config, ModelInputs};
28
29pub struct LLaVAProcessor {
30    inputs_processor: Arc<LLaVAInputProcessor>,
31}
32
33impl Processor for LLaVAProcessor {
34    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
35        self.inputs_processor.clone()
36    }
37    fn get_special_tokens(&self) -> &[&'static str] {
38        &[]
39    }
40    fn template_action(&self) -> MessagesAction {
41        MessagesAction::FlattenOnlyText
42    }
43}
44
45impl LLaVAProcessor {
46    pub fn new(config: &str) -> Self {
47        let model_config =
48            serde_json::from_str::<LLaVAConfig>(config).expect("Failed to parse model config.");
49        let image_tag_splitter = Regex::new(r"<image>").expect("Failed to compile split regex.");
50        let inputs_processor = Arc::new(LLaVAInputProcessor {
51            image_tag_splitter,
52            model_config: model_config.clone(),
53        });
54        Self { inputs_processor }
55    }
56}
57
58pub struct LLaVAInputProcessor {
59    image_tag_splitter: Regex,
60    model_config: LLaVAConfig,
61}
62
63impl LLaVAInputProcessor {
64    pub fn get_num_image_tokens(cfg: &LLaVAConfig) -> usize {
65        let patch_size = cfg.vision_config.patch_size;
66        let patch_per_side = cfg.vision_config.image_size / patch_size;
67        patch_per_side * patch_per_side
68    }
69}
70
71// Copy from phi3_inputs_processor. different is (1) calculate of num_image_token (2) process_anyres_image (3)image_ids_pad
72impl InputsProcessor for LLaVAInputProcessor {
73    fn get_type(&self) -> InputsProcessorType {
74        InputsProcessorType::Vision
75    }
76    fn process_inputs(
77        &self,
78        tokenizer: Option<Arc<Tokenizer>>,
79        input_seqs: &mut [&mut Sequence],
80        is_prompt: bool,
81        is_xlora: bool,
82        device: &Device,
83        no_kv_cache: bool,
84        last_n_context_len: Option<(usize, usize)>,
85        return_raw_logits: bool,
86        other_config: Option<Arc<dyn Any>>,
87        mut paged_attn_metadata: Option<PagedAttentionMeta>,
88        mapper: Option<&dyn DeviceMapper>,
89    ) -> anyhow::Result<InputProcessorOutput> {
90        if is_xlora {
91            return Err(anyhow::Error::msg(
92                "Cannot make inputs for X-LoRA vision model.",
93            ));
94        }
95        if no_kv_cache {
96            return Err(anyhow::Error::msg("Vision model must have kv cache."));
97        }
98        let Some(tokenizer) = tokenizer else {
99            return Err(anyhow::Error::msg(
100                "LLaVAInputProcessor requires a specified tokenizer.",
101            ));
102        };
103
104        let config = other_config
105            .clone()
106            .expect("Need a PreProcessorConfig config.");
107        let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
108
109        let has_images = input_seqs.iter().all(|seq| seq.has_images());
110
111        let (pixel_values, num_img_tokens) = if has_images {
112            let mut pixel_values_accum = Vec::new();
113            let mut num_img_tokens_accum = Vec::new();
114            for seq in input_seqs.iter_mut() {
115                let imgs = seq
116                    .take_images()
117                    .expect("Need to have images by this point.");
118                let PreprocessedImages {
119                    pixel_values,
120                    pixel_attention_mask: _,
121                    image_sizes: _,
122                    num_img_tokens,
123                    aspect_ratio_ids: _,
124                    aspect_ratio_mask: _,
125                    num_tiles: _,
126                    image_grid_thw: _,
127                    video_grid_thw: _,
128                    rows: _,
129                    cols: _,
130                    pixel_values_list: _,
131                    tgt_sizes: _,
132                    image_sizes_all: _,
133                    num_crops: _,
134                } = self
135                    .preprocess(
136                        imgs.clone(),
137                        vec![],
138                        config,
139                        device,
140                        (usize::MAX, usize::MAX),
141                    )
142                    .expect("Preprocessor failed");
143                pixel_values_accum.push(pixel_values);
144                num_img_tokens_accum.push(num_img_tokens.unwrap());
145            }
146            (
147                Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
148                Some(num_img_tokens_accum),
149            )
150        } else {
151            return text_models_inputs_processor::TextInputsProcessor
152                .process_inputs(
153                    Some(tokenizer),
154                    input_seqs,
155                    is_prompt,
156                    is_xlora,
157                    device,
158                    no_kv_cache,
159                    last_n_context_len,
160                    return_raw_logits,
161                    other_config,
162                    paged_attn_metadata,
163                    mapper,
164                )
165                .map(|metadata| {
166                    let InputProcessorOutput {
167                        inputs,
168                        seq_indices,
169                    } = metadata;
170
171                    let text_models_inputs_processor::ModelInputs {
172                        input_ids,
173                        input_ids_full: _,
174                        seqlen_offsets,
175                        seqlen_offsets_full: _,
176                        context_lens,
177                        position_ids,
178                        paged_attn_meta,
179                        flash_meta,
180                        flash_meta_full: _,
181                    } = *inputs
182                        .downcast::<text_models_inputs_processor::ModelInputs>()
183                        .expect("Downcast failed.");
184
185                    let inputs: Box<dyn Any> = Box::new(ModelInputs {
186                        input_ids,
187                        seqlen_offsets,
188                        context_lens,
189                        position_ids,
190                        pixel_values: None,
191                        model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
192                        paged_attn_meta,
193                        flash_meta,
194                    });
195                    InputProcessorOutput {
196                        inputs,
197                        seq_indices,
198                    }
199                });
200        };
201
202        let mut toks = Vec::new();
203        let detokenized = tokenizer
204            .decode_batch(
205                &input_seqs
206                    .iter()
207                    .map(|seq| seq.get_toks())
208                    .collect::<Vec<_>>(),
209                false,
210            )
211            .expect("Decoding failed");
212
213        for (detokenized, (seq, num_img_tokens)) in detokenized.into_iter().zip(
214            input_seqs
215                .iter_mut()
216                .zip(num_img_tokens.unwrap().into_iter()),
217        ) {
218            let splits = self
219                .image_tag_splitter
220                .split(&detokenized)
221                .map(|span| &detokenized[span.range()])
222                .collect::<Vec<_>>();
223            let prompt_chunks = splits
224                .iter()
225                .map(|s| {
226                    // we don't use encode_batch here, because encode_batch will pad 0 to the end of the shor sequences, which will cause the image_ids_pad to be wrong.
227                    tokenizer
228                        .encode_fast(*s, false)
229                        .unwrap()
230                        .get_ids()
231                        .to_vec()
232                        .iter()
233                        .map(|x| *x as i64)
234                        .collect()
235                })
236                .collect::<Vec<Vec<_>>>();
237            let mut image_ids_pad = Vec::new();
238            for (i, num_img_token) in num_img_tokens.iter().enumerate() {
239                let mut image_id_pad = vec![0; *num_img_token];
240                image_id_pad[0] = -(i as i64 + 1);
241                image_ids_pad.push(image_id_pad);
242            }
243            let mut input_ids: Vec<i64> = Vec::new();
244            for item in prompt_chunks
245                .iter()
246                .map(|x| x.to_vec())
247                .interleave(image_ids_pad)
248            {
249                input_ids.extend(item);
250            }
251            let new_ids = input_ids
252                .iter()
253                .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
254                .collect::<Vec<_>>();
255            if !seq.multimodal.has_changed_prompt {
256                let new_prompt = tokenizer.decode(&new_ids, false).unwrap();
257                seq.set_initial_prompt(new_prompt);
258                // NOTE(EricLBuehler): Casting to u32 is fine, we don't care about the other toks
259                seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
260                seq.multimodal.has_changed_prompt = true;
261            }
262
263            toks.push(input_ids);
264        }
265
266        let metadata = if is_prompt {
267            get_prompt_input(
268                toks.iter().map(Vec::as_slice).collect(),
269                input_seqs,
270                device,
271                last_n_context_len,
272                return_raw_logits,
273                paged_attn_metadata.as_mut(),
274                mapper,
275            )
276        } else {
277            get_completion_input(
278                toks.iter().map(Vec::as_slice).collect(),
279                input_seqs,
280                device,
281                no_kv_cache,
282                last_n_context_len,
283                return_raw_logits,
284                paged_attn_metadata.as_mut(),
285                mapper,
286            )
287        };
288
289        metadata.map(|metadata| {
290            let text_models_inputs_processor::InnerInputProcessorOutput {
291                inputs:
292                    text_models_inputs_processor::InputMetadata {
293                        input,
294                        positions,
295                        context_lens,
296                        position_ids,
297                        paged_attn_meta,
298                        flash_meta,
299                    },
300                seq_indices,
301            } = metadata;
302            let inputs: Box<dyn Any> = Box::new(ModelInputs {
303                input_ids: input,
304                seqlen_offsets: positions,
305                context_lens,
306                position_ids,
307                pixel_values: pixel_values.clone(),
308                model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
309                paged_attn_meta,
310                flash_meta,
311            });
312            InputProcessorOutput {
313                inputs,
314                seq_indices,
315            }
316        })
317    }
318}
319
320impl ImagePreProcessor for LLaVAInputProcessor {
321    #[allow(clippy::excessive_precision)]
322    const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
323    #[allow(clippy::excessive_precision)]
324    const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
325    fn preprocess(
326        &self,
327        images: Vec<image::DynamicImage>,
328        videos: Vec<Vec<image::DynamicImage>>,
329        config: &preprocessor_config::PreProcessorConfig,
330        device: &candle_core::Device,
331        (_, _): (usize, usize),
332    ) -> candle_core::Result<image_processor::PreprocessedImages> {
333        if images.len() > 1 {
334            candle_core::bail!("Can only process one image per batch"); // This is no different from phi3_input_processor
335        };
336        assert!(videos.is_empty());
337        let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize;
338
339        let original_size = images[0].dimensions();
340        let filter = config.resampling.to_filter()?;
341        let image_mean = config
342            .image_mean
343            .unwrap_or(Self::DEFAULT_MEAN)
344            .map(|x| x as f32);
345        let mean_color = image_mean
346            .iter()
347            .map(|x| ((*x) * 255.0) as u8)
348            .collect::<Vec<u8>>();
349        let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
350        let image = expand2square(&images[0], mean_color);
351        let image_std = config
352            .image_std
353            .unwrap_or(Self::DEFAULT_STD)
354            .map(|x| x as f32);
355        let pixel_values = [image]
356            .iter()
357            .map(|x| {
358                LLaVAImageProcessor::process_one_image(
359                    x,
360                    config,
361                    resized_size as u32,
362                    filter,
363                    DType::BF16,
364                    device,
365                    &image_mean,
366                    &image_std,
367                )
368            })
369            .collect::<Result<Vec<Tensor>>>()?;
370        let pixel_values = Tensor::stack(&pixel_values, 0)?;
371
372        Ok(image_processor::PreprocessedImages {
373            pixel_values,
374            pixel_attention_mask: None,
375            image_sizes: Some((original_size.0 as usize, original_size.1 as usize)),
376            num_img_tokens: Some(vec![LLaVAInputProcessor::get_num_image_tokens(
377                &self.model_config,
378            )]),
379            aspect_ratio_ids: None,
380            aspect_ratio_mask: None,
381            num_tiles: None,
382            image_grid_thw: None,
383            video_grid_thw: None,
384            rows: None,
385            cols: None,
386            pixel_values_list: None,
387            tgt_sizes: None,
388            image_sizes_all: None,
389            num_crops: None,
390        })
391    }
392}