mistralrs_vision/
transforms.rs

1use crate::utils::image_to_pixels;
2use candle_core::{Device, Result, Tensor, D};
3use image::DynamicImage;
4
5use crate::ImageTransform;
6
7/// Convert an image to a tensor. This converts the data from being in `[0, 255]` to `[0.0, 1.0]`.
8/// The tensor's shape is (channels, height, width).
9pub struct ToTensor;
10
11impl ImageTransform for ToTensor {
12    type Input = DynamicImage;
13    type Output = Tensor;
14    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
15        image_to_pixels(x, device)? / 255.
16    }
17}
18
19/// Convert an image to a tensor without normalizing to `[0.0, 1.0]`.
20/// The tensor's shape is (channels, height, width).
21pub struct ToTensorNoNorm;
22
23impl ImageTransform for ToTensorNoNorm {
24    type Input = DynamicImage;
25    type Output = Tensor;
26    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
27        image_to_pixels(x, device)
28    }
29}
30
31/// Normalize the image data based on the mean and standard deviation.
32/// The value is computed as follows:
33/// `
34/// x[channel] = (x[channel] - mean[channel]) / std[channel]
35/// `
36///
37/// Expects an input tensor of shape (channels, height, width).
38pub struct Normalize {
39    pub mean: Vec<f64>,
40    pub std: Vec<f64>,
41}
42
43impl ImageTransform for Normalize {
44    type Input = Tensor;
45    type Output = Self::Input;
46
47    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
48        let num_channels = x.dim(D::Minus(3))?;
49        if self.mean.len() != num_channels || self.std.len() != num_channels {
50            candle_core::bail!(
51                "Num channels ({}) must match number of mean ({}) and std ({}).",
52                num_channels,
53                self.mean.len(),
54                self.std.len()
55            );
56        }
57        let mut accum = Vec::new();
58        for (i, channel) in x.chunk(num_channels, D::Minus(3))?.iter().enumerate() {
59            accum.push(((channel - self.mean[i])? / self.std[i])?);
60        }
61        Tensor::cat(&accum, D::Minus(3))
62    }
63}
64
65/// Resize the image via nearest interpolation.
66pub struct InterpolateResize {
67    pub target_w: usize,
68    pub target_h: usize,
69}
70
71impl ImageTransform for InterpolateResize {
72    type Input = Tensor;
73    type Output = Self::Input;
74
75    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
76        x.unsqueeze(0)?
77            .interpolate2d(self.target_h, self.target_w)?
78            .squeeze(0)
79    }
80}
81
82impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
83    type Input = T::Input;
84    type Output = T::Output;
85
86    fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
87        if let Some(this) = self {
88            this.map(x, dev)
89        } else {
90            Ok(x.clone())
91        }
92    }
93}
94
95/// Multiply the pixe values by the provided factor.
96///
97/// Each pixel value is calculated as follows: x = x * factor
98pub struct Rescale {
99    pub factor: Option<f64>,
100}
101
102impl ImageTransform for Rescale {
103    type Input = Tensor;
104    type Output = Self::Input;
105
106    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
107        if let Some(factor) = self.factor {
108            x * factor
109        } else {
110            Ok(x.clone())
111        }
112    }
113}
114
115mod tests {
116    #[test]
117    fn test_to_tensor() {
118        use candle_core::Device;
119        use image::{ColorType, DynamicImage};
120
121        use crate::ImageTransform;
122
123        use super::ToTensor;
124
125        let image = DynamicImage::new(4, 5, ColorType::Rgb8);
126        let res = ToTensor.map(&image, &Device::Cpu).unwrap();
127        assert_eq!(res.dims(), &[3, 5, 4])
128    }
129
130    #[test]
131    fn test_normalize() {
132        use crate::{ImageTransform, Normalize};
133        use candle_core::{DType, Device, Tensor};
134
135        let image = Tensor::zeros((3, 5, 4), DType::U8, &Device::Cpu).unwrap();
136        let res = Normalize {
137            mean: vec![0.5, 0.5, 0.5],
138            std: vec![0.5, 0.5, 0.5],
139        }
140        .map(&image, &Device::Cpu)
141        .unwrap();
142        assert_eq!(res.dims(), &[3, 5, 4])
143    }
144}