mistralrs_vision/
transforms.rs

1use crate::utils::image_to_pixels;
2use candle_core::{Device, Result, Tensor, D};
3use image::DynamicImage;
4
5use crate::ImageTransform;
6
7/// Convert an image to a tensor. This converts the data from being in `[0, 255]` to `[0.0, 1.0]`.
8/// The tensor's shape is (channels, height, width).
9pub struct ToTensor;
10
11impl ImageTransform for ToTensor {
12    type Input = DynamicImage;
13    type Output = Tensor;
14    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
15        image_to_pixels(x, device)? / 255.
16    }
17}
18
19/// Convert an image to a tensor without normalizing to `[0.0, 1.0]`.
20/// The tensor's shape is (channels, height, width).
21pub struct ToTensorNoNorm;
22
23impl ImageTransform for ToTensorNoNorm {
24    type Input = DynamicImage;
25    type Output = Tensor;
26    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
27        image_to_pixels(x, device)
28    }
29}
30
31/// Normalize the image data based on the mean and standard deviation.
32/// The value is computed as follows:
33/// `
34/// x[channel] = (x[channel] - mean[channel]) / std[channel]
35/// `
36///
37/// Expects an input tensor of shape (channels, height, width).
38pub struct Normalize {
39    pub mean: Vec<f64>,
40    pub std: Vec<f64>,
41}
42
43impl ImageTransform for Normalize {
44    type Input = Tensor;
45    type Output = Self::Input;
46
47    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
48        let num_channels = x.dim(D::Minus(3))?;
49        if self.mean.len() != num_channels || self.std.len() != num_channels {
50            candle_core::bail!(
51                "Num channels ({}) must match number of mean ({}) and std ({}).",
52                num_channels,
53                self.mean.len(),
54                self.std.len()
55            );
56        }
57        let dtype = x.dtype();
58        let device = x.device();
59        let mean = Tensor::from_slice(
60            &self.mean.iter().map(|x| *x as f32).collect::<Vec<_>>(),
61            (num_channels,),
62            device,
63        )?
64        .to_dtype(dtype)?;
65        let std = Tensor::from_slice(
66            &self.std.iter().map(|x| *x as f32).collect::<Vec<_>>(),
67            (num_channels,),
68            device,
69        )?
70        .to_dtype(dtype)?;
71        let mean = mean.reshape((num_channels, 1, 1))?;
72        let std = std.reshape((num_channels, 1, 1))?;
73        x.broadcast_sub(&mean)?.broadcast_div(&std)
74    }
75}
76
77/// Resize the image via nearest interpolation.
78pub struct InterpolateResize {
79    pub target_w: usize,
80    pub target_h: usize,
81}
82
83impl ImageTransform for InterpolateResize {
84    type Input = Tensor;
85    type Output = Self::Input;
86
87    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
88        x.unsqueeze(0)?
89            .interpolate2d(self.target_h, self.target_w)?
90            .squeeze(0)
91    }
92}
93
94impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
95    type Input = T::Input;
96    type Output = T::Output;
97
98    fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
99        if let Some(this) = self {
100            this.map(x, dev)
101        } else {
102            Ok(x.clone())
103        }
104    }
105}
106
107/// Multiply the pixe values by the provided factor.
108///
109/// Each pixel value is calculated as follows: x = x * factor
110pub struct Rescale {
111    pub factor: Option<f64>,
112}
113
114impl ImageTransform for Rescale {
115    type Input = Tensor;
116    type Output = Self::Input;
117
118    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
119        if let Some(factor) = self.factor {
120            x * factor
121        } else {
122            Ok(x.clone())
123        }
124    }
125}
126
127mod tests {
128    #[test]
129    fn test_to_tensor() {
130        use candle_core::Device;
131        use image::{ColorType, DynamicImage};
132
133        use crate::ImageTransform;
134
135        use super::ToTensor;
136
137        let image = DynamicImage::new(4, 5, ColorType::Rgb8);
138        let res = ToTensor.map(&image, &Device::Cpu).unwrap();
139        assert_eq!(res.dims(), &[3, 5, 4])
140    }
141
142    #[test]
143    fn test_normalize() {
144        use crate::{ImageTransform, Normalize};
145        use candle_core::{Device, Tensor};
146
147        let image = Tensor::randn(1f32, 0f32, (3, 5, 4), &Device::Cpu).unwrap();
148        let res = Normalize {
149            mean: vec![0.5, 0.5, 0.5],
150            std: vec![0.5, 0.5, 0.5],
151        }
152        .map(&image, &Device::Cpu)
153        .unwrap();
154        assert_eq!(res.dims(), &[3, 5, 4])
155    }
156}