mistralrs_core/diffusion_models/flux/
autoencoder.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Result, Tensor, D};
4use candle_nn::{Conv2d, GroupNorm};
5use mistralrs_quant::{Convolution, ShardedVarBuilder};
6use serde::Deserialize;
7
8use crate::layers::{conv2d, group_norm, MatMul};
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct Config {
12    pub in_channels: usize,
13    pub out_channels: usize,
14    pub block_out_channels: Vec<usize>,
15    pub layers_per_block: usize,
16    pub latent_channels: usize,
17    pub scaling_factor: f64,
18    pub shift_factor: f64,
19    pub norm_num_groups: usize,
20}
21
22fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
23    let dim = q.dim(D::Minus1)?;
24    let scale_factor = 1.0 / (dim as f64).sqrt();
25    let attn_weights = (MatMul.matmul(q, &k.t()?)? * scale_factor)?;
26    MatMul.matmul(&candle_nn::ops::softmax_last_dim(&attn_weights)?, v)
27}
28
29#[derive(Debug, Clone)]
30struct AttnBlock {
31    q: Conv2d,
32    k: Conv2d,
33    v: Conv2d,
34    proj_out: Conv2d,
35    norm: GroupNorm,
36}
37
38impl AttnBlock {
39    fn new(in_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
40        let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
41        let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
42        let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
43        let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
44        let norm = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm"))?;
45        Ok(Self {
46            q,
47            k,
48            v,
49            proj_out,
50            norm,
51        })
52    }
53}
54
55impl candle_core::Module for AttnBlock {
56    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57        let init_xs = xs;
58        let normed = self.norm.forward(xs)?;
59        let q = Convolution.forward_2d(&self.q, &normed)?;
60        let k = Convolution.forward_2d(&self.k, &normed)?;
61        let v = Convolution.forward_2d(&self.v, &normed)?;
62        let (b, c, h, w) = q.dims4()?;
63        let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
64        let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
65        let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
66        let attended = scaled_dot_product_attention(&q, &k, &v)?;
67        let attended = attended.squeeze(1)?.t()?.reshape((b, c, h, w))?;
68        let projected = Convolution.forward_2d(&self.proj_out, &attended)?;
69        projected + init_xs
70    }
71}
72
73#[derive(Debug, Clone)]
74struct ResnetBlock {
75    norm1: GroupNorm,
76    conv1: Conv2d,
77    norm2: GroupNorm,
78    conv2: Conv2d,
79    nin_shortcut: Option<Conv2d>,
80}
81
82impl ResnetBlock {
83    fn new(in_c: usize, out_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
84        let conv_cfg = candle_nn::Conv2dConfig {
85            padding: 1,
86            ..Default::default()
87        };
88        let norm1 = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm1"))?;
89        let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
90        let norm2 = group_norm(cfg.norm_num_groups, out_c, 1e-6, vb.pp("norm2"))?;
91        let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
92        let nin_shortcut = if in_c == out_c {
93            None
94        } else {
95            Some(conv2d(
96                in_c,
97                out_c,
98                1,
99                Default::default(),
100                vb.pp("nin_shortcut"),
101            )?)
102        };
103        Ok(Self {
104            norm1,
105            conv1,
106            norm2,
107            conv2,
108            nin_shortcut,
109        })
110    }
111}
112
113impl candle_core::Module for ResnetBlock {
114    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
115        let mut h = self.norm1.forward(xs)?;
116        h = candle_nn::Activation::Swish.forward(&h)?;
117        h = Convolution.forward_2d(&self.conv1, &h)?;
118        h = self.norm2.forward(&h)?;
119        h = candle_nn::Activation::Swish.forward(&h)?;
120        h = Convolution.forward_2d(&self.conv2, &h)?;
121        match self.nin_shortcut.as_ref() {
122            None => xs + h,
123            Some(c) => Convolution.forward_2d(c, xs)? + h,
124        }
125    }
126}
127
128#[derive(Debug, Clone)]
129struct Downsample {
130    conv: Conv2d,
131}
132
133impl Downsample {
134    fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
135        let conv_cfg = candle_nn::Conv2dConfig {
136            stride: 2,
137            ..Default::default()
138        };
139        let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
140        Ok(Self { conv })
141    }
142}
143
144impl candle_core::Module for Downsample {
145    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
146        let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
147        let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
148        Convolution.forward_2d(&self.conv, &xs)
149    }
150}
151
152#[derive(Debug, Clone)]
153struct Upsample {
154    conv: Conv2d,
155}
156
157impl Upsample {
158    fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
159        let conv_cfg = candle_nn::Conv2dConfig {
160            padding: 1,
161            ..Default::default()
162        };
163        let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
164        Ok(Self { conv })
165    }
166}
167
168impl candle_core::Module for Upsample {
169    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
170        let (_, _, h, w) = xs.dims4()?;
171        let upsampled = xs.upsample_nearest2d(h * 2, w * 2)?;
172        Convolution.forward_2d(&self.conv, &upsampled)
173    }
174}
175
176#[derive(Debug, Clone)]
177struct DownBlock {
178    block: Vec<ResnetBlock>,
179    downsample: Option<Downsample>,
180}
181
182#[derive(Debug, Clone)]
183pub struct Encoder {
184    conv_in: Conv2d,
185    mid_block_1: ResnetBlock,
186    mid_attn_1: AttnBlock,
187    mid_block_2: ResnetBlock,
188    norm_out: GroupNorm,
189    conv_out: Conv2d,
190    down: Vec<DownBlock>,
191}
192
193impl Encoder {
194    pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
195        let conv_cfg = candle_nn::Conv2dConfig {
196            padding: 1,
197            ..Default::default()
198        };
199        let base_ch = cfg.block_out_channels[0];
200        let mut block_in = base_ch;
201        let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
202
203        let mut down = Vec::with_capacity(cfg.block_out_channels.len());
204        let vb_d = vb.pp("down");
205        for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate() {
206            let mut block = Vec::with_capacity(cfg.layers_per_block);
207            let vb_d = vb_d.pp(i_level);
208            let vb_b = vb_d.pp("block");
209            block_in = if i_level == 0 {
210                base_ch
211            } else {
212                cfg.block_out_channels[i_level - 1]
213            };
214            let block_out = *out_channels;
215            for i_block in 0..cfg.layers_per_block {
216                let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
217                block.push(b);
218                block_in = block_out;
219            }
220            let downsample = if i_level != cfg.block_out_channels.len() - 1 {
221                Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
222            } else {
223                None
224            };
225            let block = DownBlock { block, downsample };
226            down.push(block)
227        }
228
229        let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
230        let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
231        let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
232        let conv_out = conv2d(
233            block_in,
234            2 * cfg.latent_channels,
235            3,
236            conv_cfg,
237            vb.pp("conv_out"),
238        )?;
239        let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
240        Ok(Self {
241            conv_in,
242            mid_block_1,
243            mid_attn_1,
244            mid_block_2,
245            norm_out,
246            conv_out,
247            down,
248        })
249    }
250}
251
252impl candle_nn::Module for Encoder {
253    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
254        let mut h = Convolution.forward_2d(&self.conv_in, xs)?;
255        for block in self.down.iter() {
256            for b in block.block.iter() {
257                h = b.forward(&h)?
258            }
259            if let Some(ds) = block.downsample.as_ref() {
260                h = ds.forward(&h)?
261            }
262        }
263        h = self.mid_block_1.forward(&h)?;
264        h = self.mid_attn_1.forward(&h)?;
265        h = self.mid_block_2.forward(&h)?;
266        h = self.norm_out.forward(&h)?;
267        h = candle_nn::Activation::Swish.forward(&h)?;
268        Convolution.forward_2d(&self.conv_out, &h)
269    }
270}
271
272#[derive(Debug, Clone)]
273struct UpBlock {
274    block: Vec<ResnetBlock>,
275    upsample: Option<Upsample>,
276}
277
278#[derive(Debug, Clone)]
279pub struct Decoder {
280    conv_in: Conv2d,
281    mid_block_1: ResnetBlock,
282    mid_attn_1: AttnBlock,
283    mid_block_2: ResnetBlock,
284    norm_out: GroupNorm,
285    conv_out: Conv2d,
286    up: Vec<UpBlock>,
287}
288
289impl Decoder {
290    pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
291        let conv_cfg = candle_nn::Conv2dConfig {
292            padding: 1,
293            ..Default::default()
294        };
295        let base_ch = cfg.block_out_channels[0];
296        let mut block_in = cfg.block_out_channels.last().copied().unwrap_or(base_ch);
297        let conv_in = conv2d(cfg.latent_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
298        let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
299        let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
300        let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
301
302        let mut up = Vec::with_capacity(cfg.block_out_channels.len());
303        let vb_u = vb.pp("up");
304        for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate().rev() {
305            let block_out = *out_channels;
306            let vb_u = vb_u.pp(i_level);
307            let vb_b = vb_u.pp("block");
308            let mut block = Vec::with_capacity(cfg.layers_per_block + 1);
309            for i_block in 0..=cfg.layers_per_block {
310                let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
311                block.push(b);
312                block_in = block_out;
313            }
314            let upsample = if i_level != 0 {
315                Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
316            } else {
317                None
318            };
319            let block = UpBlock { block, upsample };
320            up.push(block)
321        }
322        up.reverse();
323
324        let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
325        let conv_out = conv2d(block_in, cfg.out_channels, 3, conv_cfg, vb.pp("conv_out"))?;
326        Ok(Self {
327            conv_in,
328            mid_block_1,
329            mid_attn_1,
330            mid_block_2,
331            norm_out,
332            conv_out,
333            up,
334        })
335    }
336}
337
338impl candle_nn::Module for Decoder {
339    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
340        let mut h = Convolution.forward_2d(&self.conv_in, xs)?;
341        h = self.mid_block_1.forward(&h)?;
342        h = self.mid_attn_1.forward(&h)?;
343        h = self.mid_block_2.forward(&h)?;
344        for block in self.up.iter().rev() {
345            for b in block.block.iter() {
346                h = b.forward(&h)?
347            }
348            if let Some(us) = block.upsample.as_ref() {
349                h = us.forward(&h)?
350            }
351        }
352        h = self.norm_out.forward(&h)?;
353        h = candle_nn::Activation::Swish.forward(&h)?;
354        Convolution.forward_2d(&self.conv_out, &h)
355    }
356}
357
358#[derive(Debug, Clone)]
359pub struct DiagonalGaussian {
360    sample: bool,
361    chunk_dim: usize,
362}
363
364impl DiagonalGaussian {
365    pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
366        Ok(Self { sample, chunk_dim })
367    }
368}
369
370impl candle_nn::Module for DiagonalGaussian {
371    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
372        let chunks = xs.chunk(2, self.chunk_dim)?;
373        if self.sample {
374            let std = (&chunks[1] * 0.5)?.exp()?;
375            &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
376        } else {
377            Ok(chunks[0].clone())
378        }
379    }
380}
381
382#[derive(Debug, Clone)]
383pub struct AutoEncoder {
384    encoder: Encoder,
385    decoder: Decoder,
386    reg: DiagonalGaussian,
387    shift_factor: f64,
388    scale_factor: f64,
389}
390
391impl AutoEncoder {
392    pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
393        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
394        let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
395        let reg = DiagonalGaussian::new(true, 1)?;
396        Ok(Self {
397            encoder,
398            decoder,
399            reg,
400            scale_factor: cfg.scaling_factor,
401            shift_factor: cfg.shift_factor,
402        })
403    }
404
405    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
406        let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
407        (z - self.shift_factor)? * self.scale_factor
408    }
409    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
410        let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
411        xs.apply(&self.decoder)
412    }
413}
414
415impl candle_core::Module for AutoEncoder {
416    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
417        self.decode(&self.encode(xs)?)
418    }
419}