mistralrs_vision/
transforms.rs1use crate::utils::{get_pixel_data, n_channels};
2use candle_core::{DType, Device, Result, Tensor};
3use image::{DynamicImage, GenericImageView};
4
5use crate::ImageTransform;
6
7pub struct ToTensor;
10
11impl ToTensor {
12 fn to_tensor(device: &Device, channels: usize, data: Vec<Vec<Vec<u8>>>) -> Result<Tensor> {
13 ToTensorNoNorm::to_tensor(device, channels, data)? / 255.0f64
14 }
15}
16
17impl ImageTransform for ToTensor {
18 type Input = DynamicImage;
19 type Output = Tensor;
20 fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
21 let num_channels = n_channels(x);
22 let data = get_pixel_data(
23 num_channels,
24 x.to_rgba8(),
25 x.dimensions().1 as usize,
26 x.dimensions().0 as usize,
27 );
28 Self::to_tensor(device, num_channels, data)
29 }
30}
31
32pub struct ToTensorNoNorm;
35
36impl ToTensorNoNorm {
37 fn to_tensor(device: &Device, channels: usize, data: Vec<Vec<Vec<u8>>>) -> Result<Tensor> {
38 let mut accum = Vec::new();
39 for row in data {
40 let mut row_accum = Vec::new();
41 for item in row {
42 row_accum.push(
43 Tensor::from_slice(&item[..channels], (1, channels), &Device::Cpu)?
44 .to_dtype(DType::F32)?,
45 )
46 }
47 let row = Tensor::cat(&row_accum, 0)?;
48 accum.push(row.t()?.unsqueeze(1)?);
49 }
50 Tensor::cat(&accum, 1)?.to_device(device)
51 }
52}
53
54impl ImageTransform for ToTensorNoNorm {
55 type Input = DynamicImage;
56 type Output = Tensor;
57 fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
58 let num_channels = n_channels(x);
59 let data = get_pixel_data(
60 num_channels,
61 x.to_rgba8(),
62 x.dimensions().1 as usize,
63 x.dimensions().0 as usize,
64 );
65 Self::to_tensor(device, num_channels, data)
66 }
67}
68
69pub struct Normalize {
77 pub mean: Vec<f64>,
78 pub std: Vec<f64>,
79}
80
81impl ImageTransform for Normalize {
82 type Input = Tensor;
83 type Output = Self::Input;
84
85 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
86 let num_channels = x.dim(0)?;
87 if self.mean.len() != num_channels || self.std.len() != num_channels {
88 candle_core::bail!(
89 "Num channels ({}) must match number of mean ({}) and std ({}).",
90 num_channels,
91 self.mean.len(),
92 self.std.len()
93 );
94 }
95 let mut accum = Vec::new();
96 for (i, channel) in x.chunk(num_channels, 0)?.iter().enumerate() {
97 accum.push(((channel - self.mean[i])? / self.std[i])?);
98 }
99 Tensor::cat(&accum, 0)
100 }
101}
102
103pub struct InterpolateResize {
105 pub target_w: usize,
106 pub target_h: usize,
107}
108
109impl ImageTransform for InterpolateResize {
110 type Input = Tensor;
111 type Output = Self::Input;
112
113 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
114 x.unsqueeze(0)?
115 .interpolate2d(self.target_h, self.target_w)?
116 .squeeze(0)
117 }
118}
119
120impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
121 type Input = T::Input;
122 type Output = T::Output;
123
124 fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
125 if let Some(this) = self {
126 this.map(x, dev)
127 } else {
128 Ok(x.clone())
129 }
130 }
131}
132
133pub struct Rescale {
137 pub factor: Option<f64>,
138}
139
140impl ImageTransform for Rescale {
141 type Input = Tensor;
142 type Output = Self::Input;
143
144 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
145 if let Some(factor) = self.factor {
146 x * factor
147 } else {
148 Ok(x.clone())
149 }
150 }
151}
152
153mod tests {
154 #[test]
155 fn test_to_tensor() {
156 use candle_core::Device;
157 use image::{ColorType, DynamicImage};
158
159 use crate::ImageTransform;
160
161 use super::ToTensor;
162
163 let image = DynamicImage::new(4, 5, ColorType::Rgb8);
164 let res = ToTensor.map(&image, &Device::Cpu).unwrap();
165 assert_eq!(res.dims(), &[3, 5, 4])
166 }
167
168 #[test]
169 fn test_normalize() {
170 use crate::{ImageTransform, Normalize};
171 use candle_core::{DType, Device, Tensor};
172
173 let image = Tensor::zeros((3, 5, 4), DType::U8, &Device::Cpu).unwrap();
174 let res = Normalize {
175 mean: vec![0.5, 0.5, 0.5],
176 std: vec![0.5, 0.5, 0.5],
177 }
178 .map(&image, &Device::Cpu)
179 .unwrap();
180 assert_eq!(res.dims(), &[3, 5, 4])
181 }
182}