1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{any::Any, num::NonZeroUsize, sync::Arc};
4
5use candle_core::{Device, Result, Tensor};
6use image::{imageops::FilterType, DynamicImage, GenericImageView};
7use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
8use tokenizers::Tokenizer;
9use tracing::warn;
10
11use crate::{
12 device_map::DeviceMapper,
13 pipeline::{
14 text_models_inputs_processor::{
15 self, get_completion_input, get_prompt_input, PagedAttentionMeta,
16 },
17 InputProcessorOutput, InputsProcessor, InputsProcessorType, MessagesAction, Processor,
18 },
19 sequence::Sequence,
20 vision_models::{
21 image_processor::{ImagePreProcessor, PreprocessedImages},
22 preprocessor_config::{PreProcessorConfig, ToFilter},
23 processor_config::ProcessorConfig,
24 ModelInputs,
25 },
26};
27
28use super::Mistral3SpecificArgs;
29
30const PLACEHOLDER: &str = "<placeholder>";
31
32struct Mistral3ImageProcessor {
33 image_break_token: String,
34 image_end_token: String,
35 image_token: String,
36 patch_size: usize,
37 spatial_merge_size: usize,
38}
39
40pub struct Mistral3Processor {
41 image_break_token: String,
42 image_end_token: String,
43 image_token: String,
44 patch_size: usize,
45 spatial_merge_size: usize,
46}
47
48impl Mistral3Processor {
49 pub fn new(processor_config: ProcessorConfig) -> Self {
50 Self {
51 image_break_token: processor_config.image_break_token.unwrap().clone(),
52 image_end_token: processor_config.image_end_token.unwrap().clone(),
53 image_token: processor_config.image_token.unwrap().clone(),
54 patch_size: processor_config.patch_size.unwrap(),
55 spatial_merge_size: processor_config.spatial_merge_size.unwrap(),
56 }
57 }
58}
59
60impl Processor for Mistral3Processor {
61 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
62 Arc::new(Mistral3ImageProcessor {
63 image_break_token: self.image_break_token.clone(),
64 image_end_token: self.image_end_token.clone(),
65 image_token: self.image_token.clone(),
66 patch_size: self.patch_size,
67 spatial_merge_size: self.spatial_merge_size,
68 })
69 }
70
71 fn get_special_tokens(&self) -> &[&'static str] {
72 &[]
73 }
74
75 fn template_action(&self) -> MessagesAction {
76 MessagesAction::Keep
77 }
78}
79
80impl InputsProcessor for Mistral3ImageProcessor {
81 fn get_type(&self) -> InputsProcessorType {
82 InputsProcessorType::Vision
83 }
84 fn process_inputs(
85 &self,
86 tokenizer: Option<Arc<Tokenizer>>,
87 input_seqs: &mut [&mut Sequence],
88 is_prompt: bool,
89 is_xlora: bool,
90 device: &Device,
91 no_kv_cache: bool,
92 last_n_context_len: Option<(usize, usize)>,
93 return_raw_logits: bool,
94 other_config: Option<Arc<dyn Any>>,
95 mut paged_attn_metadata: Option<PagedAttentionMeta<'_>>,
96 prompt_chunksize: Option<NonZeroUsize>,
97 mapper: Option<&dyn DeviceMapper>,
98 ) -> Box<dyn Iterator<Item = anyhow::Result<InputProcessorOutput>>> {
99 if is_xlora {
100 return Box::new(std::iter::once(Err(anyhow::Error::msg(
101 "Cannot make inputs for X-LoRA vision model.",
102 ))));
103 }
104 if no_kv_cache {
105 return Box::new(std::iter::once(Err(anyhow::Error::msg(
106 "Vision model must have kv cache.",
107 ))));
108 }
109 if prompt_chunksize.is_some() {
111 warn!("`prompt_chunksize` is set. Mistral3 does not support prompt batching.");
112 }
113 let Some(tokenizer) = tokenizer else {
114 return Box::new(std::iter::once(Err(anyhow::Error::msg(
115 "Idefics3ImageProcessor requires a specified tokenizer.",
116 ))));
117 };
118
119 let config = other_config.expect("Need a PreProcessorConfig config.");
120 let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
121
122 let has_images = input_seqs.iter().all(|seq| seq.has_images());
123
124 let (pixel_values, image_sizes) = if has_images {
125 let mut pixel_values_accum = Vec::new();
126 let mut image_sizes_accum = Vec::new();
127
128 for seq in input_seqs.iter_mut() {
129 let PreprocessedImages {
130 pixel_values,
131 pixel_attention_mask: _,
132 image_sizes: _,
133 num_img_tokens: _,
134 aspect_ratio_ids: _,
135 aspect_ratio_mask: _,
136 num_tiles: _,
137 image_grid_thw: _,
138 video_grid_thw: _,
139 rows: _,
140 cols: _,
141 pixel_values_list: _,
142 tgt_sizes: _,
143 image_sizes_all,
144 num_crops: _,
145 } = self
146 .preprocess(
147 seq.take_images()
148 .expect("Need to have images by this point."),
149 vec![],
150 config,
151 device,
152 (usize::MAX, usize::MAX), )
154 .expect("Preprocessing failed");
155 let image_sizes_all = image_sizes_all.unwrap();
156
157 pixel_values_accum.push(pixel_values.clone());
159 image_sizes_accum.extend_from_slice(&image_sizes_all);
160
161 let mut prompt = tokenizer
162 .decode(seq.get_toks(), false)
163 .expect("Detokenization failed!");
164
165 let mut image_sizes_all_iter = image_sizes_all.into_iter();
166 let mut replace_strings = Vec::new();
167 while prompt.contains(&self.image_token) {
168 let (height, width) = image_sizes_all_iter.next().unwrap();
169 let num_height_tokens =
170 (height as usize) / (self.patch_size * self.spatial_merge_size);
171 let num_width_tokens =
172 (width as usize) / (self.patch_size * self.spatial_merge_size);
173
174 let mut replace_tokens = vec![
175 [
176 vec![self.image_token.clone(); num_width_tokens],
177 vec![self.image_break_token.clone()],
178 ]
179 .concat();
180 num_height_tokens
181 ]
182 .into_iter()
183 .flatten()
184 .collect::<Vec<_>>();
185
186 *replace_tokens.last_mut().unwrap() = self.image_end_token.clone();
187
188 replace_strings.push(replace_tokens.join(""));
189 prompt = prompt.replace(&self.image_token, PLACEHOLDER);
190 }
191
192 while prompt.contains(PLACEHOLDER) {
193 let replace_str = replace_strings.pop().unwrap();
194 prompt = prompt.replace(PLACEHOLDER, &replace_str);
195 }
196
197 if !seq.has_changed_prompt {
198 seq.set_initial_prompt(prompt.clone());
199 let toks = tokenizer
200 .encode_fast(prompt, false)
201 .expect("Detokenization failed!");
202
203 let ids = toks.get_ids().to_vec();
204 seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
205 seq.has_changed_prompt = true;
206 }
207 }
208
209 (
210 Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
211 Some(image_sizes_accum),
212 )
213 } else {
214 (None, None)
215 };
216
217 let text_models_inputs_processor::InnerInputProcessorOutput {
218 inputs:
219 text_models_inputs_processor::InputMetadata {
220 input,
221 positions,
222 context_lens,
223 position_ids,
224 paged_attn_meta,
225 flash_meta,
226 },
227 seq_indices,
228 } = if is_prompt {
229 get_prompt_input(
230 input_seqs
231 .iter()
232 .map(|seq| seq.get_toks().to_vec())
233 .collect::<Vec<_>>(),
234 input_seqs,
235 device,
236 last_n_context_len,
237 return_raw_logits,
238 paged_attn_metadata.as_mut(),
239 None, mapper,
241 )
242 .nth(0)
243 .unwrap()
244 .unwrap()
245 } else {
246 get_completion_input(
247 input_seqs
248 .iter()
249 .map(|seq| seq.get_toks().to_vec())
250 .collect::<Vec<_>>(),
251 input_seqs,
252 device,
253 no_kv_cache,
254 last_n_context_len,
255 return_raw_logits,
256 paged_attn_metadata.as_mut(),
257 None, mapper,
259 )
260 .nth(0)
261 .unwrap()
262 .unwrap()
263 };
264
265 let inputs: Box<dyn Any> = Box::new(ModelInputs {
266 input_ids: input,
267 seqlen_offsets: positions,
268 context_lens,
269 position_ids,
270 pixel_values,
271 model_specific_args: Box::new(Mistral3SpecificArgs { image_sizes }),
272 paged_attn_meta,
273 flash_meta,
274 });
275 Box::new(std::iter::once(Ok(InputProcessorOutput {
276 inputs,
277 seq_indices,
278 })))
279 }
280}
281
282impl Mistral3ImageProcessor {
283 #[allow(clippy::too_many_arguments)]
284 fn resize(
285 &self,
286 image: &DynamicImage,
287 mut height: usize,
288 mut width: usize,
289 max_height: usize,
290 max_width: usize,
291 patch_size: usize,
292 filter: FilterType,
293 ) -> DynamicImage {
294 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
295 if ratio > 1. {
296 height = (height as f64 / ratio).floor() as usize;
297 width = (width as f64 / ratio).floor() as usize;
298 }
299
300 let num_height_tokens = (height - 1) / patch_size + 1;
301 let num_width_tokens = (width - 1) / patch_size + 1;
302
303 let resize_height = num_height_tokens * patch_size;
304 let resize_width = num_width_tokens * patch_size;
305
306 image.resize_exact(resize_width as u32, resize_height as u32, filter)
307 }
308}
309
310impl ImagePreProcessor for Mistral3ImageProcessor {
311 #[allow(clippy::excessive_precision)]
312 const DEFAULT_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073];
313 #[allow(clippy::excessive_precision)]
314 const DEFAULT_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711];
315
316 fn preprocess(
318 &self,
319 mut images: Vec<DynamicImage>,
320 videos: Vec<Vec<DynamicImage>>,
321 config: &PreProcessorConfig,
322 device: &Device,
323 (_bs, _max_num_images): (usize, usize),
324 ) -> Result<PreprocessedImages> {
325 assert!(videos.is_empty());
326
327 let do_resize = config.do_resize.unwrap();
328 let do_rescale = config.do_rescale.unwrap();
329 let rescale_factor = config.rescale_factor.unwrap();
330 let do_normalize = config.do_normalize.unwrap();
331 let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
332 let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
333 let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
334 let patch_size = config.patch_size.unwrap();
335 let size = config.size.as_ref().unwrap();
336 let resample = config.resampling.to_filter()?;
337
338 let default_to_square = config.default_to_square.unwrap();
339 assert!(default_to_square);
340
341 let mut pixel_values = Vec::new();
342 let mut image_sizes = Vec::new();
343
344 let (max_height, max_width) = if size.contains_key("longest_edge") {
345 (size["longest_edge"] as usize, size["longest_edge"] as usize)
346 } else if size.contains_key("height") && size.contains_key("width") {
347 (size["height"] as usize, size["width"] as usize)
348 } else {
349 candle_core::bail!("Size must be a map of `longest_edge` or `height` and `width`.");
350 };
351
352 for image in images.iter_mut() {
353 let (width, height) = image.dimensions();
354
355 if do_convert_rgb {
357 *image = DynamicImage::ImageRgb8(image.to_rgb8());
358 }
359
360 if do_resize {
361 *image = self.resize(
362 image,
363 height as usize,
364 width as usize,
365 max_height,
366 max_width,
367 patch_size,
368 resample,
369 );
370 }
371
372 let (width, height) = image.dimensions();
373
374 image_sizes.push((height, width));
375 }
376
377 images = mistralrs_vision::pad_to_max_image_size(images);
378
379 for image in images.iter_mut() {
380 let transforms = Transforms {
381 input: &ToTensorNoNorm,
382 inner_transforms: &[
383 &do_rescale.then_some(Rescale {
384 factor: Some(rescale_factor),
385 }),
386 &do_normalize.then(|| Normalize {
387 mean: image_mean.to_vec(),
388 std: image_std.to_vec(),
389 }),
390 ],
391 };
392
393 let image = image.apply(transforms, device)?;
394 pixel_values.push(image.unsqueeze(0)?);
395 }
396
397 Ok(PreprocessedImages {
398 pixel_values: Tensor::cat(&pixel_values, 0)?,
399 pixel_attention_mask: None,
400 image_sizes: None,
401 num_img_tokens: None,
402 aspect_ratio_ids: None,
403 aspect_ratio_mask: None,
404 num_tiles: None,
405 image_grid_thw: None,
406 video_grid_thw: None,
407 rows: None,
408 cols: None,
409 pixel_values_list: None,
410 tgt_sizes: None,
411 image_sizes_all: Some(image_sizes),
412 num_crops: None,
413 })
414 }
415}