mistralrs_core/vision_models/llava/
utils.rs

1#![allow(
2    clippy::cast_possible_truncation,
3    clippy::cast_precision_loss,
4    clippy::too_many_arguments
5)]
6use crate::vision_models::preprocessor_config::PreProcessorConfig;
7use candle_core::{DType, Device, Result, Tensor};
8use image::{
9    imageops::{overlay, FilterType},
10    DynamicImage, GenericImageView, Rgb, RgbImage,
11};
12use std::cmp::min;
13
14pub(crate) fn get_anyres_image_grid_shape(
15    image_size: (u32, u32),
16    grid_pinpoints: &[(u32, u32)],
17    patch_size: u32,
18) -> (u32, u32) {
19    let (width, height) = select_best_resolution(image_size, grid_pinpoints);
20    (width / patch_size, height / patch_size)
21}
22
23pub(crate) fn get_num_samples(
24    image_size: (u32, u32),
25    grid_pinpoints: &[(u32, u32)],
26    crop_size: (u32, u32),
27) -> u32 {
28    let (width, height) = select_best_resolution(image_size, grid_pinpoints);
29    width / crop_size.0 * height / crop_size.1 + 1
30}
31
32pub(crate) fn select_best_resolution(
33    original_size: (u32, u32),
34    possible_resolutions: &[(u32, u32)],
35) -> (u32, u32) {
36    let (original_width, original_height) = original_size;
37    let mut best_fit = (0, 0);
38    let original_width_f = original_width as f32;
39    let original_height_f = original_height as f32;
40    let mut max_effective_resolution = 0_u32;
41    let mut min_wasted_resolution = u32::MAX;
42    for (width, height) in possible_resolutions {
43        let width_f = *width as f32;
44        let height_f = *height as f32;
45        let scale = (width_f / original_width_f).min(height_f / original_height_f);
46        let (downscaled_width, downscaled_height) = (
47            (original_width_f * scale) as u32,
48            (original_height_f * scale) as u32,
49        );
50        let effective_resolution =
51            std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
52        let wasted_resolution = (*width) * (*height) - effective_resolution;
53        if effective_resolution > max_effective_resolution
54            || (effective_resolution == max_effective_resolution
55                && wasted_resolution < min_wasted_resolution)
56        {
57            best_fit = (*width, *height);
58            max_effective_resolution = effective_resolution;
59            min_wasted_resolution = wasted_resolution;
60        }
61    }
62    best_fit
63}
64
65pub(crate) fn calculate_unpad(size: (u32, u32), original_size: (u32, u32)) -> (u32, u32) {
66    let (original_width, original_height) = original_size;
67    let (current_width, current_height) = size;
68    let original_aspect_ratio = (original_width as f32) / (original_height as f32);
69    let current_aspect_ratio = (current_width as f32) / (current_height as f32);
70    if original_aspect_ratio > current_aspect_ratio {
71        let scale_factor = (current_width as f32) / (original_width as f32);
72        let new_height = (original_height as f32 * scale_factor).floor() as u32;
73        let padding = (current_height - new_height) / 2;
74        (current_width, current_height - 2 * padding) // as it is in unpad_image
75    } else {
76        let scale_factor = (current_height as f32) / (original_height as f32);
77        let new_width = (original_width as f32 * scale_factor).floor() as u32;
78        let padding = (current_width - new_width) / 2;
79        (current_width - 2 * padding, current_height)
80    }
81}
82
83pub(crate) fn resize_and_pad_image(
84    image: &DynamicImage,
85    target_resolution: (u32, u32),
86) -> DynamicImage {
87    let (original_width, original_height) = image.dimensions();
88    let original_width_f = original_width as f32;
89    let original_height_f = original_height as f32;
90    let (target_width, target_height) = target_resolution;
91    let target_width_f = target_width as f32;
92    let target_height_f = target_height as f32;
93    let scale_w = target_width_f / original_width_f;
94    let scale_h = target_height_f / original_height_f;
95    let (new_width, new_height) = if scale_w < scale_h {
96        (
97            target_width,
98            min((original_height_f * scale_w).ceil() as u32, target_height),
99        )
100    } else {
101        (
102            min((original_width_f * scale_h).ceil() as u32, target_width),
103            target_height,
104        )
105    };
106    let resized_image = image.resize_exact(
107        new_width,
108        new_height,
109        image::imageops::FilterType::CatmullRom,
110    );
111    let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
112    let (paste_x, paste_y) =
113        calculate_middle((target_width, target_height), (new_width, new_height));
114    overlay(
115        &mut new_image,
116        &resized_image,
117        paste_x.into(),
118        paste_y.into(),
119    );
120    new_image
121}
122
123pub(crate) fn divide_to_samples(image: &DynamicImage, crop_size: (u32, u32)) -> Vec<DynamicImage> {
124    let (width, height) = image.dimensions();
125    let mut samples = Vec::new();
126    for y in (0..height).step_by(crop_size.1 as usize) {
127        for x in (0..width).step_by(crop_size.0 as usize) {
128            let patch = image.crop_imm(x, y, crop_size.0, crop_size.1);
129            samples.push(patch);
130        }
131    }
132    samples
133}
134
135pub(crate) fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
136    let (width, height) = image_size;
137    let (center_width, center_height) = center_size;
138    let left = if width <= center_width {
139        0
140    } else {
141        ((width as f32 - center_width as f32) / 2.0).ceil() as u32
142    };
143    let top = if height <= center_height {
144        0
145    } else {
146        ((height as f32 - center_height as f32) / 2.0).ceil() as u32
147    };
148    (left, top)
149}
150
151pub(crate) fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
152    let (width, height) = image.dimensions();
153    match width.cmp(&height) {
154        std::cmp::Ordering::Less => {
155            let mut new_image =
156                DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
157            overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
158            new_image
159        }
160        std::cmp::Ordering::Equal => image.clone(),
161        std::cmp::Ordering::Greater => {
162            let mut new_image =
163                DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
164            overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
165            new_image
166        }
167    }
168}
169
170pub struct LLaVAImageProcessor;
171
172impl LLaVAImageProcessor {
173    fn resize(image: &DynamicImage, size: u32, filter: FilterType) -> DynamicImage {
174        let (width, height) = image.dimensions();
175        if width == size && height == size {
176            image.clone()
177        } else {
178            let (new_width, new_height) = if width < height {
179                (
180                    size,
181                    (((size * height) as f32) / width as f32).ceil() as u32,
182                )
183            } else {
184                (
185                    (((size * width) as f32) / height as f32).ceil() as u32,
186                    size,
187                )
188            };
189            image.resize(new_width, new_height, filter)
190        }
191    }
192
193    fn center_crop(image: &DynamicImage, crop_size: (u32, u32)) -> DynamicImage {
194        let (width, height) = image.dimensions();
195        let (left, top) = calculate_middle((width, height), crop_size);
196        image.crop_imm(left, top, crop_size.0, crop_size.1)
197    }
198
199    fn rescale(tensor: &Tensor, rescale_factor: f64) -> Result<Tensor> {
200        tensor.affine(rescale_factor, 0.0)
201    }
202
203    fn to_tensor(image: &DynamicImage, device: &Device) -> Result<Tensor> {
204        let img = image.to_rgb8().into_raw();
205        let (width, height) = image.dimensions();
206        Tensor::from_vec(img, (height as usize, width as usize, 3), device)?.to_dtype(DType::F32)
207    }
208
209    fn normalize(tensor: &Tensor, image_mean: &[f32], image_std: &[f32]) -> Result<Tensor> {
210        let mean = Tensor::from_slice(image_mean, (3,), &Device::Cpu)?;
211        let std = Tensor::from_slice(image_std, (3,), &Device::Cpu)?;
212        tensor.broadcast_sub(&mean)?.broadcast_div(&std)
213    }
214
215    fn to_channel_dimension_format(tensor: &Tensor) -> Result<Tensor> {
216        tensor.permute((2, 0, 1))
217    }
218    pub fn process_one_image(
219        image: &DynamicImage,
220        preprocessor_config: &PreProcessorConfig,
221        resize_size: u32,
222        filter: FilterType,
223        dtype: DType,
224        device: &Device,
225        image_mean: &[f32],
226        image_std: &[f32],
227    ) -> Result<Tensor> {
228        let mut image = if preprocessor_config.do_resize.unwrap_or(true) {
229            Self::resize(image, resize_size, filter)
230        } else {
231            image.clone()
232        };
233        image = if preprocessor_config.do_center_crop.unwrap_or(true) {
234            let crop_width = *preprocessor_config
235                .crop_size
236                .as_ref()
237                .unwrap()
238                .get("width")
239                .unwrap();
240            let crop_height = *preprocessor_config
241                .crop_size
242                .as_ref()
243                .unwrap()
244                .get("height")
245                .unwrap();
246            Self::center_crop(&image, (crop_width, crop_height))
247        } else {
248            image
249        };
250        let mut pixel_value = Self::to_tensor(&image, &Device::Cpu)?;
251        if preprocessor_config.do_rescale.unwrap_or(true) {
252            let rescale_factor = preprocessor_config.rescale_factor.unwrap();
253            pixel_value = Self::rescale(&pixel_value, rescale_factor)?;
254        }
255        if preprocessor_config.do_normalize.unwrap_or(true) {
256            pixel_value = Self::normalize(&pixel_value, image_mean, image_std)?;
257        }
258        Self::to_channel_dimension_format(&pixel_value)?
259            .to_dtype(dtype)?
260            .to_device(device)
261    }
262}