mistralrs_vision/
transforms.rs

1use crate::utils::{get_pixel_data, n_channels};
2use candle_core::{DType, Device, Result, Tensor};
3use image::{DynamicImage, GenericImageView};
4
5use crate::ImageTransform;
6
7/// Convert an image to a tensor. This converts the data from being in `[0, 255]` to `[0.0, 1.0]`.
8/// The tensor's shape is (channels, height, width).
9pub struct ToTensor;
10
11impl ToTensor {
12    fn to_tensor(device: &Device, channels: usize, data: Vec<Vec<Vec<u8>>>) -> Result<Tensor> {
13        ToTensorNoNorm::to_tensor(device, channels, data)? / 255.0f64
14    }
15}
16
17impl ImageTransform for ToTensor {
18    type Input = DynamicImage;
19    type Output = Tensor;
20    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
21        let num_channels = n_channels(x);
22        let data = get_pixel_data(
23            num_channels,
24            x.to_rgba8(),
25            x.dimensions().1 as usize,
26            x.dimensions().0 as usize,
27        );
28        Self::to_tensor(device, num_channels, data)
29    }
30}
31
32/// Convert an image to a tensor without normalizing to `[0.0, 1.0]`.
33/// The tensor's shape is (channels, height, width).
34pub struct ToTensorNoNorm;
35
36impl ToTensorNoNorm {
37    fn to_tensor(device: &Device, channels: usize, data: Vec<Vec<Vec<u8>>>) -> Result<Tensor> {
38        let mut accum = Vec::new();
39        for row in data {
40            let mut row_accum = Vec::new();
41            for item in row {
42                row_accum.push(
43                    Tensor::from_slice(&item[..channels], (1, channels), &Device::Cpu)?
44                        .to_dtype(DType::F32)?,
45                )
46            }
47            let row = Tensor::cat(&row_accum, 0)?;
48            accum.push(row.t()?.unsqueeze(1)?);
49        }
50        Tensor::cat(&accum, 1)?.to_device(device)
51    }
52}
53
54impl ImageTransform for ToTensorNoNorm {
55    type Input = DynamicImage;
56    type Output = Tensor;
57    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
58        let num_channels = n_channels(x);
59        let data = get_pixel_data(
60            num_channels,
61            x.to_rgba8(),
62            x.dimensions().1 as usize,
63            x.dimensions().0 as usize,
64        );
65        Self::to_tensor(device, num_channels, data)
66    }
67}
68
69/// Normalize the image data based on the mean and standard deviation.
70/// The value is computed as follows:
71/// `
72/// x[channel]=(x[channel - mean[channel]) / std[channel]
73/// `
74///
75/// Expects an input tensor of shape (channels, height, width).
76pub struct Normalize {
77    pub mean: Vec<f64>,
78    pub std: Vec<f64>,
79}
80
81impl ImageTransform for Normalize {
82    type Input = Tensor;
83    type Output = Self::Input;
84
85    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
86        let num_channels = x.dim(0)?;
87        if self.mean.len() != num_channels || self.std.len() != num_channels {
88            candle_core::bail!(
89                "Num channels ({}) must match number of mean ({}) and std ({}).",
90                num_channels,
91                self.mean.len(),
92                self.std.len()
93            );
94        }
95        let mut accum = Vec::new();
96        for (i, channel) in x.chunk(num_channels, 0)?.iter().enumerate() {
97            accum.push(((channel - self.mean[i])? / self.std[i])?);
98        }
99        Tensor::cat(&accum, 0)
100    }
101}
102
103/// Resize the image via nearest interpolation.
104pub struct InterpolateResize {
105    pub target_w: usize,
106    pub target_h: usize,
107}
108
109impl ImageTransform for InterpolateResize {
110    type Input = Tensor;
111    type Output = Self::Input;
112
113    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
114        x.unsqueeze(0)?
115            .interpolate2d(self.target_h, self.target_w)?
116            .squeeze(0)
117    }
118}
119
120impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
121    type Input = T::Input;
122    type Output = T::Output;
123
124    fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
125        if let Some(this) = self {
126            this.map(x, dev)
127        } else {
128            Ok(x.clone())
129        }
130    }
131}
132
133/// Multiply the pixe values by the provided factor.
134///
135/// Each pixel value is calculated as follows: x = x * factor
136pub struct Rescale {
137    pub factor: Option<f64>,
138}
139
140impl ImageTransform for Rescale {
141    type Input = Tensor;
142    type Output = Self::Input;
143
144    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
145        if let Some(factor) = self.factor {
146            x * factor
147        } else {
148            Ok(x.clone())
149        }
150    }
151}
152
153mod tests {
154    #[test]
155    fn test_to_tensor() {
156        use candle_core::Device;
157        use image::{ColorType, DynamicImage};
158
159        use crate::ImageTransform;
160
161        use super::ToTensor;
162
163        let image = DynamicImage::new(4, 5, ColorType::Rgb8);
164        let res = ToTensor.map(&image, &Device::Cpu).unwrap();
165        assert_eq!(res.dims(), &[3, 5, 4])
166    }
167
168    #[test]
169    fn test_normalize() {
170        use crate::{ImageTransform, Normalize};
171        use candle_core::{DType, Device, Tensor};
172
173        let image = Tensor::zeros((3, 5, 4), DType::U8, &Device::Cpu).unwrap();
174        let res = Normalize {
175            mean: vec![0.5, 0.5, 0.5],
176            std: vec![0.5, 0.5, 0.5],
177        }
178        .map(&image, &Device::Cpu)
179        .unwrap();
180        assert_eq!(res.dims(), &[3, 5, 4])
181    }
182}