1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImage, GenericImageView, Rgba};
7use itertools::Itertools;
8use mistralrs_vision::{ApplyTransforms, Normalize, ToTensor, Transforms};
9use regex_automata::meta::Regex;
10use tokenizers::Tokenizer;
11use tracing::warn;
12
13use crate::{
14 device_map::DeviceMapper,
15 pipeline::{
16 text_models_inputs_processor::{
17 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
18 },
19 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
20 ProcessorCreator,
21 },
22 sequence::Sequence,
23};
24
25use crate::vision_models::{
26 image_processor::{ImagePreProcessor, PreprocessedImages},
27 phi3::Phi3VisionSpecificArgs,
28 preprocessor_config::PreProcessorConfig,
29 processor_config::ProcessorConfig,
30 ModelInputs,
31};
32
33pub struct Phi3InputsProcessor {
35 image_tag_splitter: Regex,
36}
37pub struct Phi3Processor {
39 inputs_processor: Arc<Phi3InputsProcessor>,
40}
41
42impl ProcessorCreator for Phi3Processor {
43 fn new_processor(
44 _: Option<ProcessorConfig>,
45 _: PreProcessorConfig,
46 ) -> Arc<dyn Processor + Send + Sync> {
47 Arc::new(Self {
48 inputs_processor: Arc::new(Phi3InputsProcessor {
49 image_tag_splitter: Regex::new(r"<\|image_\d+\|>")
50 .expect("Failed to compile split regex."),
51 }),
52 })
53 }
54}
55
56impl Processor for Phi3Processor {
57 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
58 self.inputs_processor.clone()
59 }
60 fn get_special_tokens(&self) -> &[&'static str] {
61 &[]
62 }
63 fn template_action(&self) -> MessagesAction {
64 MessagesAction::FlattenOnlyText
65 }
66}
67
68impl InputsProcessor for Phi3InputsProcessor {
69 fn get_type(&self) -> InputsProcessorType {
70 InputsProcessorType::Vision
71 }
72 fn process_inputs(
73 &self,
74 tokenizer: Option<Arc<Tokenizer>>,
75 input_seqs: &mut [&mut Sequence],
76 is_prompt: bool,
77 is_xlora: bool,
78 device: &Device,
79 no_kv_cache: bool,
80 last_n_context_len: Option<(usize, usize)>,
81 return_raw_logits: bool,
82 other_config: Option<Arc<dyn Any>>,
83 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
84 prompt_chunksize: Option<NonZeroUsize>,
85 mapper: Option<&dyn DeviceMapper>,
86 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
87 if is_xlora {
88 return Box::new(std::iter::once(Err(anyhow::Error::msg(
89 "Cannot make inputs for X-LoRA vision model.",
90 ))));
91 }
92 if no_kv_cache {
93 return Box::new(std::iter::once(Err(anyhow::Error::msg(
94 "Vision model must have kv cache.",
95 ))));
96 }
97 if prompt_chunksize.is_some() {
99 warn!("`prompt_chunksize` is set. Idefics 2 does not support prompt batching.");
100 }
101 let Some(tokenizer) = tokenizer else {
102 return Box::new(std::iter::once(Err(anyhow::Error::msg(
103 "Phi3InputProcessor requires a specified tokenizer.",
104 ))));
105 };
106
107 let config = other_config
108 .clone()
109 .expect("Need a PreProcessorConfig config.");
110 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
111
112 let has_images = input_seqs.iter().all(|seq| seq.has_images());
113
114 let (pixel_values, image_sizes, num_img_tokens, n_images) = if has_images {
115 let mut pixel_values_accum = Vec::new();
116 let mut image_sizes_accum = Vec::new();
117 let mut num_img_tokens_accum = Vec::new();
118 let mut n_images = Vec::new();
119 for seq in input_seqs.iter_mut() {
120 let imgs = seq
121 .take_images()
122 .expect("Need to have images by this point.");
123 let imgs_len = imgs.len();
124 n_images.push(imgs_len);
125 let PreprocessedImages {
126 pixel_values,
127 pixel_attention_mask: _,
128 image_sizes,
129 num_img_tokens,
130 aspect_ratio_ids: _,
131 aspect_ratio_mask: _,
132 num_tiles: _,
133 image_grid_thw: _,
134 video_grid_thw: _,
135 rows: _,
136 cols: _,
137 pixel_values_list: _,
138 tgt_sizes: _,
139 image_sizes_all: _,
140 num_crops: _,
141 } = self
142 .preprocess(
143 imgs,
144 vec![],
145 config,
146 device,
147 (usize::MAX, usize::MAX), )
149 .expect("Preprocessor failed");
150 let image_sizes = image_sizes.unwrap();
151 pixel_values_accum.push(pixel_values);
152 image_sizes_accum.push(image_sizes);
153 num_img_tokens_accum.push(num_img_tokens.unwrap());
154 }
155 (
156 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
157 Some(image_sizes_accum),
158 Some(num_img_tokens_accum),
159 n_images,
160 )
161 } else {
162 return Box::new(
163 text_models_inputs_processor::TextInputsProcessor
164 .process_inputs(
165 Some(tokenizer),
166 input_seqs,
167 is_prompt,
168 is_xlora,
169 device,
170 no_kv_cache,
171 last_n_context_len,
172 return_raw_logits,
173 other_config,
174 paged_attn_metadata,
175 None, mapper,
177 )
178 .map(|metadata| {
179 let InputProcessorOutput {
180 inputs,
181 seq_indices,
182 } = metadata?;
183
184 let text_models_inputs_processor::ModelInputs {
185 input_ids,
186 input_ids_full: _,
187 seqlen_offsets,
188 seqlen_offsets_full: _,
189 context_lens,
190 position_ids,
191 paged_attn_meta,
192 flash_meta,
193 flash_meta_full: _,
194 } = *inputs
195 .downcast::<text_models_inputs_processor::ModelInputs>()
196 .expect("Downcast failed.");
197
198 let inputs: Box<dyn Any> = Box::new(ModelInputs {
199 input_ids,
200 seqlen_offsets,
201 context_lens,
202 position_ids,
203 pixel_values: None,
204 model_specific_args: Box::new(Phi3VisionSpecificArgs {
205 image_sizes: None,
206 }),
207 paged_attn_meta,
208 flash_meta,
209 });
210 Ok(InputProcessorOutput {
211 inputs,
212 seq_indices,
213 })
214 }),
215 );
216 };
217
218 let mut toks = Vec::new();
219 let detokenized = tokenizer
220 .decode_batch(
221 &input_seqs
222 .iter()
223 .map(|seq| seq.get_toks())
224 .collect::<Vec<_>>(),
225 false,
226 )
227 .expect("Decode failed");
228
229 for (detokenized, (seq, (num_img_tokens, n_images))) in detokenized.into_iter().zip(
230 input_seqs
231 .iter_mut()
232 .zip(num_img_tokens.unwrap().into_iter().zip(n_images)),
233 ) {
234 let splits = self
235 .image_tag_splitter
236 .split(&detokenized)
237 .map(|span| &detokenized[span.range()])
238 .collect::<Vec<_>>();
239 let prompt_chunks = tokenizer
240 .encode_batch(splits, true)
241 .expect("Encode failed")
242 .into_iter()
243 .map(|enc| enc.get_ids().to_vec())
244 .collect::<Vec<_>>();
245
246 let image_tags = self.image_tag_splitter.find_iter(&detokenized);
247 let image_ids = image_tags
248 .into_iter()
249 .map(|s| {
250 let s = &detokenized[s.range()];
251 s.split('|')
252 .nth(1)
253 .unwrap()
254 .split('_')
255 .nth(1)
256 .unwrap()
257 .parse::<u32>()
258 .expect("Failed to parse image id to u32")
259 })
260 .collect::<Vec<_>>();
261 let unique_image_ids = image_ids
262 .iter()
263 .copied()
264 .unique()
265 .sorted()
266 .collect::<Vec<_>>();
267 if unique_image_ids != (1u32..unique_image_ids.len() as u32 + 1).collect::<Vec<_>>() {
269 return Box::new(std::iter::once(Err(anyhow::Error::msg(
270 "`image_ids` must start from 1, and must be continuous, e.g. [1, 2, 3], cannot be [1, 4, 5].",
271 ))));
272 }
273 if unique_image_ids.len() != n_images {
275 return Box::new(std::iter::once(Err(anyhow::Error::msg(
276 "Total images must be the same as the number of image tags.",
277 ))));
278 }
279
280 let image_ids_pad = image_ids
282 .iter()
283 .map(|id| {
284 [-(*id as i64)].repeat(
285 num_img_tokens[TryInto::<usize>::try_into(*id as isize - 1)
286 .unwrap_or(num_img_tokens.len() - 1)],
287 )
288 })
289 .collect::<Vec<_>>();
290
291 let mut input_ids: Vec<i64> = Vec::new();
292 for item in prompt_chunks
293 .iter()
294 .map(|x| x.iter().map(|x| *x as i64).collect::<Vec<_>>())
295 .interleave(image_ids_pad)
296 {
297 input_ids.extend(item);
298 }
299
300 let new_ids = input_ids
301 .iter()
302 .map(|x| if *x < 0 { 0u32 } else { *x as u32 })
303 .collect::<Vec<_>>();
304 if !seq.has_changed_prompt {
305 let new_prompt = tokenizer.decode(&new_ids, false).unwrap();
306 seq.set_initial_prompt(new_prompt);
307 seq.set_toks_and_reallocate(new_ids, paged_attn_metadata.as_mut());
309 seq.has_changed_prompt = true;
310 }
311
312 toks.push(input_ids);
313 }
314
315 let iter = if is_prompt {
316 get_prompt_input(
317 toks,
318 input_seqs,
319 device,
320 last_n_context_len,
321 return_raw_logits,
322 paged_attn_metadata.as_mut(),
323 None, mapper,
325 )
326 } else {
327 get_completion_input(
328 toks,
329 input_seqs,
330 device,
331 no_kv_cache,
332 last_n_context_len,
333 return_raw_logits,
334 paged_attn_metadata.as_mut(),
335 None, mapper,
337 )
338 };
339
340 Box::new(iter.into_iter().map(move |metadata| {
341 let text_models_inputs_processor::InnerInputProcessorOutput {
342 inputs:
343 text_models_inputs_processor::InputMetadata {
344 input,
345 positions,
346 context_lens,
347 position_ids,
348 paged_attn_meta,
349 flash_meta,
350 },
351 seq_indices,
352 } = metadata?;
353 let inputs: Box<dyn Any> = Box::new(ModelInputs {
354 input_ids: input,
355 seqlen_offsets: positions,
356 context_lens,
357 position_ids,
358 pixel_values: pixel_values.clone(),
359 model_specific_args: Box::new(Phi3VisionSpecificArgs {
360 image_sizes: image_sizes.clone(),
361 }),
362 paged_attn_meta,
363 flash_meta,
364 });
365 Ok(InputProcessorOutput {
366 inputs,
367 seq_indices,
368 })
369 }))
370 }
371}
372
373impl Phi3InputsProcessor {
374 fn pad_image(
375 image: &DynamicImage,
376 top: u32,
377 bottom: u32,
378 left: u32,
379 right: u32,
380 pad_color: Rgba<u8>,
381 ) -> DynamicImage {
382 let new_width = image.width() + left + right;
384 let new_height = image.height() + top + bottom;
385
386 let mut new_image = DynamicImage::new_rgb8(new_width, new_height);
388 for x in 0..new_width {
389 for y in 0..new_height {
390 new_image.put_pixel(x, y, pad_color);
391 }
392 }
393
394 new_image
396 .copy_from(image, left, top)
397 .expect("Failed to copy image");
398
399 new_image
400 }
401
402 fn padding_336(img: &DynamicImage) -> DynamicImage {
403 let (_width, height) = img.dimensions();
404 let tar = ((height as f64 / 336.0).ceil() * 336.0) as u32;
405 let top_padding = ((tar as f64 - height as f64 + 1.) / 2.) as u32;
406 let bottom_padding = tar - height - top_padding;
407 let left_padding = 0u32;
408 let right_padding = 0u32;
409 Self::pad_image(
410 img,
411 top_padding,
412 bottom_padding,
413 left_padding,
414 right_padding,
415 Rgba([255u8, 255, 255, 255]),
416 )
417 }
418
419 fn hd_transform(img: &DynamicImage, hd_num: usize) -> DynamicImage {
420 let (mut width, mut height) = img.dimensions();
421 let mut transposed = false;
422
423 let img = if width < height {
424 let img = img.rotate90();
425 transposed = true;
426 width = img.width();
427 height = img.height();
428 img
429 } else {
430 img.clone()
432 };
433
434 let ratio = width as f64 / height as f64;
435 let mut scale = 1.0;
436 while (scale * (scale / ratio).ceil()) <= hd_num as f64 {
437 scale += 1.0;
438 }
439 scale -= 1.0;
440
441 let new_width = (scale * 336.0) as u32;
442 let new_height = (new_width as f64 / ratio) as u32;
443
444 let resized_img = img.resize_exact(new_width, new_height, FilterType::Nearest);
445 let padded_img = Self::padding_336(&resized_img);
446
447 if transposed {
448 return padded_img.rotate270();
449 }
450
451 padded_img
452 }
453}
454
455fn pad_to_max_num_crops_tensor(image: &Tensor, max_crops: usize) -> Result<Tensor> {
456 let (b, _, h, w) = image.dims4()?;
457 if b < max_crops {
458 let pad = Tensor::zeros((max_crops - b, 3, h, w), image.dtype(), image.device())?;
459 Tensor::cat(&[image, &pad], 0)
460 } else {
461 Ok(image.clone())
462 }
463}
464
465impl ImagePreProcessor for Phi3InputsProcessor {
466 #[allow(clippy::excessive_precision)]
467 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
468 #[allow(clippy::excessive_precision)]
469 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
470
471 fn preprocess(
472 &self,
473 mut images: Vec<DynamicImage>,
474 videos: Vec<Vec<DynamicImage>>,
475 config: &PreProcessorConfig,
476 device: &Device,
477 (_, _): (usize, usize),
478 ) -> Result<PreprocessedImages> {
479 assert!(!images.is_empty());
481 assert!(videos.is_empty());
482
483 let mut image_sizes = Vec::new();
484 let mut padded_images = Vec::new();
485 let mut num_img_tokens = Vec::new();
486 let mut max_size = None;
488 for image in images.iter() {
489 if max_size.is_none() {
490 max_size = Some((image.dimensions().0 as usize, image.dimensions().1 as usize))
491 } else if max_size.is_some_and(|(x, _)| image.dimensions().0 as usize > x) {
492 max_size = Some((image.dimensions().0 as usize, max_size.unwrap().1));
493 } else if max_size.is_some_and(|(_, y)| image.dimensions().1 as usize > y) {
494 max_size = Some((max_size.unwrap().0, image.dimensions().1 as usize));
495 }
496 }
497 let (max_h, max_w) = max_size.unwrap();
498 for image in images.iter_mut() {
499 *image = image.resize_exact(max_w as u32, max_h as u32, FilterType::Nearest);
500 }
501
502 for image in images.iter_mut() {
503 if config.do_convert_rgb.unwrap_or(true) {
505 *image = DynamicImage::ImageRgb8(image.to_rgb8());
506 }
507
508 let hd_image = Self::hd_transform(image, config.num_crops.expect("Need `num_crops`"));
509
510 let transforms_hd = Transforms {
513 input: &ToTensor,
514 inner_transforms: &[&Normalize {
515 mean: config.image_mean.unwrap_or(Self::DEFAULT_MEAN).to_vec(),
516 std: config.image_std.unwrap_or(Self::DEFAULT_STD).to_vec(),
517 }],
518 };
519
520 let hd_image = hd_image.apply(transforms_hd, device)?;
522
523 let global_image = hd_image.unsqueeze(0)?.interpolate2d(336, 336)?;
526
527 let (_, h, w) = hd_image.dims3()?;
528 let num_image_tokens = ((h as f32 / 336. * w as f32 / 336. + 1.) * 144.
529 + ((h as f32 / 336.) + 1.) * 12.
530 + 1.) as usize;
531
532 let hd_image_reshape = hd_image
533 .reshape((
534 1,
535 3,
536 (h as f32 / 336.) as usize,
537 336,
538 (w as f32 / 336.) as usize,
539 336,
540 ))?
541 .permute((0, 2, 4, 1, 3, 5))?
542 .reshape(((), 3, 336, 336))?;
543 let hd_image_reshape = Tensor::cat(&[global_image, hd_image_reshape], 0)?;
544 let image_transformed = pad_to_max_num_crops_tensor(
545 &hd_image_reshape,
546 config.num_crops.expect("Need `num_crops`") + 1,
547 )?;
548 image_sizes.push((h, w));
549 padded_images.push(image_transformed);
550 num_img_tokens.push(num_image_tokens);
551 }
552 if padded_images.len() > 1 {
553 candle_core::bail!("Can only process one image per batch");
554 }
555 let image_sizes = image_sizes[0];
556
557 Ok(PreprocessedImages {
558 pixel_values: Tensor::stack(&padded_images, 0)?,
559 image_sizes: Some((image_sizes.0, image_sizes.1)),
560 pixel_attention_mask: None,
561 num_img_tokens: Some(num_img_tokens),
562 aspect_ratio_ids: None,
563 aspect_ratio_mask: None,
564 num_tiles: None,
565 image_grid_thw: None,
566 video_grid_thw: None,
567 rows: None,
568 cols: None,
569 pixel_values_list: None,
570 tgt_sizes: None,
571 image_sizes_all: None,
572 num_crops: None,
573 })
574 }
575}