mistralrs_vision/
transforms.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use crate::utils::{get_pixel_data, n_channels};
use candle_core::{DType, Device, Result, Tensor};
use image::{DynamicImage, GenericImageView};

use crate::ImageTransform;

/// Convert an image to a tensor. This converts the data from being in `[0, 255]` to `[0.0, 1.0]`.
/// The tensor's shape is (channels, height, width).
pub struct ToTensor;

impl ToTensor {
    fn to_tensor(device: &Device, channels: usize, data: Vec<Vec<Vec<u8>>>) -> Result<Tensor> {
        ToTensorNoNorm::to_tensor(device, channels, data)? / 255.0f64
    }
}

impl ImageTransform for ToTensor {
    type Input = DynamicImage;
    type Output = Tensor;
    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
        let num_channels = n_channels(x);
        let data = get_pixel_data(
            num_channels,
            x.to_rgba8(),
            x.dimensions().1 as usize,
            x.dimensions().0 as usize,
        );
        Self::to_tensor(device, num_channels, data)
    }
}

/// Convert an image to a tensor without normalizing to `[0.0, 1.0]`.
/// The tensor's shape is (channels, height, width).
pub struct ToTensorNoNorm;

impl ToTensorNoNorm {
    fn to_tensor(device: &Device, channels: usize, data: Vec<Vec<Vec<u8>>>) -> Result<Tensor> {
        let mut accum = Vec::new();
        for row in data {
            let mut row_accum = Vec::new();
            for item in row {
                row_accum.push(
                    Tensor::from_slice(&item[..channels], (1, channels), &Device::Cpu)?
                        .to_dtype(DType::F32)?,
                )
            }
            let row = Tensor::cat(&row_accum, 0)?;
            accum.push(row.t()?.unsqueeze(1)?);
        }
        Tensor::cat(&accum, 1)?.to_device(device)
    }
}

impl ImageTransform for ToTensorNoNorm {
    type Input = DynamicImage;
    type Output = Tensor;
    fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
        let num_channels = n_channels(x);
        let data = get_pixel_data(
            num_channels,
            x.to_rgba8(),
            x.dimensions().1 as usize,
            x.dimensions().0 as usize,
        );
        Self::to_tensor(device, num_channels, data)
    }
}

/// Normalize the image data based on the mean and standard deviation.
/// The value is computed as follows:
/// `
/// x[channel]=(x[channel - mean[channel]) / std[channel]
/// `
///
/// Expects an input tensor of shape (channels, height, width).
pub struct Normalize {
    pub mean: Vec<f64>,
    pub std: Vec<f64>,
}

impl ImageTransform for Normalize {
    type Input = Tensor;
    type Output = Self::Input;

    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
        let num_channels = x.dim(0)?;
        if self.mean.len() != num_channels || self.std.len() != num_channels {
            candle_core::bail!(
                "Num channels ({}) must match number of mean ({}) and std ({}).",
                num_channels,
                self.mean.len(),
                self.std.len()
            );
        }
        let mut accum = Vec::new();
        for (i, channel) in x.chunk(num_channels, 0)?.iter().enumerate() {
            accum.push(((channel - self.mean[i])? / self.std[i])?);
        }
        Tensor::cat(&accum, 0)
    }
}

/// Resize the image via nearest interpolation.
pub struct InterpolateResize {
    pub target_w: usize,
    pub target_h: usize,
}

impl ImageTransform for InterpolateResize {
    type Input = Tensor;
    type Output = Self::Input;

    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
        x.unsqueeze(0)?
            .interpolate2d(self.target_h, self.target_w)?
            .squeeze(0)
    }
}

impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
    type Input = T::Input;
    type Output = T::Output;

    fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
        if let Some(this) = self {
            this.map(x, dev)
        } else {
            Ok(x.clone())
        }
    }
}

/// Multiply the pixe values by the provided factor.
///
/// Each pixel value is calculated as follows: x = x * factor
pub struct Rescale {
    pub factor: Option<f64>,
}

impl ImageTransform for Rescale {
    type Input = Tensor;
    type Output = Self::Input;

    fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
        if let Some(factor) = self.factor {
            x * factor
        } else {
            Ok(x.clone())
        }
    }
}

mod tests {
    #[test]
    fn test_to_tensor() {
        use candle_core::Device;
        use image::{ColorType, DynamicImage};

        use crate::ImageTransform;

        use super::ToTensor;

        let image = DynamicImage::new(4, 5, ColorType::Rgb8);
        let res = ToTensor.map(&image, &Device::Cpu).unwrap();
        assert_eq!(res.dims(), &[3, 5, 4])
    }

    #[test]
    fn test_normalize() {
        use crate::{ImageTransform, Normalize};
        use candle_core::{DType, Device, Tensor};

        let image = Tensor::zeros((3, 5, 4), DType::U8, &Device::Cpu).unwrap();
        let res = Normalize {
            mean: vec![0.5, 0.5, 0.5],
            std: vec![0.5, 0.5, 0.5],
        }
        .map(&image, &Device::Cpu)
        .unwrap();
        assert_eq!(res.dims(), &[3, 5, 4])
    }
}