1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
9use regex_automata::meta::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14 device_map::DeviceMapper,
15 pipeline::{
16 text_models_inputs_processor::{
17 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18 },
19 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20 ProcessorCreator,
21 },
22 sequence::Sequence,
23};
24
25use crate::vision_models::{
26 image_processor::{ImagePreProcessor, PreprocessedImages},
27 phi3::Phi3VisionSpecificArgs,
28 preprocessor_config::PreProcessorConfig,
29 processor_config::ProcessorConfig,
30 ModelInputs,
31};
32
33pub struct Phi3InputsProcessor {
35 image_tag_splitter: Regex,
36}
37pub struct Phi3Processor {
39 inputs_processor: Arc<Phi3InputsProcessor>,
40}
41
42impl ProcessorCreator for Phi3Processor {
43 fn new_processor(
44 _: Option<ProcessorConfig>,
45 _: PreProcessorConfig,
46 ) -> Arc<dyn Processor + Send + Sync> {
47 Arc::new(Self {
48 inputs_processor: Arc::new(Phi3InputsProcessor {
49 image_tag_splitter: Regex::new(r"<\|image_\d+\|>")
50 .expect("Failed to compile split regex."),
51 }),
52 })
53 }
54}
55
56impl Processor for Phi3Processor {
57 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58 self.inputs_processor.clone()
59 }
60 fn get_special_tokens(&self) -> &[&'static str] {
61 &[]
62 }
63 fn template_action(&self) -> MessagesAction {
64 MessagesAction::FlattenOnlyText
65 }
66}
67
68impl InputsProcessor for Phi3InputsProcessor {
69 fn get_type(&self) -> InputsProcessorType {
70 InputsProcessorType::Vision
71 }
72 fn process_inputs(
73 &self,
74 tokenizer: Option<Arc<Tokenizer>>,
75 input_seqs: &mut [&mut Sequence],
76 is_prompt: bool,
77 is_xlora: bool,
78 device: &Device,
79 no_kv_cache: bool,
80 last_n_context_len: Option<(usize, usize)>,
81 return_raw_logits: bool,
82 other_config: Option<Arc<dyn Any>>,
83 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84 prompt_chunksize: Option<NonZeroUsize>,
85 mapper: Option<&dyn DeviceMapper>,
86 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87 if is_xlora {
88 return Box::new(std::iter::once(Err(anyhow::Error::msg(
89 "Cannot make inputs for X-LoRA vision model.",
90 ))));
91 }
92 if no_kv_cache {
93 return Box::new(std::iter::once(Err(anyhow::Error::msg(
94 "Vision model must have kv cache.",
95 ))));
96 }
97 if prompt_chunksize.is_some() {
99 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100 }
101 let Some(tokenizer) = tokenizer else {
102 return Box::new(std::iter::once(Err(anyhow::Error::msg(
103 "Phi3InputProcessor requires a specified tokenizer.",
104 ))));
105 };
106
107 let config = other_config
108 .clone()
109 .expect("Need a PreProcessorConfig config.");
110 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112 let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114 let (pixel_values, image_sizes, num_img_tokens, n_images) = if has_images {
115 let mut pixel_values_accum = Vec::new();
116 let mut image_sizes_accum = Vec::new();
117 let mut num_img_tokens_accum = Vec::new();
118 let mut n_images = Vec::new();
119 for seq in input_seqs.iter_mut() {
120 let imgs = seq
121 .take_images()
122 .expect("Need to have images by this point.");
123 let imgs_len = imgs.len();
124 n_images.push(imgs_len);
125 let PreprocessedImages {
126 pixel_values,
127 pixel_attention_mask: _,
128 image_sizes,
129 num_img_tokens,
130 aspect_ratio_ids: _,
131 aspect_ratio_mask: _,
132 num_tiles: _,
133 image_grid_thw: _,
134 video_grid_thw: _,
135 rows: _,
136 cols: _,
137 pixel_values_list: _,
138 tgt_sizes: _,
139 image_sizes_all: _,
140 num_crops: _,
141 } = self
142 .preprocess(
143 imgs,
144 vec![],
145 config,
146 device,
147 (usize::MAX, usize::MAX), )
149 .expect("Preprocessor failed");
150 let image_sizes = image_sizes.unwrap();
151 pixel_values_accum.push(pixel_values);
152 image_sizes_accum.push(image_sizes);
153 num_img_tokens_accum.push(num_img_tokens.unwrap());
154 }
155 (
156 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157 Some(image_sizes_accum),
158 Some(num_img_tokens_accum),
159 n_images,
160 )
161 } else {
162 return Box::new(
163 text_models_inputs_processor::TextInputsProcessor
164 .process_inputs(
165 Some(tokenizer),
166 input_seqs,
167 is_prompt,
168 is_xlora,
169 device,
170 no_kv_cache,
171 last_n_context_len,
172 return_raw_logits,
173 other_config,
174 paged_attn_metadata,
175 None, mapper,
177 )
178 .map(|metadata| {
179 let InputProcessorOutput {
180 inputs,
181 seq_indices,
182 } = metadata?;
183
184 let text_models_inputs_processor::ModelInputs {
185 input_ids,
186 input_ids_full: _,
187 seqlen_offsets,
188 seqlen_offsets_full: _,
189 context_lens,
190 position_ids,
191 paged_attn_meta,
192 flash_meta,
193 flash_meta_full: _,
194 } = *inputs
195 .downcast::<text_models_inputs_processor::ModelInputs>()
196 .expect("Downcast failed.");
197
198 let inputs: Box<dyn Any> = Box::new(ModelInputs {
199 input_ids,
200 seqlen_offsets,
201 context_lens,
202 position_ids,
203 pixel_values: None,
204 model_specific_args: Box::new(Phi3VisionSpecificArgs {
205 image_sizes: None,
206 }),
207 paged_attn_meta,
208 flash_meta,
209 });
210 Ok(InputProcessorOutput {
211 inputs,
212 seq_indices,
213 })
214 }),
215 );
216 };
217
218 let mut toks = Vec::new();
219 let detokenized = tokenizer
220 .decode_batch(
221 &input_seqs
222 .iter()
223 .map(|seq| seq.get_toks())
224 .collect::<Vec<_>>(),
225 false,
226 )
227 .expect("Decode failed");
228
229 for (detokenized, (seq, (num_img_tokens, n_images))) in detokenized.into_iter().zip(
230 input_seqs
231 .iter_mut()
232 .zip(num_img_tokens.unwrap().into_iter().zip(n_images)),
233 ) {
234 let splits = self
235 .image_tag_splitter
236 .split(&detokenized)
237 .map(|span| &detokenized[span.range()])
238 .collect::<Vec<_>>();
239 let prompt_chunks = tokenizer
240 .encode_batch(splits, true)
241 .expect("Encode failed")
242 .into_iter()
243 .map(|enc| enc.get_ids().to_vec())
244 .collect::<Vec<_>>();
245
246 let image_tags = self.image_tag_splitter.find_iter(&detokenized);
247 let image_ids = image_tags
248 .into_iter()
249 .map(|s| {
250 let s = &detokenized[s.range()];
251 s.split('|')
252 .nth(1)
253 .unwrap()
254 .split('_')
255 .nth(1)
256 .unwrap()
257 .parse::<u32>()
258 .expect("Failed to parse image id to u32")
259 })
260 .collect::<Vec<_>>();
261 let unique_image_ids = image_ids
262 .iter()
263 .copied()
264 .unique()
265 .sorted()
266 .collect::<Vec<_>>();
267 if unique_image_ids != (1u32..unique_image_ids.len() as u32 + 1).collect::<Vec<_>>() {
269 return Box::new(std::iter::once(Err(anyhow::Error::msg(
270 "`image_ids` must start from 1, and must be continuous, e.g. [1, 2, 3], cannot be [1, 4, 5].",
271 ))));
272 }
273 if unique_image_ids.len() != n_images {
275 return Box::new(std::iter::once(Err(anyhow::Error::msg(
276 "Total images must be the same as the number of image tags.",
277 ))));
278 }
279
280 let image_ids_pad = image_ids
282 .iter()
283 .map(|id| {
284 [-(*id as i64)].repeat(
285 num_img_tokens[TryInto::<usize>::try_into(*id as isize - 1)
286 .unwrap_or(num_img_tokens.len() - 1)],
287 )
288 })
289 .collect::<Vec<_>>();
290
291 let mut input_ids: Vec<i64> = Vec::new();
292 for item in prompt_chunks
293 .iter()
294 .map(|x| x.iter().map(|x| *x as i64).collect::<Vec<_>>())
295 .interleave(image_ids_pad)
296 {
297 input_ids.extend(item);
298 }
299
300 seq.set_toks_and_reallocate(
302 input_ids
303 .iter()
304 .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
305 .collect::<Vec<_>>(),
306 paged_attn_metadata.as_mut(),
307 );
308
309 toks.push(input_ids);
310 }
311
312 let iter = if is_prompt {
313 get_prompt_input(
314 toks,
315 input_seqs,
316 device,
317 last_n_context_len,
318 return_raw_logits,
319 paged_attn_metadata.as_mut(),
320 None, mapper,
322 )
323 } else {
324 get_completion_input(
325 toks,
326 input_seqs,
327 device,
328 no_kv_cache,
329 last_n_context_len,
330 return_raw_logits,
331 paged_attn_metadata.as_mut(),
332 None, mapper,
334 )
335 };
336
337 Box::new(iter.into_iter().map(move |metadata| {
338 let text_models_inputs_processor::InnerInputProcessorOutput {
339 inputs:
340 text_models_inputs_processor::InputMetadata {
341 input,
342 positions,
343 context_lens,
344 position_ids,
345 paged_attn_meta,
346 flash_meta,
347 },
348 seq_indices,
349 } = metadata?;
350 let inputs: Box<dyn Any> = Box::new(ModelInputs {
351 input_ids: input,
352 seqlen_offsets: positions,
353 context_lens,
354 position_ids,
355 pixel_values: pixel_values.clone(),
356 model_specific_args: Box::new(Phi3VisionSpecificArgs {
357 image_sizes: image_sizes.clone(),
358 }),
359 paged_attn_meta,
360 flash_meta,
361 });
362 Ok(InputProcessorOutput {
363 inputs,
364 seq_indices,
365 })
366 }))
367 }
368}
369
370impl Phi3InputsProcessor {
371 fn pad_image(
372 image: &DynamicImage,
373 top: u32,
374 bottom: u32,
375 left: u32,
376 right: u32,
377 pad_color: Rgba<u8>,
378 ) -> DynamicImage {
379 let new_width = image.width() + left + right;
381 let new_height = image.height() + top + bottom;
382
383 let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
385 for x in 0..new_width {
386 for y in 0..new_height {
387 new_image.put_pixel(x, y, pad_color);
388 }
389 }
390
391 new_image
393 .copy_from(image, left, top)
394 .expect("Failed to copy image");
395
396 new_image
397 }
398
399 fn padding_336(img: &DynamicImage) -> DynamicImage {
400 let (_width, height) = img.dimensions();
401 let tar = ((height as f64 / 336.0).ceil() * 336.0) as u32;
402 let top_padding = ((tar as f64 - height as f64 + 1.) / 2.) as u32;
403 let bottom_padding = tar - height - top_padding;
404 let left_padding = 0u32;
405 let right_padding = 0u32;
406 Self::pad_image(
407 img,
408 top_padding,
409 bottom_padding,
410 left_padding,
411 right_padding,
412 Rgba([255u8, 255, 255, 255]),
413 )
414 }
415
416 fn hd_transform(img: &DynamicImage, hd_num: usize) -> DynamicImage {
417 let (mut width, mut height) = img.dimensions();
418 let mut transposed = false;
419
420 let img = if width < height {
421 let img = img.rotate90();
422 transposed = true;
423 width = img.width();
424 height = img.height();
425 img
426 } else {
427 img.clone()
429 };
430
431 let ratio = width as f64 / height as f64;
432 let mut scale = 1.0;
433 while (scale * (scale / ratio).ceil()) <= hd_num as f64 {
434 scale += 1.0;
435 }
436 scale -= 1.0;
437
438 let new_width = (scale * 336.0) as u32;
439 let new_height = (new_width as f64 / ratio) as u32;
440
441 let resized_img = img.resize_exact(new_width, new_height, FilterType::Nearest);
442 let padded_img = Self::padding_336(&resized_img);
443
444 if transposed {
445 return padded_img.rotate270();
446 }
447
448 padded_img
449 }
450}
451
452fn pad_to_max_num_crops_tensor(image: &Tensor, max_crops: usize) -> Result<Tensor> {
453 let (b, _, h, w) = image.dims4()?;
454 if b < max_crops {
455 let pad = Tensor::zeros((max_crops - b, 3, h, w), image.dtype(), image.device())?;
456 Tensor::cat(&[image, &pad], 0)
457 } else {
458 Ok(image.clone())
459 }
460}
461
462impl ImagePreProcessor for Phi3InputsProcessor {
463 #[allow(clippy::excessive_precision)]
464 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
465 #[allow(clippy::excessive_precision)]
466 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
467
468 fn preprocess(
469 &self,
470 mut images: Vec<DynamicImage>,
471 videos: Vec<Vec<DynamicImage>>,
472 config: &PreProcessorConfig,
473 device: &Device,
474 (_, _): (usize, usize),
475 ) -> Result<PreprocessedImages> {
476 assert!(!images.is_empty());
478 assert!(videos.is_empty());
479
480 let mut image_sizes = Vec::new();
481 let mut padded_images = Vec::new();
482 let mut num_img_tokens = Vec::new();
483 let mut max_size = None;
485 for image in images.iter() {
486 if max_size.is_none() {
487 max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
488 } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
489 max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
490 } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
491 max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
492 }
493 }
494 let (max_h, max_w) = max_size.unwrap();
495 for image in images.iter_mut() {
496 *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
497 }
498
499 for image in images.iter_mut() {
500 if config.do_convert_rgb.unwrap_or(true) {
502 *image = DynamicImage::ImageRgb8(image.to_rgb8());
503 }
504
505 let hd_image = Self::hd_transform(image, config.num_crops.expect("Need `num_crops`"));
506
507 let transforms_hd = Transforms {
510 input: &ToTensor,
511 inner_transforms: &[&Normalize {
512 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
513 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
514 }],
515 };
516
517 let hd_image = hd_image.apply(transforms_hd, device)?;
519
520 let global_image = hd_image.unsqueeze(0)?.interpolate2d(336, 336)?;
523
524 let (_, h, w) = hd_image.dims3()?;
525 let num_image_tokens = ((h as f32 / 336. * w as f32 / 336. + 1.) * 144.
526 + ((h as f32 / 336.) + 1.) * 12.
527 + 1.) as usize;
528
529 let hd_image_reshape = hd_image
530 .reshape((
531 1,
532 3,
533 (h as f32 / 336.) as usize,
534 336,
535 (w as f32 / 336.) as usize,
536 336,
537 ))?
538 .permute((0, 2, 4, 1, 3, 5))?
539 .reshape(((), 3, 336, 336))?;
540 let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
541 let image_transformed = pad_to_max_num_crops_tensor(
542 &hd_image_reshape,
543 config.num_crops.expect("Need `num_crops`") + 1,
544 )?;
545 image_sizes.push((h, w));
546 padded_images.push(image_transformed);
547 num_img_tokens.push(num_image_tokens);
548 }
549 if padded_images.len() > 1 {
550 candle_core::bail!("Can only process one image per batch");
551 }
552 let image_sizes = image_sizes[0];
553
554 Ok(PreprocessedImages {
555 pixel_values: Tensor::stack(&padded_images, 0)?,
556 image_sizes: Some((image_sizes.0, image_sizes.1)),
557 pixel_attention_mask: None,
558 num_img_tokens: Some(num_img_tokens),
559 aspect_ratio_ids: None,
560 aspect_ratio_mask: None,
561 num_tiles: None,
562 image_grid_thw: None,
563 video_grid_thw: None,
564 rows: None,
565 cols: None,
566 pixel_values_list: None,
567 tgt_sizes: None,
568 image_sizes_all: None,
569 num_crops: None,
570 })
571 }
572}