1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9
10use crate::{
11 device_map::DeviceMapper,
12 pipeline::{
13 text_models_inputs_processor::{
14 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
15 },
16 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
17 },
18 sequence::Sequence,
19 vision_models::{
20 image_processor::{ImagePreProcessor, PreprocessedImages},
21 preprocessor_config::{PreProcessorConfig, ToFilter},
22 processor_config::ProcessorConfig,
23 ModelInputs,
24 },
25};
26
27use super::Mistral3SpecificArgs;
28
29const PLACEHOLDER: &str = "<placeholder>";
30
31struct Mistral3ImageProcessor {
32 image_break_token: String,
33 image_end_token: String,
34 image_token: String,
35 patch_size: usize,
36 spatial_merge_size: usize,
37}
38
39pub struct Mistral3Processor {
40 image_break_token: String,
41 image_end_token: String,
42 image_token: String,
43 patch_size: usize,
44 spatial_merge_size: usize,
45}
46
47impl Mistral3Processor {
48 pub fn new(processor_config: ProcessorConfig) -> Self {
49 Self {
50 image_break_token: processor_config.image_break_token.unwrap().clone(),
51 image_end_token: processor_config.image_end_token.unwrap().clone(),
52 image_token: processor_config.image_token.unwrap().clone(),
53 patch_size: processor_config.patch_size.unwrap(),
54 spatial_merge_size: processor_config.spatial_merge_size.unwrap(),
55 }
56 }
57}
58
59impl Processor for Mistral3Processor {
60 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
61 Arc::new(Mistral3ImageProcessor {
62 image_break_token: self.image_break_token.clone(),
63 image_end_token: self.image_end_token.clone(),
64 image_token: self.image_token.clone(),
65 patch_size: self.patch_size,
66 spatial_merge_size: self.spatial_merge_size,
67 })
68 }
69
70 fn get_special_tokens(&self) -> &[&'static str] {
71 &[]
72 }
73
74 fn template_action(&self) -> MessagesAction {
75 MessagesAction::Keep
76 }
77}
78
79impl InputsProcessor for Mistral3ImageProcessor {
80 fn get_type(&self) -> InputsProcessorType {
81 InputsProcessorType::Vision
82 }
83 fn process_inputs(
84 &self,
85 tokenizer: Option<Arc<Tokenizer>>,
86 input_seqs: &mut [&mut Sequence],
87 is_prompt: bool,
88 is_xlora: bool,
89 device: &Device,
90 no_kv_cache: bool,
91 last_n_context_len: Option<(usize, usize)>,
92 return_raw_logits: bool,
93 other_config: Option<Arc<dyn Any>>,
94 mut paged_attn_metadata: Option<PagedAttentionMeta>,
95 mapper: Option<&dyn DeviceMapper>,
96 ) -> anyhow::Result<InputProcessorOutput> {
97 if is_xlora {
98 return Err(anyhow::Error::msg(
99 "Cannot make inputs for X-LoRA vision model.",
100 ));
101 }
102 if no_kv_cache {
103 return Err(anyhow::Error::msg("Vision model must have kv cache."));
104 }
105 let Some(tokenizer) = tokenizer else {
106 return Err(anyhow::Error::msg(
107 "Idefics3ImageProcessor requires a specified tokenizer.",
108 ));
109 };
110
111 let config = other_config.expect("Need a PreProcessorConfig config.");
112 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
113
114 let has_images = input_seqs.iter().all(|seq| seq.has_images());
115
116 let (pixel_values, image_sizes) = if has_images {
117 let mut pixel_values_accum = Vec::new();
118 let mut image_sizes_accum = Vec::new();
119
120 for seq in input_seqs.iter_mut() {
121 let PreprocessedImages {
122 pixel_values,
123 pixel_attention_mask: _,
124 image_sizes: _,
125 num_img_tokens: _,
126 aspect_ratio_ids: _,
127 aspect_ratio_mask: _,
128 num_tiles: _,
129 image_grid_thw: _,
130 video_grid_thw: _,
131 rows: _,
132 cols: _,
133 pixel_values_list: _,
134 tgt_sizes: _,
135 image_sizes_all,
136 num_crops: _,
137 } = self
138 .preprocess(
139 seq.take_images()
140 .expect("Need to have images by this point."),
141 vec![],
142 config,
143 device,
144 (usize::MAX, usize::MAX), )
146 .expect("Preprocessing failed");
147 let image_sizes_all = image_sizes_all.unwrap();
148
149 pixel_values_accum.push(pixel_values.clone());
151 image_sizes_accum.extend_from_slice(&image_sizes_all);
152
153 let mut prompt = tokenizer
154 .decode(seq.get_toks(), false)
155 .expect("Detokenization failed!");
156
157 let mut image_sizes_all_iter = image_sizes_all.into_iter();
158 let mut replace_strings = Vec::new();
159 while prompt.contains(&self.image_token) {
160 let (height, width) = image_sizes_all_iter.next().unwrap();
161 let num_height_tokens =
162 (height as usize) / (self.patch_size * self.spatial_merge_size);
163 let num_width_tokens =
164 (width as usize) / (self.patch_size * self.spatial_merge_size);
165
166 let mut replace_tokens = vec![
167 [
168 vec![self.image_token.clone(); num_width_tokens],
169 vec![self.image_break_token.clone()],
170 ]
171 .concat();
172 num_height_tokens
173 ]
174 .into_iter()
175 .flatten()
176 .collect::<Vec<_>>();
177
178 *replace_tokens.last_mut().unwrap() = self.image_end_token.clone();
179
180 replace_strings.push(replace_tokens.join(""));
181 prompt = prompt.replace(&self.image_token, PLACEHOLDER);
182 }
183
184 while prompt.contains(PLACEHOLDER) {
185 let replace_str = replace_strings.pop().unwrap();
186 prompt = prompt.replace(PLACEHOLDER, &replace_str);
187 }
188
189 if !seq.multimodal.has_changed_prompt {
190 seq.set_initial_prompt(prompt.clone());
191 let toks = tokenizer
192 .encode_fast(prompt, false)
193 .expect("Detokenization failed!");
194
195 let ids = toks.get_ids().to_vec();
196 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
197 seq.multimodal.has_changed_prompt = true;
198 }
199 }
200
201 (
202 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
203 Some(image_sizes_accum),
204 )
205 } else {
206 (None, None)
207 };
208
209 let text_models_inputs_processor::InnerInputProcessorOutput {
210 inputs:
211 text_models_inputs_processor::InputMetadata {
212 input,
213 positions,
214 context_lens,
215 position_ids,
216 paged_attn_meta,
217 flash_meta,
218 },
219 seq_indices,
220 } = if is_prompt {
221 get_prompt_input(
222 input_seqs
223 .iter()
224 .map(|seq| seq.get_toks())
225 .collect::<Vec<_>>(),
226 input_seqs,
227 device,
228 last_n_context_len,
229 return_raw_logits,
230 paged_attn_metadata.as_mut(),
231 mapper,
232 )
233 .unwrap()
234 } else {
235 get_completion_input(
236 input_seqs
237 .iter()
238 .map(|seq| seq.get_toks())
239 .collect::<Vec<_>>(),
240 input_seqs,
241 device,
242 no_kv_cache,
243 last_n_context_len,
244 return_raw_logits,
245 paged_attn_metadata.as_mut(),
246 mapper,
247 )
248 .unwrap()
249 };
250
251 let inputs: Box<dyn Any> = Box::new(ModelInputs {
252 input_ids: input,
253 seqlen_offsets: positions,
254 context_lens,
255 position_ids,
256 pixel_values,
257 model_specific_args: Box::new(Mistral3SpecificArgs { image_sizes }),
258 paged_attn_meta,
259 flash_meta,
260 });
261 Ok(InputProcessorOutput {
262 inputs,
263 seq_indices,
264 })
265 }
266}
267
268impl Mistral3ImageProcessor {
269 #[allow(clippy::too_many_arguments)]
270 fn resize(
271 &self,
272 image: &DynamicImage,
273 mut height: usize,
274 mut width: usize,
275 max_height: usize,
276 max_width: usize,
277 patch_size: usize,
278 filter: FilterType,
279 ) -> DynamicImage {
280 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
281 if ratio > 1. {
282 height = (height as f64 / ratio).floor() as usize;
283 width = (width as f64 / ratio).floor() as usize;
284 }
285
286 let num_height_tokens = (height - 1) / patch_size + 1;
287 let num_width_tokens = (width - 1) / patch_size + 1;
288
289 let resize_height = num_height_tokens * patch_size;
290 let resize_width = num_width_tokens * patch_size;
291
292 image.resize_exact(resize_width as u32, resize_height as u32, filter)
293 }
294}
295
296impl ImagePreProcessor for Mistral3ImageProcessor {
297 #[allow(clippy::excessive_precision)]
298 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
299 #[allow(clippy::excessive_precision)]
300 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
301
302 fn preprocess(
304 &self,
305 mut images: Vec<DynamicImage>,
306 videos: Vec<Vec<DynamicImage>>,
307 config: &PreProcessorConfig,
308 device: &Device,
309 (_bs, _max_num_images): (usize, usize),
310 ) -> Result<PreprocessedImages> {
311 assert!(videos.is_empty());
312
313 let do_resize = config.do_resize.unwrap();
314 let do_rescale = config.do_rescale.unwrap();
315 let rescale_factor = config.rescale_factor.unwrap();
316 let do_normalize = config.do_normalize.unwrap();
317 let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
318 let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
319 let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
320 let patch_size = config.patch_size.unwrap();
321 let size = config.size.as_ref().unwrap();
322 let resample = config.resampling.to_filter()?;
323
324 let default_to_square = config.default_to_square.unwrap();
325 assert!(default_to_square);
326
327 let mut pixel_values = Vec::new();
328 let mut image_sizes = Vec::new();
329
330 let (max_height, max_width) = if size.contains_key("longest_edge") {
331 (size["longest_edge"] as usize, size["longest_edge"] as usize)
332 } else if size.contains_key("height") && size.contains_key("width") {
333 (size["height"] as usize, size["width"] as usize)
334 } else {
335 candle_core::bail!("Size must be a map of `longest_edge` or `height` and `width`.");
336 };
337
338 for image in images.iter_mut() {
339 let (width, height) = image.dimensions();
340
341 if do_convert_rgb {
343 *image = DynamicImage::ImageRgb8(image.to_rgb8());
344 }
345
346 if do_resize {
347 *image = self.resize(
348 image,
349 height as usize,
350 width as usize,
351 max_height,
352 max_width,
353 patch_size,
354 resample,
355 );
356 }
357
358 let (width, height) = image.dimensions();
359
360 image_sizes.push((height, width));
361 }
362
363 images = mistralrs_vision::pad_to_max_image_size(images);
364
365 for image in images.iter_mut() {
366 let transforms = Transforms {
367 input: &ToTensorNoNorm,
368 inner_transforms: &[
369 &do_rescale.then_some(Rescale {
370 factor: Some(rescale_factor),
371 }),
372 &do_normalize.then(|| Normalize {
373 mean: image_mean.to_vec(),
374 std: image_std.to_vec(),
375 }),
376 ],
377 };
378
379 let image = image.apply(transforms, device)?;
380 pixel_values.push(image.unsqueeze(0)?);
381 }
382
383 Ok(PreprocessedImages {
384 pixel_values: Tensor::cat(&pixel_values, 0)?,
385 pixel_attention_mask: None,
386 image_sizes: None,
387 num_img_tokens: None,
388 aspect_ratio_ids: None,
389 aspect_ratio_mask: None,
390 num_tiles: None,
391 image_grid_thw: None,
392 video_grid_thw: None,
393 rows: None,
394 cols: None,
395 pixel_values_list: None,
396 tgt_sizes: None,
397 image_sizes_all: Some(image_sizes),
398 num_crops: None,
399 })
400 }
401}