mistralrs_core/vision_models/llava/
utils.rs1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_precision_loss,
4 clippy::too_many_arguments
5)]
6use crate::vision_models::preprocessor_config::PreProcessorConfig;
7use candle_core::{DType, Device, Result, Tensor};
8use image::{
9 imageops::{overlay, FilterType},
10 DynamicImage, GenericImageView, Rgb, RgbImage,
11};
12use std::cmp::min;
13
14pub(crate) fn get_anyres_image_grid_shape(
15 image_size: (u32, u32),
16 grid_pinpoints: &[(u32, u32)],
17 patch_size: u32,
18) -> (u32, u32) {
19 let (width, height) = select_best_resolution(image_size, grid_pinpoints);
20 (width / patch_size, height / patch_size)
21}
22
23pub(crate) fn get_num_samples(
24 image_size: (u32, u32),
25 grid_pinpoints: &[(u32, u32)],
26 crop_size: (u32, u32),
27) -> u32 {
28 let (width, height) = select_best_resolution(image_size, grid_pinpoints);
29 width / crop_size.0 * height / crop_size.1 + 1
30}
31
32pub(crate) fn select_best_resolution(
33 original_size: (u32, u32),
34 possible_resolutions: &[(u32, u32)],
35) -> (u32, u32) {
36 let (original_width, original_height) = original_size;
37 let mut best_fit = (0, 0);
38 let original_width_f = original_width as f32;
39 let original_height_f = original_height as f32;
40 let mut max_effective_resolution = 0_u32;
41 let mut min_wasted_resolution = u32::MAX;
42 for (width, height) in possible_resolutions {
43 let width_f = *width as f32;
44 let height_f = *height as f32;
45 let scale = (width_f / original_width_f).min(height_f / original_height_f);
46 let (downscaled_width, downscaled_height) = (
47 (original_width_f * scale) as u32,
48 (original_height_f * scale) as u32,
49 );
50 let effective_resolution =
51 std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
52 let wasted_resolution = (*width) * (*height) - effective_resolution;
53 if effective_resolution > max_effective_resolution
54 || (effective_resolution == max_effective_resolution
55 && wasted_resolution < min_wasted_resolution)
56 {
57 best_fit = (*width, *height);
58 max_effective_resolution = effective_resolution;
59 min_wasted_resolution = wasted_resolution;
60 }
61 }
62 best_fit
63}
64
65pub(crate) fn calculate_unpad(size: (u32, u32), original_size: (u32, u32)) -> (u32, u32) {
66 let (original_width, original_height) = original_size;
67 let (current_width, current_height) = size;
68 let original_aspect_ratio = (original_width as f32) / (original_height as f32);
69 let current_aspect_ratio = (current_width as f32) / (current_height as f32);
70 if original_aspect_ratio > current_aspect_ratio {
71 let scale_factor = (current_width as f32) / (original_width as f32);
72 let new_height = (original_height as f32 * scale_factor).floor() as u32;
73 let padding = (current_height - new_height) / 2;
74 (current_width, current_height - 2 * padding) } else {
76 let scale_factor = (current_height as f32) / (original_height as f32);
77 let new_width = (original_width as f32 * scale_factor).floor() as u32;
78 let padding = (current_width - new_width) / 2;
79 (current_width - 2 * padding, current_height)
80 }
81}
82
83pub(crate) fn resize_and_pad_image(
84 image: &DynamicImage,
85 target_resolution: (u32, u32),
86) -> DynamicImage {
87 let (original_width, original_height) = image.dimensions();
88 let original_width_f = original_width as f32;
89 let original_height_f = original_height as f32;
90 let (target_width, target_height) = target_resolution;
91 let target_width_f = target_width as f32;
92 let target_height_f = target_height as f32;
93 let scale_w = target_width_f / original_width_f;
94 let scale_h = target_height_f / original_height_f;
95 let (new_width, new_height) = if scale_w < scale_h {
96 (
97 target_width,
98 min((original_height_f * scale_w).ceil() as u32, target_height),
99 )
100 } else {
101 (
102 min((original_width_f * scale_h).ceil() as u32, target_width),
103 target_height,
104 )
105 };
106 let resized_image = image.resize_exact(
107 new_width,
108 new_height,
109 image::imageops::FilterType::CatmullRom,
110 );
111 let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
112 let (paste_x, paste_y) =
113 calculate_middle((target_width, target_height), (new_width, new_height));
114 overlay(
115 &mut new_image,
116 &resized_image,
117 paste_x.into(),
118 paste_y.into(),
119 );
120 new_image
121}
122
123pub(crate) fn divide_to_samples(image: &DynamicImage, crop_size: (u32, u32)) -> Vec<DynamicImage> {
124 let (width, height) = image.dimensions();
125 let mut samples = Vec::new();
126 for y in (0..height).step_by(crop_size.1 as usize) {
127 for x in (0..width).step_by(crop_size.0 as usize) {
128 let patch = image.crop_imm(x, y, crop_size.0, crop_size.1);
129 samples.push(patch);
130 }
131 }
132 samples
133}
134
135pub(crate) fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
136 let (width, height) = image_size;
137 let (center_width, center_height) = center_size;
138 let left = if width <= center_width {
139 0
140 } else {
141 ((width as f32 - center_width as f32) / 2.0).ceil() as u32
142 };
143 let top = if height <= center_height {
144 0
145 } else {
146 ((height as f32 - center_height as f32) / 2.0).ceil() as u32
147 };
148 (left, top)
149}
150
151pub(crate) fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
152 let (width, height) = image.dimensions();
153 match width.cmp(&height) {
154 std::cmp::Ordering::Less => {
155 let mut new_image =
156 DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
157 overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
158 new_image
159 }
160 std::cmp::Ordering::Equal => image.clone(),
161 std::cmp::Ordering::Greater => {
162 let mut new_image =
163 DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
164 overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
165 new_image
166 }
167 }
168}
169
170pub struct LLaVAImageProcessor;
171
172impl LLaVAImageProcessor {
173 fn resize(image: &DynamicImage, size: u32, filter: FilterType) -> DynamicImage {
174 let (width, height) = image.dimensions();
175 if width == size && height == size {
176 image.clone()
177 } else {
178 let (new_width, new_height) = if width < height {
179 (
180 size,
181 (((size * height) as f32) / width as f32).ceil() as u32,
182 )
183 } else {
184 (
185 (((size * width) as f32) / height as f32).ceil() as u32,
186 size,
187 )
188 };
189 image.resize(new_width, new_height, filter)
190 }
191 }
192
193 fn center_crop(image: &DynamicImage, crop_size: (u32, u32)) -> DynamicImage {
194 let (width, height) = image.dimensions();
195 let (left, top) = calculate_middle((width, height), crop_size);
196 image.crop_imm(left, top, crop_size.0, crop_size.1)
197 }
198
199 fn rescale(tensor: &Tensor, rescale_factor: f64) -> Result<Tensor> {
200 tensor.affine(rescale_factor, 0.0)
201 }
202
203 fn to_tensor(image: &DynamicImage, device: &Device) -> Result<Tensor> {
204 let img = image.to_rgb8().into_raw();
205 let (width, height) = image.dimensions();
206 Tensor::from_vec(img, (height as usize, width as usize, 3), device)?.to_dtype(DType::F32)
207 }
208
209 fn normalize(tensor: &Tensor, image_mean: &[f32], image_std: &[f32]) -> Result<Tensor> {
210 let mean = Tensor::from_slice(image_mean, (3,), &Device::Cpu)?;
211 let std = Tensor::from_slice(image_std, (3,), &Device::Cpu)?;
212 tensor.broadcast_sub(&mean)?.broadcast_div(&std)
213 }
214
215 fn to_channel_dimension_format(tensor: &Tensor) -> Result<Tensor> {
216 tensor.permute((2, 0, 1))
217 }
218 pub fn process_one_image(
219 image: &DynamicImage,
220 preprocessor_config: &PreProcessorConfig,
221 resize_size: u32,
222 filter: FilterType,
223 dtype: DType,
224 device: &Device,
225 image_mean: &[f32],
226 image_std: &[f32],
227 ) -> Result<Tensor> {
228 let mut image = if preprocessor_config.do_resize.unwrap_or(true) {
229 Self::resize(image, resize_size, filter)
230 } else {
231 image.clone()
232 };
233 image = if preprocessor_config.do_center_crop.unwrap_or(true) {
234 let crop_width = *preprocessor_config
235 .crop_size
236 .as_ref()
237 .unwrap()
238 .get("width")
239 .unwrap();
240 let crop_height = *preprocessor_config
241 .crop_size
242 .as_ref()
243 .unwrap()
244 .get("height")
245 .unwrap();
246 Self::center_crop(&image, (crop_width, crop_height))
247 } else {
248 image
249 };
250 let mut pixel_value = Self::to_tensor(&image, &Device::Cpu)?;
251 if preprocessor_config.do_rescale.unwrap_or(true) {
252 let rescale_factor = preprocessor_config.rescale_factor.unwrap();
253 pixel_value = Self::rescale(&pixel_value, rescale_factor)?;
254 }
255 if preprocessor_config.do_normalize.unwrap_or(true) {
256 pixel_value = Self::normalize(&pixel_value, image_mean, image_std)?;
257 }
258 Self::to_channel_dimension_format(&pixel_value)?
259 .to_dtype(dtype)?
260 .to_device(device)
261 }
262}