mistralrs_vision/
transforms.rs1use crate::utils::image_to_pixels;
2use candle_core::{Device, Result, Tensor, D};
3use image::DynamicImage;
4
5use crate::ImageTransform;
6
7pub struct ToTensor;
10
11impl ImageTransform for ToTensor {
12 type Input = DynamicImage;
13 type Output = Tensor;
14 fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
15 image_to_pixels(x, device)? / 255.
16 }
17}
18
19pub struct ToTensorNoNorm;
22
23impl ImageTransform for ToTensorNoNorm {
24 type Input = DynamicImage;
25 type Output = Tensor;
26 fn map(&self, x: &Self::Input, device: &Device) -> Result<Self::Output> {
27 image_to_pixels(x, device)
28 }
29}
30
31pub struct Normalize {
39 pub mean: Vec<f64>,
40 pub std: Vec<f64>,
41}
42
43impl ImageTransform for Normalize {
44 type Input = Tensor;
45 type Output = Self::Input;
46
47 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
48 let num_channels = x.dim(D::Minus(3))?;
49 if self.mean.len() != num_channels || self.std.len() != num_channels {
50 candle_core::bail!(
51 "Num channels ({}) must match number of mean ({}) and std ({}).",
52 num_channels,
53 self.mean.len(),
54 self.std.len()
55 );
56 }
57 let mut accum = Vec::new();
58 for (i, channel) in x.chunk(num_channels, D::Minus(3))?.iter().enumerate() {
59 accum.push(((channel - self.mean[i])? / self.std[i])?);
60 }
61 Tensor::cat(&accum, D::Minus(3))
62 }
63}
64
65pub struct InterpolateResize {
67 pub target_w: usize,
68 pub target_h: usize,
69}
70
71impl ImageTransform for InterpolateResize {
72 type Input = Tensor;
73 type Output = Self::Input;
74
75 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
76 x.unsqueeze(0)?
77 .interpolate2d(self.target_h, self.target_w)?
78 .squeeze(0)
79 }
80}
81
82impl<T: ImageTransform<Input = E, Output = E>, E: Clone> ImageTransform for Option<T> {
83 type Input = T::Input;
84 type Output = T::Output;
85
86 fn map(&self, x: &T::Input, dev: &Device) -> Result<T::Output> {
87 if let Some(this) = self {
88 this.map(x, dev)
89 } else {
90 Ok(x.clone())
91 }
92 }
93}
94
95pub struct Rescale {
99 pub factor: Option<f64>,
100}
101
102impl ImageTransform for Rescale {
103 type Input = Tensor;
104 type Output = Self::Input;
105
106 fn map(&self, x: &Self::Input, _: &Device) -> Result<Self::Output> {
107 if let Some(factor) = self.factor {
108 x * factor
109 } else {
110 Ok(x.clone())
111 }
112 }
113}
114
115mod tests {
116 #[test]
117 fn test_to_tensor() {
118 use candle_core::Device;
119 use image::{ColorType, DynamicImage};
120
121 use crate::ImageTransform;
122
123 use super::ToTensor;
124
125 let image = DynamicImage::new(4, 5, ColorType::Rgb8);
126 let res = ToTensor.map(&image, &Device::Cpu).unwrap();
127 assert_eq!(res.dims(), &[3, 5, 4])
128 }
129
130 #[test]
131 fn test_normalize() {
132 use crate::{ImageTransform, Normalize};
133 use candle_core::{DType, Device, Tensor};
134
135 let image = Tensor::zeros((3, 5, 4), DType::U8, &Device::Cpu).unwrap();
136 let res = Normalize {
137 mean: vec![0.5, 0.5, 0.5],
138 std: vec![0.5, 0.5, 0.5],
139 }
140 .map(&image, &Device::Cpu)
141 .unwrap();
142 assert_eq!(res.dims(), &[3, 5, 4])
143 }
144}