1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use std::any::Any;
3use std::num::NonZeroUsize;
4use std::sync::Arc;
5
6use candle_core::Result;
7use candle_core::{DType, Device, Tensor};
8use image::GenericImageView;
9use image::Rgb;
10use itertools::Itertools;
11use regex_automata::meta::Regex;
12use tokenizers::Tokenizer;
13use tracing::warn;
14
15use super::llava15::LLaVAVisionSpecificArgs;
16use super::utils::{expand2square, LLaVAImageProcessor};
17use crate::device_map::DeviceMapper;
18use crate::pipeline::text_models_inputs_processor::{
19 get_completion_input, get_prompt_input, PagedAttentionMeta,
20};
21use crate::pipeline::{
22 text_models_inputs_processor, InputProcessorOutput, InputsProcessor, InputsProcessorType,
23 MessagesAction, Processor,
24};
25use crate::sequence::Sequence;
26use crate::vision_models::image_processor::{self, ImagePreProcessor, PreprocessedImages};
27use crate::vision_models::llava::config::Config as LLaVAConfig;
28use crate::vision_models::preprocessor_config::{PreProcessorConfig, ToFilter};
29use crate::vision_models::{preprocessor_config, ModelInputs};
30
31pub struct LLaVAProcessor {
32 inputs_processor: Arc<LLaVAInputProcessor>,
33}
34
35impl Processor for LLaVAProcessor {
36 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
37 self.inputs_processor.clone()
38 }
39 fn get_special_tokens(&self) -> &[&'static str] {
40 &[]
41 }
42 fn template_action(&self) -> MessagesAction {
43 MessagesAction::FlattenOnlyText
44 }
45}
46
47impl LLaVAProcessor {
48 pub fn new(config: &str) -> Self {
49 let model_config =
50 serde_json::from_str::<LLaVAConfig>(config).expect("Failed to parse model config.");
51 let image_tag_splitter = Regex::new(r"<image>").expect("Failed to compile split regex.");
52 let inputs_processor = Arc::new(LLaVAInputProcessor {
53 image_tag_splitter,
54 model_config: model_config.clone(),
55 });
56 Self { inputs_processor }
57 }
58}
59
60pub struct LLaVAInputProcessor {
61 image_tag_splitter: Regex,
62 model_config: LLaVAConfig,
63}
64
65impl LLaVAInputProcessor {
66 pub fn get_num_image_tokens(cfg: &LLaVAConfig) -> usize {
67 let patch_size = cfg.vision_config.patch_size;
68 let patch_per_side = cfg.vision_config.image_size / patch_size;
69 patch_per_side * patch_per_side
70 }
71}
72
73impl InputsProcessor for LLaVAInputProcessor {
75 fn get_type(&self) -> InputsProcessorType {
76 InputsProcessorType::Vision
77 }
78 fn process_inputs(
79 &self,
80 tokenizer: Option<Arc<Tokenizer>>,
81 input_seqs: &mut [&mut Sequence],
82 is_prompt: bool,
83 is_xlora: bool,
84 device: &Device,
85 no_kv_cache: bool,
86 last_n_context_len: Option<(usize, usize)>,
87 return_raw_logits: bool,
88 other_config: Option<Arc<dyn Any>>,
89 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
90 prompt_chunksize: Option<NonZeroUsize>,
91 mapper: Option<&dyn DeviceMapper>,
92 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
93 if is_xlora {
94 return Box::new(std::iter::once(Err(anyhow::Error::msg(
95 "Cannot make inputs for X-LoRA vision model.",
96 ))));
97 }
98 if no_kv_cache {
99 return Box::new(std::iter::once(Err(anyhow::Error::msg(
100 "Vision model must have kv cache.",
101 ))));
102 }
103 if prompt_chunksize.is_some() {
105 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
106 }
107 let Some(tokenizer) = tokenizer else {
108 return Box::new(std::iter::once(Err(anyhow::Error::msg(
109 "LLaVAInputProcessor requires a specified tokenizer.",
110 ))));
111 };
112
113 let config = other_config
114 .clone()
115 .expect("Need a PreProcessorConfig config.");
116 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
117
118 let has_images = input_seqs.iter().all(|seq| seq.has_images());
119
120 let (pixel_values, num_img_tokens) = if has_images {
121 let mut pixel_values_accum = Vec::new();
122 let mut num_img_tokens_accum = Vec::new();
123 for seq in input_seqs.iter_mut() {
124 let imgs = seq
125 .take_images()
126 .expect("Need to have images by this point.");
127 let PreprocessedImages {
128 pixel_values,
129 pixel_attention_mask: _,
130 image_sizes: _,
131 num_img_tokens,
132 aspect_ratio_ids: _,
133 aspect_ratio_mask: _,
134 num_tiles: _,
135 image_grid_thw: _,
136 video_grid_thw: _,
137 rows: _,
138 cols: _,
139 pixel_values_list: _,
140 tgt_sizes: _,
141 image_sizes_all: _,
142 num_crops: _,
143 } = self
144 .preprocess(
145 imgs.clone(),
146 vec![],
147 config,
148 device,
149 (usize::MAX, usize::MAX),
150 )
151 .expect("Preprocessor failed");
152 pixel_values_accum.push(pixel_values);
153 num_img_tokens_accum.push(num_img_tokens.unwrap());
154 }
155 (
156 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157 Some(num_img_tokens_accum),
158 )
159 } else {
160 return Box::new(
161 text_models_inputs_processor::TextInputsProcessor
162 .process_inputs(
163 Some(tokenizer),
164 input_seqs,
165 is_prompt,
166 is_xlora,
167 device,
168 no_kv_cache,
169 last_n_context_len,
170 return_raw_logits,
171 other_config,
172 paged_attn_metadata,
173 None, mapper,
175 )
176 .map(|metadata| {
177 let InputProcessorOutput {
178 inputs,
179 seq_indices,
180 } = metadata?;
181
182 let text_models_inputs_processor::ModelInputs {
183 input_ids,
184 input_ids_full: _,
185 seqlen_offsets,
186 seqlen_offsets_full: _,
187 context_lens,
188 position_ids,
189 paged_attn_meta,
190 flash_meta,
191 flash_meta_full: _,
192 } = *inputs
193 .downcast::<text_models_inputs_processor::ModelInputs>()
194 .expect("Downcast failed.");
195
196 let inputs: Box<dyn Any> = Box::new(ModelInputs {
197 input_ids,
198 seqlen_offsets,
199 context_lens,
200 position_ids,
201 pixel_values: None,
202 model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
203 paged_attn_meta,
204 flash_meta,
205 });
206 Ok(InputProcessorOutput {
207 inputs,
208 seq_indices,
209 })
210 }),
211 );
212 };
213
214 let mut toks = Vec::new();
215 let detokenized = tokenizer
216 .decode_batch(
217 &input_seqs
218 .iter()
219 .map(|seq| seq.get_toks())
220 .collect::<Vec<_>>(),
221 false,
222 )
223 .expect("Decoding failed");
224
225 for (detokenized, (seq, num_img_tokens)) in detokenized.into_iter().zip(
226 input_seqs
227 .iter_mut()
228 .zip(num_img_tokens.unwrap().into_iter()),
229 ) {
230 let splits = self
231 .image_tag_splitter
232 .split(&detokenized)
233 .map(|span| &detokenized[span.range()])
234 .collect::<Vec<_>>();
235 let prompt_chunks = splits
236 .iter()
237 .map(|s| {
238 tokenizer
240 .encode_fast(*s, false)
241 .unwrap()
242 .get_ids()
243 .to_vec()
244 .iter()
245 .map(|x| *x as i64)
246 .collect()
247 })
248 .collect::<Vec<Vec<_>>>();
249 let mut image_ids_pad = Vec::new();
250 for (i, num_img_token) in num_img_tokens.iter().enumerate() {
251 let mut image_id_pad = vec![0; *num_img_token];
252 image_id_pad[0] = -(i as i64 + 1);
253 image_ids_pad.push(image_id_pad);
254 }
255 let mut input_ids: Vec<i64> = Vec::new();
256 for item in prompt_chunks
257 .iter()
258 .map(|x| x.to_vec())
259 .interleave(image_ids_pad)
260 {
261 input_ids.extend(item);
262 }
263 seq.set_toks_and_reallocate(
265 input_ids
266 .iter()
267 .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
268 .collect::<Vec<_>>(),
269 paged_attn_metadata.as_mut(),
270 );
271
272 toks.push(input_ids);
273 }
274
275 let iter = if is_prompt {
276 get_prompt_input(
277 toks,
278 input_seqs,
279 device,
280 last_n_context_len,
281 return_raw_logits,
282 paged_attn_metadata.as_mut(),
283 None, mapper,
285 )
286 } else {
287 get_completion_input(
288 toks,
289 input_seqs,
290 device,
291 no_kv_cache,
292 last_n_context_len,
293 return_raw_logits,
294 paged_attn_metadata.as_mut(),
295 None, mapper,
297 )
298 };
299
300 Box::new(iter.into_iter().map(move |metadata| {
301 let text_models_inputs_processor::InnerInputProcessorOutput {
302 inputs:
303 text_models_inputs_processor::InputMetadata {
304 input,
305 positions,
306 context_lens,
307 position_ids,
308 paged_attn_meta,
309 flash_meta,
310 },
311 seq_indices,
312 } = metadata?;
313 let inputs: Box<dyn Any> = Box::new(ModelInputs {
314 input_ids: input,
315 seqlen_offsets: positions,
316 context_lens,
317 position_ids,
318 pixel_values: pixel_values.clone(),
319 model_specific_args: Box::new(LLaVAVisionSpecificArgs {}),
320 paged_attn_meta,
321 flash_meta,
322 });
323 Ok(InputProcessorOutput {
324 inputs,
325 seq_indices,
326 })
327 }))
328 }
329}
330
331impl ImagePreProcessor for LLaVAInputProcessor {
332 #[allow(clippy::excessive_precision)]
333 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
334 #[allow(clippy::excessive_precision)]
335 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
336 fn preprocess(
337 &self,
338 images: Vec<image::DynamicImage>,
339 videos: Vec<Vec<image::DynamicImage>>,
340 config: &preprocessor_config::PreProcessorConfig,
341 device: &candle_core::Device,
342 (_, _): (usize, usize),
343 ) -> candle_core::Result<image_processor::PreprocessedImages> {
344 if images.len() > 1 {
345 candle_core::bail!("Can only process one image per batch"); };
347 assert!(videos.is_empty());
348 let resized_size = *config.size.as_ref().unwrap().get("shortest_edge").unwrap() as usize;
349
350 let original_size = images[0].dimensions();
351 let filter = config.resampling.to_filter()?;
352 let image_mean = config
353 .image_mean
354 .unwrap_or(Self::DEFAULT_MEAN)
355 .map(|x| x as f32);
356 let mean_color = image_mean
357 .iter()
358 .map(|x| ((*x) * 255.0) as u8)
359 .collect::<Vec<u8>>();
360 let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
361 let image = expand2square(&images[0], mean_color);
362 let image_std = config
363 .image_std
364 .unwrap_or(Self::DEFAULT_STD)
365 .map(|x| x as f32);
366 let pixel_values = [image]
367 .iter()
368 .map(|x| {
369 LLaVAImageProcessor::process_one_image(
370 x,
371 config,
372 resized_size as u32,
373 filter,
374 DType::BF16,
375 device,
376 &image_mean,
377 &image_std,
378 )
379 })
380 .collect::<Result<Vec<Tensor>>>()?;
381 let pixel_values = Tensor::stack(&pixel_values, 0)?;
382
383 Ok(image_processor::PreprocessedImages {
384 pixel_values,
385 pixel_attention_mask: None,
386 image_sizes: Some((original_size.0 as usize, original_size.1 as usize)),
387 num_img_tokens: Some(vec![LLaVAInputProcessor::get_num_image_tokens(
388 &self.model_config,
389 )]),
390 aspect_ratio_ids: None,
391 aspect_ratio_mask: None,
392 num_tiles: None,
393 image_grid_thw: None,
394 video_grid_thw: None,
395 rows: None,
396 cols: None,
397 pixel_values_list: None,
398 tgt_sizes: None,
399 image_sizes_all: None,
400 num_crops: None,
401 })
402 }
403}