1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use std::any::Any;
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use candle_core::Result;
7use candle_core::{DType, Device, Tensor};
8use image::GenericImageView;
9use image::Rgb;
10use itertools::Itertools;
11use regex_automata::meta::Regex;
12use tokenizers::Tokenizer;
13use tracing::warn;
14
15use super::llava15::LLaVAVisionSpecificArgs;
16use super::utils::{expand2square, LLaVAImageProcessor};
17use crate::device_map::DeviceMapper;
18use crate::pipeline::text_models_inputs_processor::{
19 get_completion_input, get_prompt_input, PagedAttentionMeta,
20};
21use crate::pipeline::{
22 text_models_inputs_processor, InputProcessorOutput, InputsProcessor, InputsProcessorType,
23 MessagesAction, Processor,
24};
25use crate::sequence::Sequence;
26use crate::vision_models::image_processor::{self, ImagePreProcessor, PreprocessedImages};
27use crate::vision_models::llava::config::Config as LLaVAConfig;
28use crate::vision_models::preprocessor_config::{PreProcessorConfig, ToFilter};
29use crate::vision_models::{preprocessor_config, ModelInputs};
30
31pub struct LLaVAProcessor {
32 inputs_processor: Arc<LLaVAInputProcessor>,
33}
34
35impl Processor for LLaVAProcessor {
36 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37 self.inputs_processor.clone()
38 }
39 fn get_special_tokens(&self) -> &[&'static str] {
40 &[]
41 }
42 fn template_action(&self) -> MessagesAction {
43 MessagesAction::FlattenOnlyText
44 }
45}
46
47impl LLaVAProcessor {
48 pub fn new(config: &str) -> Self {
49 let model_config =
50 serde_json::from_str::<LLaVAConfig>(config).expect("Failed to parse model config.");
51 let image_tag_splitter = Regex::new(r"<image>").expect("Failed to compile split regex.");
52 let inputs_processor = Arc::new(LLaVAInputProcessor {
53 image_tag_splitter,
54 model_config: model_config.clone(),
55 });
56 Self { inputs_processor }
57 }
58}
59
60pub struct LLaVAInputProcessor {
61 image_tag_splitter: Regex,
62 model_config: LLaVAConfig,
63}
64
65impl LLaVAInputProcessor {
66 pub fn get_num_image_tokens(cfg: &LLaVAConfig) -> usize {
67 let patch_size = cfg.vision_config.patch_size;
68 let patch_per_side = cfg.vision_config.image_size / patch_size;
69 patch_per_side * patch_per_side
70 }
71}
72
73impl InputsProcessor for LLaVAInputProcessor {
75 fn get_type(&self) -> InputsProcessorType {
76 InputsProcessorType::Vision
77 }
78 fn process_inputs(
79 &self,
80 tokenizer: Option<Arc<Tokenizer>>,
81 input_seqs: &mut [&mut Sequence],
82 is_prompt: bool,
83 is_xlora: bool,
84 device: &Device,
85 no_kv_cache: bool,
86 last_n_context_len: Option<(usize, usize)>,
87 return_raw_logits: bool,
88 other_config: Option<Arc<dyn Any>>,
89 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
90 prompt_chunksize: Option<NonZeroUsize>,
91 mapper: Option<&dyn DeviceMapper>,
92 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
93 if is_xlora {
94 return Box::new(std::iter::once(Err(anyhow::Error::msg(
95 "Cannot make inputs for X-LoRA vision model.",
96 ))));
97 }
98 if no_kv_cache {
99 return Box::new(std::iter::once(Err(anyhow::Error::msg(
100 "Vision model must have kv cache.",
101 ))));
102 }
103 if prompt_chunksize.is_some() {
105 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
106 }
107 let Some(tokenizer) = tokenizer else {
108 return Box::new(std::iter::once(Err(anyhow::Error::msg(
109 "LLaVAInputProcessor requires a specified tokenizer.",
110 ))));
111 };
112
113 let config = other_config
114 .clone()
115 .expect("Need a PreProcessorConfig config.");
116 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118 let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120 let (pixel_values, num_img_tokens) = if has_images {
121 let mut pixel_values_accum = Vec::new();
122 let mut num_img_tokens_accum = Vec::new();
123 for seq in input_seqs.iter_mut() {
124 let imgs = seq
125 .take_images()
126 .expect("Need to have images by this point.");
127 let PreprocessedImages {
128 pixel_values,
129 pixel_attention_mask: _,
130 image_sizes: _,
131 num_img_tokens,
132 aspect_ratio_ids: _,
133 aspect_ratio_mask: _,
134 num_tiles: _,
135 image_grid_thw: _,
136 video_grid_thw: _,
137 rows: _,
138 cols: _,
139 pixel_values_list: _,
140 tgt_sizes: _,
141 image_sizes_all: _,
142 num_crops: _,
143 } = self
144 .preprocess(
145 imgs.clone(),
146 vec![],
147 config,
148 device,
149 (usize::MAX, usize::MAX),
150 )
151 .expect("Preprocessor failed");
152 pixel_values_accum.push(pixel_values);
153 num_img_tokens_accum.push(num_img_tokens.unwrap());
154 }
155 (
156 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157 Some(num_img_tokens_accum),
158 )
159 } else {
160 return Box::new(
161 text_models_inputs_processor::TextInputsProcessor
162 .process_inputs(
163 Some(tokenizer),
164 input_seqs,
165 is_prompt,
166 is_xlora,
167 device,
168 no_kv_cache,
169 last_n_context_len,
170 return_raw_logits,
171 other_config,
172 paged_attn_metadata,
173 None, mapper,
175 )
176 .map(|metadata| {
177 let InputProcessorOutput {
178 inputs,
179 seq_indices,
180 } = metadata?;
181
182 let text_models_inputs_processor::ModelInputs {
183 input_ids,
184 input_ids_full: _,
185 seqlen_offsets,
186 seqlen_offsets_full: _,
187 context_lens,
188 position_ids,
189 paged_attn_meta,
190 flash_meta,
191 flash_meta_full: _,
192 } = *inputs
193 .downcast::<text_models_inputs_processor::ModelInputs>()
194 .expect("Downcast failed.");
195
196 let inputs: Box<dyn Any> = Box::new(ModelInputs {
197 input_ids,
198 seqlen_offsets,
199 context_lens,
200 position_ids,
201 pixel_values: None,
202 model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
203 paged_attn_meta,
204 flash_meta,
205 });
206 Ok(InputProcessorOutput {
207 inputs,
208 seq_indices,
209 })
210 }),
211 );
212 };
213
214 let mut toks = Vec::new();
215 let detokenized = tokenizer
216 .decode_batch(
217 &input_seqs
218 .iter()
219 .map(|seq| seq.get_toks())
220 .collect::<Vec<_>>(),
221 false,
222 )
223 .expect("Decoding failed");
224
225 for (detokenized, (seq, num_img_tokens)) in detokenized.into_iter().zip(
226 input_seqs
227 .iter_mut()
228 .zip(num_img_tokens.unwrap().into_iter()),
229 ) {
230 let splits = self
231 .image_tag_splitter
232 .split(&detokenized)
233 .map(|span| &detokenized[span.range()])
234 .collect::<Vec<_>>();
235 let prompt_chunks = splits
236 .iter()
237 .map(|s| {
238 tokenizer
240 .encode_fast(*s, false)
241 .unwrap()
242 .get_ids()
243 .to_vec()
244 .iter()
245 .map(|x| *x as i64)
246 .collect()
247 })
248 .collect::<Vec<Vec<_>>>();
249 let mut image_ids_pad = Vec::new();
250 for (i, num_img_token) in num_img_tokens.iter().enumerate() {
251 let mut image_id_pad = vec![0; *num_img_token];
252 image_id_pad[0] = -(i as i64 + 1);
253 image_ids_pad.push(image_id_pad);
254 }
255 let mut input_ids: Vec<i64> = Vec::new();
256 for item in prompt_chunks
257 .iter()
258 .map(|x| x.to_vec())
259 .interleave(image_ids_pad)
260 {
261 input_ids.extend(item);
262 }
263 let new_ids = input_ids
264 .iter()
265 .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
266 .collect::<Vec<_>>();
267 if !seq.has_changed_prompt {
268 let new_prompt = tokenizer.decode(&new_ids, false).unwrap();
269 seq.set_initial_prompt(new_prompt);
270 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
272 seq.has_changed_prompt = true;
273 }
274
275 toks.push(input_ids);
276 }
277
278 let iter = if is_prompt {
279 get_prompt_input(
280 toks,
281 input_seqs,
282 device,
283 last_n_context_len,
284 return_raw_logits,
285 paged_attn_metadata.as_mut(),
286 None, mapper,
288 )
289 } else {
290 get_completion_input(
291 toks,
292 input_seqs,
293 device,
294 no_kv_cache,
295 last_n_context_len,
296 return_raw_logits,
297 paged_attn_metadata.as_mut(),
298 None, mapper,
300 )
301 };
302
303 Box::new(iter.into_iter().map(move |metadata| {
304 let text_models_inputs_processor::InnerInputProcessorOutput {
305 inputs:
306 text_models_inputs_processor::InputMetadata {
307 input,
308 positions,
309 context_lens,
310 position_ids,
311 paged_attn_meta,
312 flash_meta,
313 },
314 seq_indices,
315 } = metadata?;
316 let inputs: Box<dyn Any> = Box::new(ModelInputs {
317 input_ids: input,
318 seqlen_offsets: positions,
319 context_lens,
320 position_ids,
321 pixel_values: pixel_values.clone(),
322 model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
323 paged_attn_meta,
324 flash_meta,
325 });
326 Ok(InputProcessorOutput {
327 inputs,
328 seq_indices,
329 })
330 }))
331 }
332}
333
334impl ImagePreProcessor for LLaVAInputProcessor {
335 #[allow(clippy::excessive_precision)]
336 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
337 #[allow(clippy::excessive_precision)]
338 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
339 fn preprocess(
340 &self,
341 images: Vec<image::DynamicImage>,
342 videos: Vec<Vec<image::DynamicImage>>,
343 config: &preprocessor_config::PreProcessorConfig,
344 device: &candle_core::Device,
345 (_, _): (usize, usize),
346 ) -> candle_core::Result<image_processor::PreprocessedImages> {
347 if images.len() > 1 {
348 candle_core::bail!("Can only process one image per batch"); };
350 assert!(videos.is_empty());
351 let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize;
352
353 let original_size = images[0].dimensions();
354 let filter = config.resampling.to_filter()?;
355 let image_mean = config
356 .image_mean
357 .unwrap_or(Self::DEFAULT_MEAN)
358 .map(|x| x as f32);
359 let mean_color = image_mean
360 .iter()
361 .map(|x| ((*x) * 255.0) as u8)
362 .collect::<Vec<u8>>();
363 let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
364 let image = expand2square(&images[0], mean_color);
365 let image_std = config
366 .image_std
367 .unwrap_or(Self::DEFAULT_STD)
368 .map(|x| x as f32);
369 let pixel_values = [image]
370 .iter()
371 .map(|x| {
372 LLaVAImageProcessor::process_one_image(
373 x,
374 config,
375 resized_size as u32,
376 filter,
377 DType::BF16,
378 device,
379 &image_mean,
380 &image_std,
381 )
382 })
383 .collect::<Result<Vec<Tensor>>>()?;
384 let pixel_values = Tensor::stack(&pixel_values, 0)?;
385
386 Ok(image_processor::PreprocessedImages {
387 pixel_values,
388 pixel_attention_mask: None,
389 image_sizes: Some((original_size.0 as usize, original_size.1 as usize)),
390 num_img_tokens: Some(vec![LLaVAInputProcessor::get_num_image_tokens(
391 &self.model_config,
392 )]),
393 aspect_ratio_ids: None,
394 aspect_ratio_mask: None,
395 num_tiles: None,
396 image_grid_thw: None,
397 video_grid_thw: None,
398 rows: None,
399 cols: None,
400 pixel_values_list: None,
401 tgt_sizes: None,
402 image_sizes_all: None,
403 num_crops: None,
404 })
405 }
406}