1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Result, Tensor, D};
4use candle_nn::{Conv2d, GroupNorm};
5use mistralrs_quant::{Convolution, ShardedVarBuilder};
6use serde::Deserialize;
7
8use crate::layers::{conv2d, group_norm, MatMul};
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct Config {
12 pub in_channels: usize,
13 pub out_channels: usize,
14 pub block_out_channels: Vec<usize>,
15 pub layers_per_block: usize,
16 pub latent_channels: usize,
17 pub scaling_factor: f64,
18 pub shift_factor: f64,
19 pub norm_num_groups: usize,
20}
21
22fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
23 let dim = q.dim(D::Minus1)?;
24 let scale_factor = 1.0 / (dim as f64).sqrt();
25 let attn_weights = (MatMul.matmul(q, &k.t()?)? * scale_factor)?;
26 MatMul.matmul(&candle_nn::ops::softmax_last_dim(&attn_weights)?, v)
27}
28
29#[derive(Debug, Clone)]
30struct AttnBlock {
31 q: Conv2d,
32 k: Conv2d,
33 v: Conv2d,
34 proj_out: Conv2d,
35 norm: GroupNorm,
36}
37
38impl AttnBlock {
39 fn new(in_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
40 let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
41 let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
42 let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
43 let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
44 let norm = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm"))?;
45 Ok(Self {
46 q,
47 k,
48 v,
49 proj_out,
50 norm,
51 })
52 }
53}
54
55impl candle_core::Module for AttnBlock {
56 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57 let init_xs = xs;
58 let normed = self.norm.forward(xs)?;
59 let q = Convolution.forward_2d(&self.q, &normed)?;
60 let k = Convolution.forward_2d(&self.k, &normed)?;
61 let v = Convolution.forward_2d(&self.v, &normed)?;
62 let (b, c, h, w) = q.dims4()?;
63 let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
64 let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
65 let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
66 let attended = scaled_dot_product_attention(&q, &k, &v)?;
67 let attended = attended.squeeze(1)?.t()?.reshape((b, c, h, w))?;
68 let projected = Convolution.forward_2d(&self.proj_out, &attended)?;
69 projected + init_xs
70 }
71}
72
73#[derive(Debug, Clone)]
74struct ResnetBlock {
75 norm1: GroupNorm,
76 conv1: Conv2d,
77 norm2: GroupNorm,
78 conv2: Conv2d,
79 nin_shortcut: Option<Conv2d>,
80}
81
82impl ResnetBlock {
83 fn new(in_c: usize, out_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
84 let conv_cfg = candle_nn::Conv2dConfig {
85 padding: 1,
86 ..Default::default()
87 };
88 let norm1 = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm1"))?;
89 let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
90 let norm2 = group_norm(cfg.norm_num_groups, out_c, 1e-6, vb.pp("norm2"))?;
91 let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
92 let nin_shortcut = if in_c == out_c {
93 None
94 } else {
95 Some(conv2d(
96 in_c,
97 out_c,
98 1,
99 Default::default(),
100 vb.pp("nin_shortcut"),
101 )?)
102 };
103 Ok(Self {
104 norm1,
105 conv1,
106 norm2,
107 conv2,
108 nin_shortcut,
109 })
110 }
111}
112
113impl candle_core::Module for ResnetBlock {
114 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
115 let mut h = self.norm1.forward(xs)?;
116 h = candle_nn::Activation::Swish.forward(&h)?;
117 h = Convolution.forward_2d(&self.conv1, &h)?;
118 h = self.norm2.forward(&h)?;
119 h = candle_nn::Activation::Swish.forward(&h)?;
120 h = Convolution.forward_2d(&self.conv2, &h)?;
121 match self.nin_shortcut.as_ref() {
122 None => xs + h,
123 Some(c) => Convolution.forward_2d(c, xs)? + h,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
129struct Downsample {
130 conv: Conv2d,
131}
132
133impl Downsample {
134 fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
135 let conv_cfg = candle_nn::Conv2dConfig {
136 stride: 2,
137 ..Default::default()
138 };
139 let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
140 Ok(Self { conv })
141 }
142}
143
144impl candle_core::Module for Downsample {
145 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
146 let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
147 let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
148 Convolution.forward_2d(&self.conv, &xs)
149 }
150}
151
152#[derive(Debug, Clone)]
153struct Upsample {
154 conv: Conv2d,
155}
156
157impl Upsample {
158 fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
159 let conv_cfg = candle_nn::Conv2dConfig {
160 padding: 1,
161 ..Default::default()
162 };
163 let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
164 Ok(Self { conv })
165 }
166}
167
168impl candle_core::Module for Upsample {
169 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
170 let (_, _, h, w) = xs.dims4()?;
171 let upsampled = xs.upsample_nearest2d(h * 2, w * 2)?;
172 Convolution.forward_2d(&self.conv, &upsampled)
173 }
174}
175
176#[derive(Debug, Clone)]
177struct DownBlock {
178 block: Vec<ResnetBlock>,
179 downsample: Option<Downsample>,
180}
181
182#[derive(Debug, Clone)]
183pub struct Encoder {
184 conv_in: Conv2d,
185 mid_block_1: ResnetBlock,
186 mid_attn_1: AttnBlock,
187 mid_block_2: ResnetBlock,
188 norm_out: GroupNorm,
189 conv_out: Conv2d,
190 down: Vec<DownBlock>,
191}
192
193impl Encoder {
194 pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
195 let conv_cfg = candle_nn::Conv2dConfig {
196 padding: 1,
197 ..Default::default()
198 };
199 let base_ch = cfg.block_out_channels[0];
200 let mut block_in = base_ch;
201 let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
202
203 let mut down = Vec::with_capacity(cfg.block_out_channels.len());
204 let vb_d = vb.pp("down");
205 for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate() {
206 let mut block = Vec::with_capacity(cfg.layers_per_block);
207 let vb_d = vb_d.pp(i_level);
208 let vb_b = vb_d.pp("block");
209 block_in = if i_level == 0 {
210 base_ch
211 } else {
212 cfg.block_out_channels[i_level - 1]
213 };
214 let block_out = *out_channels;
215 for i_block in 0..cfg.layers_per_block {
216 let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
217 block.push(b);
218 block_in = block_out;
219 }
220 let downsample = if i_level != cfg.block_out_channels.len() - 1 {
221 Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
222 } else {
223 None
224 };
225 let block = DownBlock { block, downsample };
226 down.push(block)
227 }
228
229 let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
230 let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
231 let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
232 let conv_out = conv2d(
233 block_in,
234 2 * cfg.latent_channels,
235 3,
236 conv_cfg,
237 vb.pp("conv_out"),
238 )?;
239 let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
240 Ok(Self {
241 conv_in,
242 mid_block_1,
243 mid_attn_1,
244 mid_block_2,
245 norm_out,
246 conv_out,
247 down,
248 })
249 }
250}
251
252impl candle_nn::Module for Encoder {
253 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
254 let mut h = Convolution.forward_2d(&self.conv_in, xs)?;
255 for block in self.down.iter() {
256 for b in block.block.iter() {
257 h = b.forward(&h)?
258 }
259 if let Some(ds) = block.downsample.as_ref() {
260 h = ds.forward(&h)?
261 }
262 }
263 h = self.mid_block_1.forward(&h)?;
264 h = self.mid_attn_1.forward(&h)?;
265 h = self.mid_block_2.forward(&h)?;
266 h = self.norm_out.forward(&h)?;
267 h = candle_nn::Activation::Swish.forward(&h)?;
268 Convolution.forward_2d(&self.conv_out, &h)
269 }
270}
271
272#[derive(Debug, Clone)]
273struct UpBlock {
274 block: Vec<ResnetBlock>,
275 upsample: Option<Upsample>,
276}
277
278#[derive(Debug, Clone)]
279pub struct Decoder {
280 conv_in: Conv2d,
281 mid_block_1: ResnetBlock,
282 mid_attn_1: AttnBlock,
283 mid_block_2: ResnetBlock,
284 norm_out: GroupNorm,
285 conv_out: Conv2d,
286 up: Vec<UpBlock>,
287}
288
289impl Decoder {
290 pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
291 let conv_cfg = candle_nn::Conv2dConfig {
292 padding: 1,
293 ..Default::default()
294 };
295 let base_ch = cfg.block_out_channels[0];
296 let mut block_in = cfg.block_out_channels.last().copied().unwrap_or(base_ch);
297 let conv_in = conv2d(cfg.latent_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
298 let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
299 let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
300 let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
301
302 let mut up = Vec::with_capacity(cfg.block_out_channels.len());
303 let vb_u = vb.pp("up");
304 for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate().rev() {
305 let block_out = *out_channels;
306 let vb_u = vb_u.pp(i_level);
307 let vb_b = vb_u.pp("block");
308 let mut block = Vec::with_capacity(cfg.layers_per_block + 1);
309 for i_block in 0..=cfg.layers_per_block {
310 let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
311 block.push(b);
312 block_in = block_out;
313 }
314 let upsample = if i_level != 0 {
315 Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
316 } else {
317 None
318 };
319 let block = UpBlock { block, upsample };
320 up.push(block)
321 }
322 up.reverse();
323
324 let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
325 let conv_out = conv2d(block_in, cfg.out_channels, 3, conv_cfg, vb.pp("conv_out"))?;
326 Ok(Self {
327 conv_in,
328 mid_block_1,
329 mid_attn_1,
330 mid_block_2,
331 norm_out,
332 conv_out,
333 up,
334 })
335 }
336}
337
338impl candle_nn::Module for Decoder {
339 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
340 let mut h = Convolution.forward_2d(&self.conv_in, xs)?;
341 h = self.mid_block_1.forward(&h)?;
342 h = self.mid_attn_1.forward(&h)?;
343 h = self.mid_block_2.forward(&h)?;
344 for block in self.up.iter().rev() {
345 for b in block.block.iter() {
346 h = b.forward(&h)?
347 }
348 if let Some(us) = block.upsample.as_ref() {
349 h = us.forward(&h)?
350 }
351 }
352 h = self.norm_out.forward(&h)?;
353 h = candle_nn::Activation::Swish.forward(&h)?;
354 Convolution.forward_2d(&self.conv_out, &h)
355 }
356}
357
358#[derive(Debug, Clone)]
359pub struct DiagonalGaussian {
360 sample: bool,
361 chunk_dim: usize,
362}
363
364impl DiagonalGaussian {
365 pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
366 Ok(Self { sample, chunk_dim })
367 }
368}
369
370impl candle_nn::Module for DiagonalGaussian {
371 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
372 let chunks = xs.chunk(2, self.chunk_dim)?;
373 if self.sample {
374 let std = (&chunks[1] * 0.5)?.exp()?;
375 &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
376 } else {
377 Ok(chunks[0].clone())
378 }
379 }
380}
381
382#[derive(Debug, Clone)]
383pub struct AutoEncoder {
384 encoder: Encoder,
385 decoder: Decoder,
386 reg: DiagonalGaussian,
387 shift_factor: f64,
388 scale_factor: f64,
389}
390
391impl AutoEncoder {
392 pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
393 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
394 let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
395 let reg = DiagonalGaussian::new(true, 1)?;
396 Ok(Self {
397 encoder,
398 decoder,
399 reg,
400 scale_factor: cfg.scaling_factor,
401 shift_factor: cfg.shift_factor,
402 })
403 }
404
405 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
406 let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
407 (z - self.shift_factor)? * self.scale_factor
408 }
409 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
410 let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
411 xs.apply(&self.decoder)
412 }
413}
414
415impl candle_core::Module for AutoEncoder {
416 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
417 self.decode(&self.encode(xs)?)
418 }
419}