1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use std::any::Any;
3use std::sync::Arc;
4
5use candle_core::Result;
6use candle_core::{DType, Device, Tensor};
7use image::GenericImageView;
8use image::Rgb;
9use itertools::Itertools;
10use regex_automata::meta::Regex;
11use tokenizers::Tokenizer;
12
13use super::llava15::LLaVAVisionSpecificArgs;
14use super::utils::{expand2square, LLaVAImageProcessor};
15use crate::device_map::DeviceMapper;
16use crate::pipeline::text_models_inputs_processor::{
17 get_completion_input, get_prompt_input, PagedAttentionMeta,
18};
19use crate::pipeline::{
20 text_models_inputs_processor, InputProcessorOutput, InputsProcessor, InputsProcessorType,
21 MessagesAction, Processor,
22};
23use crate::sequence::Sequence;
24use crate::vision_models::image_processor::{self, ImagePreProcessor, PreprocessedImages};
25use crate::vision_models::llava::config::Config as LLaVAConfig;
26use crate::vision_models::preprocessor_config::{PreProcessorConfig, ToFilter};
27use crate::vision_models::{preprocessor_config, ModelInputs};
28
29pub struct LLaVAProcessor {
30 inputs_processor: Arc<LLaVAInputProcessor>,
31}
32
33impl Processor for LLaVAProcessor {
34 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
35 self.inputs_processor.clone()
36 }
37 fn get_special_tokens(&self) -> &[&'static str] {
38 &[]
39 }
40 fn template_action(&self) -> MessagesAction {
41 MessagesAction::FlattenOnlyText
42 }
43}
44
45impl LLaVAProcessor {
46 pub fn new(config: &str) -> Self {
47 let model_config =
48 serde_json::from_str::<LLaVAConfig>(config).expect("Failed to parse model config.");
49 let image_tag_splitter = Regex::new(r"<image>").expect("Failed to compile split regex.");
50 let inputs_processor = Arc::new(LLaVAInputProcessor {
51 image_tag_splitter,
52 model_config: model_config.clone(),
53 });
54 Self { inputs_processor }
55 }
56}
57
58pub struct LLaVAInputProcessor {
59 image_tag_splitter: Regex,
60 model_config: LLaVAConfig,
61}
62
63impl LLaVAInputProcessor {
64 pub fn get_num_image_tokens(cfg: &LLaVAConfig) -> usize {
65 let patch_size = cfg.vision_config.patch_size;
66 let patch_per_side = cfg.vision_config.image_size / patch_size;
67 patch_per_side * patch_per_side
68 }
69}
70
71impl InputsProcessor for LLaVAInputProcessor {
73 fn get_type(&self) -> InputsProcessorType {
74 InputsProcessorType::Vision
75 }
76 fn process_inputs(
77 &self,
78 tokenizer: Option<Arc<Tokenizer>>,
79 input_seqs: &mut [&mut Sequence],
80 is_prompt: bool,
81 is_xlora: bool,
82 device: &Device,
83 no_kv_cache: bool,
84 last_n_context_len: Option<(usize, usize)>,
85 return_raw_logits: bool,
86 other_config: Option<Arc<dyn Any>>,
87 mut paged_attn_metadata: Option<PagedAttentionMeta>,
88 mapper: Option<&dyn DeviceMapper>,
89 ) -> anyhow::Result<InputProcessorOutput> {
90 if is_xlora {
91 return Err(anyhow::Error::msg(
92 "Cannot make inputs for X-LoRA vision model.",
93 ));
94 }
95 if no_kv_cache {
96 return Err(anyhow::Error::msg("Vision model must have kv cache."));
97 }
98 let Some(tokenizer) = tokenizer else {
99 return Err(anyhow::Error::msg(
100 "LLaVAInputProcessor requires a specified tokenizer.",
101 ));
102 };
103
104 let config = other_config
105 .clone()
106 .expect("Need a PreProcessorConfig config.");
107 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
108
109 let has_images = input_seqs.iter().all(|seq| seq.has_images());
110
111 let (pixel_values, num_img_tokens) = if has_images {
112 let mut pixel_values_accum = Vec::new();
113 let mut num_img_tokens_accum = Vec::new();
114 for seq in input_seqs.iter_mut() {
115 let imgs = seq
116 .take_images()
117 .expect("Need to have images by this point.");
118 let PreprocessedImages {
119 pixel_values,
120 pixel_attention_mask: _,
121 image_sizes: _,
122 num_img_tokens,
123 aspect_ratio_ids: _,
124 aspect_ratio_mask: _,
125 num_tiles: _,
126 image_grid_thw: _,
127 video_grid_thw: _,
128 rows: _,
129 cols: _,
130 pixel_values_list: _,
131 tgt_sizes: _,
132 image_sizes_all: _,
133 num_crops: _,
134 } = self
135 .preprocess(
136 imgs.clone(),
137 vec![],
138 config,
139 device,
140 (usize::MAX, usize::MAX),
141 )
142 .expect("Preprocessor failed");
143 pixel_values_accum.push(pixel_values);
144 num_img_tokens_accum.push(num_img_tokens.unwrap());
145 }
146 (
147 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
148 Some(num_img_tokens_accum),
149 )
150 } else {
151 return text_models_inputs_processor::TextInputsProcessor
152 .process_inputs(
153 Some(tokenizer),
154 input_seqs,
155 is_prompt,
156 is_xlora,
157 device,
158 no_kv_cache,
159 last_n_context_len,
160 return_raw_logits,
161 other_config,
162 paged_attn_metadata,
163 mapper,
164 )
165 .map(|metadata| {
166 let InputProcessorOutput {
167 inputs,
168 seq_indices,
169 } = metadata;
170
171 let text_models_inputs_processor::ModelInputs {
172 input_ids,
173 input_ids_full: _,
174 seqlen_offsets,
175 seqlen_offsets_full: _,
176 context_lens,
177 position_ids,
178 paged_attn_meta,
179 flash_meta,
180 flash_meta_full: _,
181 } = *inputs
182 .downcast::<text_models_inputs_processor::ModelInputs>()
183 .expect("Downcast failed.");
184
185 let inputs: Box<dyn Any> = Box::new(ModelInputs {
186 input_ids,
187 seqlen_offsets,
188 context_lens,
189 position_ids,
190 pixel_values: None,
191 model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
192 paged_attn_meta,
193 flash_meta,
194 });
195 InputProcessorOutput {
196 inputs,
197 seq_indices,
198 }
199 });
200 };
201
202 let mut toks = Vec::new();
203 let detokenized = tokenizer
204 .decode_batch(
205 &input_seqs
206 .iter()
207 .map(|seq| seq.get_toks())
208 .collect::<Vec<_>>(),
209 false,
210 )
211 .expect("Decoding failed");
212
213 for (detokenized, (seq, num_img_tokens)) in detokenized.into_iter().zip(
214 input_seqs
215 .iter_mut()
216 .zip(num_img_tokens.unwrap().into_iter()),
217 ) {
218 let splits = self
219 .image_tag_splitter
220 .split(&detokenized)
221 .map(|span| &detokenized[span.range()])
222 .collect::<Vec<_>>();
223 let prompt_chunks = splits
224 .iter()
225 .map(|s| {
226 tokenizer
228 .encode_fast(*s, false)
229 .unwrap()
230 .get_ids()
231 .to_vec()
232 .iter()
233 .map(|x| *x as i64)
234 .collect()
235 })
236 .collect::<Vec<Vec<_>>>();
237 let mut image_ids_pad = Vec::new();
238 for (i, num_img_token) in num_img_tokens.iter().enumerate() {
239 let mut image_id_pad = vec![0; *num_img_token];
240 image_id_pad[0] = -(i as i64 + 1);
241 image_ids_pad.push(image_id_pad);
242 }
243 let mut input_ids: Vec<i64> = Vec::new();
244 for item in prompt_chunks
245 .iter()
246 .map(|x| x.to_vec())
247 .interleave(image_ids_pad)
248 {
249 input_ids.extend(item);
250 }
251 let new_ids = input_ids
252 .iter()
253 .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
254 .collect::<Vec<_>>();
255 if !seq.multimodal.has_changed_prompt {
256 let new_prompt = tokenizer.decode(&new_ids, false).unwrap();
257 seq.set_initial_prompt(new_prompt);
258 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
260 seq.multimodal.has_changed_prompt = true;
261 }
262
263 toks.push(input_ids);
264 }
265
266 let metadata = if is_prompt {
267 get_prompt_input(
268 toks.iter().map(Vec::as_slice).collect(),
269 input_seqs,
270 device,
271 last_n_context_len,
272 return_raw_logits,
273 paged_attn_metadata.as_mut(),
274 mapper,
275 )
276 } else {
277 get_completion_input(
278 toks.iter().map(Vec::as_slice).collect(),
279 input_seqs,
280 device,
281 no_kv_cache,
282 last_n_context_len,
283 return_raw_logits,
284 paged_attn_metadata.as_mut(),
285 mapper,
286 )
287 };
288
289 metadata.map(|metadata| {
290 let text_models_inputs_processor::InnerInputProcessorOutput {
291 inputs:
292 text_models_inputs_processor::InputMetadata {
293 input,
294 positions,
295 context_lens,
296 position_ids,
297 paged_attn_meta,
298 flash_meta,
299 },
300 seq_indices,
301 } = metadata;
302 let inputs: Box<dyn Any> = Box::new(ModelInputs {
303 input_ids: input,
304 seqlen_offsets: positions,
305 context_lens,
306 position_ids,
307 pixel_values: pixel_values.clone(),
308 model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
309 paged_attn_meta,
310 flash_meta,
311 });
312 InputProcessorOutput {
313 inputs,
314 seq_indices,
315 }
316 })
317 }
318}
319
320impl ImagePreProcessor for LLaVAInputProcessor {
321 #[allow(clippy::excessive_precision)]
322 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
323 #[allow(clippy::excessive_precision)]
324 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
325 fn preprocess(
326 &self,
327 images: Vec<image::DynamicImage>,
328 videos: Vec<Vec<image::DynamicImage>>,
329 config: &preprocessor_config::PreProcessorConfig,
330 device: &candle_core::Device,
331 (_, _): (usize, usize),
332 ) -> candle_core::Result<image_processor::PreprocessedImages> {
333 if images.len() > 1 {
334 candle_core::bail!("Can only process one image per batch"); };
336 assert!(videos.is_empty());
337 let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize;
338
339 let original_size = images[0].dimensions();
340 let filter = config.resampling.to_filter()?;
341 let image_mean = config
342 .image_mean
343 .unwrap_or(Self::DEFAULT_MEAN)
344 .map(|x| x as f32);
345 let mean_color = image_mean
346 .iter()
347 .map(|x| ((*x) * 255.0) as u8)
348 .collect::<Vec<u8>>();
349 let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
350 let image = expand2square(&images[0], mean_color);
351 let image_std = config
352 .image_std
353 .unwrap_or(Self::DEFAULT_STD)
354 .map(|x| x as f32);
355 let pixel_values = [image]
356 .iter()
357 .map(|x| {
358 LLaVAImageProcessor::process_one_image(
359 x,
360 config,
361 resized_size as u32,
362 filter,
363 DType::BF16,
364 device,
365 &image_mean,
366 &image_std,
367 )
368 })
369 .collect::<Result<Vec<Tensor>>>()?;
370 let pixel_values = Tensor::stack(&pixel_values, 0)?;
371
372 Ok(image_processor::PreprocessedImages {
373 pixel_values,
374 pixel_attention_mask: None,
375 image_sizes: Some((original_size.0 as usize, original_size.1 as usize)),
376 num_img_tokens: Some(vec![LLaVAInputProcessor::get_num_image_tokens(
377 &self.model_config,
378 )]),
379 aspect_ratio_ids: None,
380 aspect_ratio_mask: None,
381 num_tiles: None,
382 image_grid_thw: None,
383 video_grid_thw: None,
384 rows: None,
385 cols: None,
386 pixel_values_list: None,
387 tgt_sizes: None,
388 image_sizes_all: None,
389 num_crops: None,
390 })
391 }
392}