mistralrs_vision/
transforms.rs1use crate::utils::image_to_pixels;
2use candle_core::{Device, Result, Tensor, D};
3use image::DynamicImage;
4
5use crate::ImageTransform;
6
7pub struct ToTensor;
10
11impl ImageTransform for ToTensor {
12 type Input = DynamicImage;
13 type Output = Tensor;
14 fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
15 image_to_pixels(x, device)? / 255.
16 }
17}
18
19pub struct ToTensorNoNorm;
22
23impl ImageTransform for ToTensorNoNorm {
24 type Input = DynamicImage;
25 type Output = Tensor;
26 fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
27 image_to_pixels(x, device)
28 }
29}
30
31pub struct Normalize {
39 pub mean: Vec<f64>,
40 pub std: Vec<f64>,
41}
42
43impl ImageTransform for Normalize {
44 type Input = Tensor;
45 type Output = Self::Input;
46
47 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
48 let num_channels = x.dim(D::Minus(3))?;
49 if self.mean.len() != num_channels || self.std.len() != num_channels {
50 candle_core::bail!(
51 "Num channels ({}) must match number of mean ({}) and std ({}).",
52 num_channels,
53 self.mean.len(),
54 self.std.len()
55 );
56 }
57 let dtype = x.dtype();
58 let device = x.device();
59 let mean = Tensor::from_slice(
60 &self.mean.iter().map(|x| *x as f32).collect::<Vec<_>>(),
61 (num_channels,),
62 device,
63 )?
64 .to_dtype(dtype)?;
65 let std = Tensor::from_slice(
66 &self.std.iter().map(|x| *x as f32).collect::<Vec<_>>(),
67 (num_channels,),
68 device,
69 )?
70 .to_dtype(dtype)?;
71 let mean = mean.reshape((num_channels, 1, 1))?;
72 let std = std.reshape((num_channels, 1, 1))?;
73 x.broadcast_sub(&mean)?.broadcast_div(&std)
74 }
75}
76
77pub struct InterpolateResize {
79 pub target_w: usize,
80 pub target_h: usize,
81}
82
83impl ImageTransform for InterpolateResize {
84 type Input = Tensor;
85 type Output = Self::Input;
86
87 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
88 x.unsqueeze(0)?
89 .interpolate2d(self.target_h, self.target_w)?
90 .squeeze(0)
91 }
92}
93
94impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
95 type Input = T::Input;
96 type Output = T::Output;
97
98 fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
99 if let Some(this) = self {
100 this.map(x, dev)
101 } else {
102 Ok(x.clone())
103 }
104 }
105}
106
107pub struct Rescale {
111 pub factor: Option<f64>,
112}
113
114impl ImageTransform for Rescale {
115 type Input = Tensor;
116 type Output = Self::Input;
117
118 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
119 if let Some(factor) = self.factor {
120 x * factor
121 } else {
122 Ok(x.clone())
123 }
124 }
125}
126
127mod tests {
128 #[test]
129 fn test_to_tensor() {
130 use candle_core::Device;
131 use image::{ColorType, DynamicImage};
132
133 use crate::ImageTransform;
134
135 use super::ToTensor;
136
137 let image = DynamicImage::new(4, 5, ColorType::Rgb8);
138 let res = ToTensor.map(&image, &Device::Cpu).unwrap();
139 assert_eq!(res.dims(), &[3, 5, 4])
140 }
141
142 #[test]
143 fn test_normalize() {
144 use crate::{ImageTransform, Normalize};
145 use candle_core::{Device, Tensor};
146
147 let image = Tensor::randn(1f32, 0f32, (3, 5, 4), &Device::Cpu).unwrap();
148 let res = Normalize {
149 mean: vec![0.5, 0.5, 0.5],
150 std: vec![0.5, 0.5, 0.5],
151 }
152 .map(&image, &Device::Cpu)
153 .unwrap();
154 assert_eq!(res.dims(), &[3, 5, 4])
155 }
156}