1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Result, Tensor, D};
4use candle_nn::{Conv2d, GroupNorm};
5use mistralrs_quant::ShardedVarBuilder;
6use serde::Deserialize;
7
8use crate::layers::{conv2d, group_norm, MatMul};
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct Config {
12 pub in_channels: usize,
13 pub out_channels: usize,
14 pub block_out_channels: Vec<usize>,
15 pub layers_per_block: usize,
16 pub latent_channels: usize,
17 pub scaling_factor: f64,
18 pub shift_factor: f64,
19 pub norm_num_groups: usize,
20}
21
22fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
23 let dim = q.dim(D::Minus1)?;
24 let scale_factor = 1.0 / (dim as f64).sqrt();
25 let attn_weights = (MatMul.matmul(q, &k.t()?)? * scale_factor)?;
26 MatMul.matmul(&candle_nn::ops::softmax_last_dim(&attn_weights)?, v)
27}
28
29#[derive(Debug, Clone)]
30struct AttnBlock {
31 q: Conv2d,
32 k: Conv2d,
33 v: Conv2d,
34 proj_out: Conv2d,
35 norm: GroupNorm,
36}
37
38impl AttnBlock {
39 fn new(in_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
40 let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
41 let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
42 let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
43 let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
44 let norm = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm"))?;
45 Ok(Self {
46 q,
47 k,
48 v,
49 proj_out,
50 norm,
51 })
52 }
53}
54
55impl candle_core::Module for AttnBlock {
56 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57 let init_xs = xs;
58 let xs = xs.apply(&self.norm)?;
59 let q = xs.apply(&self.q)?;
60 let k = xs.apply(&self.k)?;
61 let v = xs.apply(&self.v)?;
62 let (b, c, h, w) = q.dims4()?;
63 let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
64 let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
65 let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
66 let xs = scaled_dot_product_attention(&q, &k, &v)?;
67 let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
68 xs.apply(&self.proj_out)? + init_xs
69 }
70}
71
72#[derive(Debug, Clone)]
73struct ResnetBlock {
74 norm1: GroupNorm,
75 conv1: Conv2d,
76 norm2: GroupNorm,
77 conv2: Conv2d,
78 nin_shortcut: Option<Conv2d>,
79}
80
81impl ResnetBlock {
82 fn new(in_c: usize, out_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
83 let conv_cfg = candle_nn::Conv2dConfig {
84 padding: 1,
85 ..Default::default()
86 };
87 let norm1 = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm1"))?;
88 let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
89 let norm2 = group_norm(cfg.norm_num_groups, out_c, 1e-6, vb.pp("norm2"))?;
90 let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
91 let nin_shortcut = if in_c == out_c {
92 None
93 } else {
94 Some(conv2d(
95 in_c,
96 out_c,
97 1,
98 Default::default(),
99 vb.pp("nin_shortcut"),
100 )?)
101 };
102 Ok(Self {
103 norm1,
104 conv1,
105 norm2,
106 conv2,
107 nin_shortcut,
108 })
109 }
110}
111
112impl candle_core::Module for ResnetBlock {
113 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
114 let h = xs
115 .apply(&self.norm1)?
116 .apply(&candle_nn::Activation::Swish)?
117 .apply(&self.conv1)?
118 .apply(&self.norm2)?
119 .apply(&candle_nn::Activation::Swish)?
120 .apply(&self.conv2)?;
121 match self.nin_shortcut.as_ref() {
122 None => xs + h,
123 Some(c) => xs.apply(c)? + h,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
129struct Downsample {
130 conv: Conv2d,
131}
132
133impl Downsample {
134 fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
135 let conv_cfg = candle_nn::Conv2dConfig {
136 stride: 2,
137 ..Default::default()
138 };
139 let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
140 Ok(Self { conv })
141 }
142}
143
144impl candle_core::Module for Downsample {
145 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
146 let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
147 let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
148 xs.apply(&self.conv)
149 }
150}
151
152#[derive(Debug, Clone)]
153struct Upsample {
154 conv: Conv2d,
155}
156
157impl Upsample {
158 fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
159 let conv_cfg = candle_nn::Conv2dConfig {
160 padding: 1,
161 ..Default::default()
162 };
163 let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
164 Ok(Self { conv })
165 }
166}
167
168impl candle_core::Module for Upsample {
169 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
170 let (_, _, h, w) = xs.dims4()?;
171 xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
172 }
173}
174
175#[derive(Debug, Clone)]
176struct DownBlock {
177 block: Vec<ResnetBlock>,
178 downsample: Option<Downsample>,
179}
180
181#[derive(Debug, Clone)]
182pub struct Encoder {
183 conv_in: Conv2d,
184 mid_block_1: ResnetBlock,
185 mid_attn_1: AttnBlock,
186 mid_block_2: ResnetBlock,
187 norm_out: GroupNorm,
188 conv_out: Conv2d,
189 down: Vec<DownBlock>,
190}
191
192impl Encoder {
193 pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
194 let conv_cfg = candle_nn::Conv2dConfig {
195 padding: 1,
196 ..Default::default()
197 };
198 let base_ch = cfg.block_out_channels[0];
199 let mut block_in = base_ch;
200 let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
201
202 let mut down = Vec::with_capacity(cfg.block_out_channels.len());
203 let vb_d = vb.pp("down");
204 for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate() {
205 let mut block = Vec::with_capacity(cfg.layers_per_block);
206 let vb_d = vb_d.pp(i_level);
207 let vb_b = vb_d.pp("block");
208 block_in = if i_level == 0 {
209 base_ch
210 } else {
211 cfg.block_out_channels[i_level - 1]
212 };
213 let block_out = *out_channels;
214 for i_block in 0..cfg.layers_per_block {
215 let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
216 block.push(b);
217 block_in = block_out;
218 }
219 let downsample = if i_level != cfg.block_out_channels.len() - 1 {
220 Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
221 } else {
222 None
223 };
224 let block = DownBlock { block, downsample };
225 down.push(block)
226 }
227
228 let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
229 let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
230 let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
231 let conv_out = conv2d(
232 block_in,
233 2 * cfg.latent_channels,
234 3,
235 conv_cfg,
236 vb.pp("conv_out"),
237 )?;
238 let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
239 Ok(Self {
240 conv_in,
241 mid_block_1,
242 mid_attn_1,
243 mid_block_2,
244 norm_out,
245 conv_out,
246 down,
247 })
248 }
249}
250
251impl candle_nn::Module for Encoder {
252 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
253 let mut h = xs.apply(&self.conv_in)?;
254 for block in self.down.iter() {
255 for b in block.block.iter() {
256 h = h.apply(b)?
257 }
258 if let Some(ds) = block.downsample.as_ref() {
259 h = h.apply(ds)?
260 }
261 }
262 h.apply(&self.mid_block_1)?
263 .apply(&self.mid_attn_1)?
264 .apply(&self.mid_block_2)?
265 .apply(&self.norm_out)?
266 .apply(&candle_nn::Activation::Swish)?
267 .apply(&self.conv_out)
268 }
269}
270
271#[derive(Debug, Clone)]
272struct UpBlock {
273 block: Vec<ResnetBlock>,
274 upsample: Option<Upsample>,
275}
276
277#[derive(Debug, Clone)]
278pub struct Decoder {
279 conv_in: Conv2d,
280 mid_block_1: ResnetBlock,
281 mid_attn_1: AttnBlock,
282 mid_block_2: ResnetBlock,
283 norm_out: GroupNorm,
284 conv_out: Conv2d,
285 up: Vec<UpBlock>,
286}
287
288impl Decoder {
289 pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
290 let conv_cfg = candle_nn::Conv2dConfig {
291 padding: 1,
292 ..Default::default()
293 };
294 let base_ch = cfg.block_out_channels[0];
295 let mut block_in = cfg.block_out_channels.last().copied().unwrap_or(base_ch);
296 let conv_in = conv2d(cfg.latent_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
297 let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
298 let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
299 let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
300
301 let mut up = Vec::with_capacity(cfg.block_out_channels.len());
302 let vb_u = vb.pp("up");
303 for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate().rev() {
304 let block_out = *out_channels;
305 let vb_u = vb_u.pp(i_level);
306 let vb_b = vb_u.pp("block");
307 let mut block = Vec::with_capacity(cfg.layers_per_block + 1);
308 for i_block in 0..=cfg.layers_per_block {
309 let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
310 block.push(b);
311 block_in = block_out;
312 }
313 let upsample = if i_level != 0 {
314 Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
315 } else {
316 None
317 };
318 let block = UpBlock { block, upsample };
319 up.push(block)
320 }
321 up.reverse();
322
323 let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
324 let conv_out = conv2d(block_in, cfg.out_channels, 3, conv_cfg, vb.pp("conv_out"))?;
325 Ok(Self {
326 conv_in,
327 mid_block_1,
328 mid_attn_1,
329 mid_block_2,
330 norm_out,
331 conv_out,
332 up,
333 })
334 }
335}
336
337impl candle_nn::Module for Decoder {
338 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
339 let h = xs.apply(&self.conv_in)?;
340 let mut h = h
341 .apply(&self.mid_block_1)?
342 .apply(&self.mid_attn_1)?
343 .apply(&self.mid_block_2)?;
344 for block in self.up.iter().rev() {
345 for b in block.block.iter() {
346 h = h.apply(b)?
347 }
348 if let Some(us) = block.upsample.as_ref() {
349 h = h.apply(us)?
350 }
351 }
352 h.apply(&self.norm_out)?
353 .apply(&candle_nn::Activation::Swish)?
354 .apply(&self.conv_out)
355 }
356}
357
358#[derive(Debug, Clone)]
359pub struct DiagonalGaussian {
360 sample: bool,
361 chunk_dim: usize,
362}
363
364impl DiagonalGaussian {
365 pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
366 Ok(Self { sample, chunk_dim })
367 }
368}
369
370impl candle_nn::Module for DiagonalGaussian {
371 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
372 let chunks = xs.chunk(2, self.chunk_dim)?;
373 if self.sample {
374 let std = (&chunks[1] * 0.5)?.exp()?;
375 &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
376 } else {
377 Ok(chunks[0].clone())
378 }
379 }
380}
381
382#[derive(Debug, Clone)]
383pub struct AutoEncoder {
384 encoder: Encoder,
385 decoder: Decoder,
386 reg: DiagonalGaussian,
387 shift_factor: f64,
388 scale_factor: f64,
389}
390
391impl AutoEncoder {
392 pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
393 let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
394 let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
395 let reg = DiagonalGaussian::new(true, 1)?;
396 Ok(Self {
397 encoder,
398 decoder,
399 reg,
400 scale_factor: cfg.scaling_factor,
401 shift_factor: cfg.shift_factor,
402 })
403 }
404
405 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
406 let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
407 (z - self.shift_factor)? * self.scale_factor
408 }
409 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
410 let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
411 xs.apply(&self.decoder)
412 }
413}
414
415impl candle_core::Module for AutoEncoder {
416 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
417 self.decode(&self.encode(xs)?)
418 }
419}