mistralrs_core/diffusion_models/flux/
autoencoder.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Result, Tensor, D};
4use candle_nn::{Conv2d, GroupNorm};
5use mistralrs_quant::ShardedVarBuilder;
6use serde::Deserialize;
7
8use crate::layers::{conv2d, group_norm, MatMul};
9
10#[derive(Debug, Clone, Deserialize)]
11pub struct Config {
12    pub in_channels: usize,
13    pub out_channels: usize,
14    pub block_out_channels: Vec<usize>,
15    pub layers_per_block: usize,
16    pub latent_channels: usize,
17    pub scaling_factor: f64,
18    pub shift_factor: f64,
19    pub norm_num_groups: usize,
20}
21
22fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
23    let dim = q.dim(D::Minus1)?;
24    let scale_factor = 1.0 / (dim as f64).sqrt();
25    let attn_weights = (MatMul.matmul(q, &k.t()?)? * scale_factor)?;
26    MatMul.matmul(&candle_nn::ops::softmax_last_dim(&attn_weights)?, v)
27}
28
29#[derive(Debug, Clone)]
30struct AttnBlock {
31    q: Conv2d,
32    k: Conv2d,
33    v: Conv2d,
34    proj_out: Conv2d,
35    norm: GroupNorm,
36}
37
38impl AttnBlock {
39    fn new(in_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
40        let q = conv2d(in_c, in_c, 1, Default::default(), vb.pp("q"))?;
41        let k = conv2d(in_c, in_c, 1, Default::default(), vb.pp("k"))?;
42        let v = conv2d(in_c, in_c, 1, Default::default(), vb.pp("v"))?;
43        let proj_out = conv2d(in_c, in_c, 1, Default::default(), vb.pp("proj_out"))?;
44        let norm = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm"))?;
45        Ok(Self {
46            q,
47            k,
48            v,
49            proj_out,
50            norm,
51        })
52    }
53}
54
55impl candle_core::Module for AttnBlock {
56    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
57        let init_xs = xs;
58        let xs = xs.apply(&self.norm)?;
59        let q = xs.apply(&self.q)?;
60        let k = xs.apply(&self.k)?;
61        let v = xs.apply(&self.v)?;
62        let (b, c, h, w) = q.dims4()?;
63        let q = q.flatten_from(2)?.t()?.unsqueeze(1)?;
64        let k = k.flatten_from(2)?.t()?.unsqueeze(1)?;
65        let v = v.flatten_from(2)?.t()?.unsqueeze(1)?;
66        let xs = scaled_dot_product_attention(&q, &k, &v)?;
67        let xs = xs.squeeze(1)?.t()?.reshape((b, c, h, w))?;
68        xs.apply(&self.proj_out)? + init_xs
69    }
70}
71
72#[derive(Debug, Clone)]
73struct ResnetBlock {
74    norm1: GroupNorm,
75    conv1: Conv2d,
76    norm2: GroupNorm,
77    conv2: Conv2d,
78    nin_shortcut: Option<Conv2d>,
79}
80
81impl ResnetBlock {
82    fn new(in_c: usize, out_c: usize, vb: ShardedVarBuilder, cfg: &Config) -> Result<Self> {
83        let conv_cfg = candle_nn::Conv2dConfig {
84            padding: 1,
85            ..Default::default()
86        };
87        let norm1 = group_norm(cfg.norm_num_groups, in_c, 1e-6, vb.pp("norm1"))?;
88        let conv1 = conv2d(in_c, out_c, 3, conv_cfg, vb.pp("conv1"))?;
89        let norm2 = group_norm(cfg.norm_num_groups, out_c, 1e-6, vb.pp("norm2"))?;
90        let conv2 = conv2d(out_c, out_c, 3, conv_cfg, vb.pp("conv2"))?;
91        let nin_shortcut = if in_c == out_c {
92            None
93        } else {
94            Some(conv2d(
95                in_c,
96                out_c,
97                1,
98                Default::default(),
99                vb.pp("nin_shortcut"),
100            )?)
101        };
102        Ok(Self {
103            norm1,
104            conv1,
105            norm2,
106            conv2,
107            nin_shortcut,
108        })
109    }
110}
111
112impl candle_core::Module for ResnetBlock {
113    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
114        let h = xs
115            .apply(&self.norm1)?
116            .apply(&candle_nn::Activation::Swish)?
117            .apply(&self.conv1)?
118            .apply(&self.norm2)?
119            .apply(&candle_nn::Activation::Swish)?
120            .apply(&self.conv2)?;
121        match self.nin_shortcut.as_ref() {
122            None => xs + h,
123            Some(c) => xs.apply(c)? + h,
124        }
125    }
126}
127
128#[derive(Debug, Clone)]
129struct Downsample {
130    conv: Conv2d,
131}
132
133impl Downsample {
134    fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
135        let conv_cfg = candle_nn::Conv2dConfig {
136            stride: 2,
137            ..Default::default()
138        };
139        let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
140        Ok(Self { conv })
141    }
142}
143
144impl candle_core::Module for Downsample {
145    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
146        let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
147        let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
148        xs.apply(&self.conv)
149    }
150}
151
152#[derive(Debug, Clone)]
153struct Upsample {
154    conv: Conv2d,
155}
156
157impl Upsample {
158    fn new(in_c: usize, vb: ShardedVarBuilder) -> Result<Self> {
159        let conv_cfg = candle_nn::Conv2dConfig {
160            padding: 1,
161            ..Default::default()
162        };
163        let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
164        Ok(Self { conv })
165    }
166}
167
168impl candle_core::Module for Upsample {
169    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
170        let (_, _, h, w) = xs.dims4()?;
171        xs.upsample_nearest2d(h * 2, w * 2)?.apply(&self.conv)
172    }
173}
174
175#[derive(Debug, Clone)]
176struct DownBlock {
177    block: Vec<ResnetBlock>,
178    downsample: Option<Downsample>,
179}
180
181#[derive(Debug, Clone)]
182pub struct Encoder {
183    conv_in: Conv2d,
184    mid_block_1: ResnetBlock,
185    mid_attn_1: AttnBlock,
186    mid_block_2: ResnetBlock,
187    norm_out: GroupNorm,
188    conv_out: Conv2d,
189    down: Vec<DownBlock>,
190}
191
192impl Encoder {
193    pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
194        let conv_cfg = candle_nn::Conv2dConfig {
195            padding: 1,
196            ..Default::default()
197        };
198        let base_ch = cfg.block_out_channels[0];
199        let mut block_in = base_ch;
200        let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
201
202        let mut down = Vec::with_capacity(cfg.block_out_channels.len());
203        let vb_d = vb.pp("down");
204        for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate() {
205            let mut block = Vec::with_capacity(cfg.layers_per_block);
206            let vb_d = vb_d.pp(i_level);
207            let vb_b = vb_d.pp("block");
208            block_in = if i_level == 0 {
209                base_ch
210            } else {
211                cfg.block_out_channels[i_level - 1]
212            };
213            let block_out = *out_channels;
214            for i_block in 0..cfg.layers_per_block {
215                let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
216                block.push(b);
217                block_in = block_out;
218            }
219            let downsample = if i_level != cfg.block_out_channels.len() - 1 {
220                Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
221            } else {
222                None
223            };
224            let block = DownBlock { block, downsample };
225            down.push(block)
226        }
227
228        let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
229        let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
230        let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
231        let conv_out = conv2d(
232            block_in,
233            2 * cfg.latent_channels,
234            3,
235            conv_cfg,
236            vb.pp("conv_out"),
237        )?;
238        let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
239        Ok(Self {
240            conv_in,
241            mid_block_1,
242            mid_attn_1,
243            mid_block_2,
244            norm_out,
245            conv_out,
246            down,
247        })
248    }
249}
250
251impl candle_nn::Module for Encoder {
252    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
253        let mut h = xs.apply(&self.conv_in)?;
254        for block in self.down.iter() {
255            for b in block.block.iter() {
256                h = h.apply(b)?
257            }
258            if let Some(ds) = block.downsample.as_ref() {
259                h = h.apply(ds)?
260            }
261        }
262        h.apply(&self.mid_block_1)?
263            .apply(&self.mid_attn_1)?
264            .apply(&self.mid_block_2)?
265            .apply(&self.norm_out)?
266            .apply(&candle_nn::Activation::Swish)?
267            .apply(&self.conv_out)
268    }
269}
270
271#[derive(Debug, Clone)]
272struct UpBlock {
273    block: Vec<ResnetBlock>,
274    upsample: Option<Upsample>,
275}
276
277#[derive(Debug, Clone)]
278pub struct Decoder {
279    conv_in: Conv2d,
280    mid_block_1: ResnetBlock,
281    mid_attn_1: AttnBlock,
282    mid_block_2: ResnetBlock,
283    norm_out: GroupNorm,
284    conv_out: Conv2d,
285    up: Vec<UpBlock>,
286}
287
288impl Decoder {
289    pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
290        let conv_cfg = candle_nn::Conv2dConfig {
291            padding: 1,
292            ..Default::default()
293        };
294        let base_ch = cfg.block_out_channels[0];
295        let mut block_in = cfg.block_out_channels.last().copied().unwrap_or(base_ch);
296        let conv_in = conv2d(cfg.latent_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
297        let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"), cfg)?;
298        let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"), cfg)?;
299        let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"), cfg)?;
300
301        let mut up = Vec::with_capacity(cfg.block_out_channels.len());
302        let vb_u = vb.pp("up");
303        for (i_level, out_channels) in cfg.block_out_channels.iter().enumerate().rev() {
304            let block_out = *out_channels;
305            let vb_u = vb_u.pp(i_level);
306            let vb_b = vb_u.pp("block");
307            let mut block = Vec::with_capacity(cfg.layers_per_block + 1);
308            for i_block in 0..=cfg.layers_per_block {
309                let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block), cfg)?;
310                block.push(b);
311                block_in = block_out;
312            }
313            let upsample = if i_level != 0 {
314                Some(Upsample::new(block_in, vb_u.pp("upsample"))?)
315            } else {
316                None
317            };
318            let block = UpBlock { block, upsample };
319            up.push(block)
320        }
321        up.reverse();
322
323        let norm_out = group_norm(cfg.norm_num_groups, block_in, 1e-6, vb.pp("norm_out"))?;
324        let conv_out = conv2d(block_in, cfg.out_channels, 3, conv_cfg, vb.pp("conv_out"))?;
325        Ok(Self {
326            conv_in,
327            mid_block_1,
328            mid_attn_1,
329            mid_block_2,
330            norm_out,
331            conv_out,
332            up,
333        })
334    }
335}
336
337impl candle_nn::Module for Decoder {
338    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
339        let h = xs.apply(&self.conv_in)?;
340        let mut h = h
341            .apply(&self.mid_block_1)?
342            .apply(&self.mid_attn_1)?
343            .apply(&self.mid_block_2)?;
344        for block in self.up.iter().rev() {
345            for b in block.block.iter() {
346                h = h.apply(b)?
347            }
348            if let Some(us) = block.upsample.as_ref() {
349                h = h.apply(us)?
350            }
351        }
352        h.apply(&self.norm_out)?
353            .apply(&candle_nn::Activation::Swish)?
354            .apply(&self.conv_out)
355    }
356}
357
358#[derive(Debug, Clone)]
359pub struct DiagonalGaussian {
360    sample: bool,
361    chunk_dim: usize,
362}
363
364impl DiagonalGaussian {
365    pub fn new(sample: bool, chunk_dim: usize) -> Result<Self> {
366        Ok(Self { sample, chunk_dim })
367    }
368}
369
370impl candle_nn::Module for DiagonalGaussian {
371    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
372        let chunks = xs.chunk(2, self.chunk_dim)?;
373        if self.sample {
374            let std = (&chunks[1] * 0.5)?.exp()?;
375            &chunks[0] + (std * chunks[0].randn_like(0., 1.))?
376        } else {
377            Ok(chunks[0].clone())
378        }
379    }
380}
381
382#[derive(Debug, Clone)]
383pub struct AutoEncoder {
384    encoder: Encoder,
385    decoder: Decoder,
386    reg: DiagonalGaussian,
387    shift_factor: f64,
388    scale_factor: f64,
389}
390
391impl AutoEncoder {
392    pub fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
393        let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
394        let decoder = Decoder::new(cfg, vb.pp("decoder"))?;
395        let reg = DiagonalGaussian::new(true, 1)?;
396        Ok(Self {
397            encoder,
398            decoder,
399            reg,
400            scale_factor: cfg.scaling_factor,
401            shift_factor: cfg.shift_factor,
402        })
403    }
404
405    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
406        let z = xs.apply(&self.encoder)?.apply(&self.reg)?;
407        (z - self.shift_factor)? * self.scale_factor
408    }
409    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
410        let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
411        xs.apply(&self.decoder)
412    }
413}
414
415impl candle_core::Module for AutoEncoder {
416    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
417        self.decode(&self.encode(xs)?)
418    }
419}