diffusion_rs_core/models/vaes/
autoencoder_kl.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use diffusion_rs_common::core::{Result, Tensor};
use diffusion_rs_common::nn::{Activation, Conv2d, Conv2dConfig};
use diffusion_rs_common::VarBuilder;
use serde::Deserialize;

use super::{
    vae::{Decoder, DiagonalGaussian, Encoder, VAEConfig},
    VAEModel,
};

fn default_act() -> Activation {
    Activation::Silu
}

#[derive(Debug, Clone, Deserialize)]
pub struct AutencoderKlConfig {
    pub in_channels: usize,
    pub out_channels: usize,
    pub block_out_channels: Vec<usize>,
    pub layers_per_block: usize,
    #[serde(default = "default_act")]
    pub act_fn: Activation,
    pub latent_channels: usize,
    pub norm_num_groups: usize,
    pub scaling_factor: f64,
    pub shift_factor: f64,
    pub mid_block_add_attention: bool,
    pub use_quant_conv: bool,
    pub use_post_quant_conv: bool,
    pub down_block_types: Vec<String>,
    pub up_block_types: Vec<String>,
}

impl From<AutencoderKlConfig> for VAEConfig {
    fn from(value: AutencoderKlConfig) -> Self {
        Self {
            in_channels: value.in_channels,
            out_channels: value.out_channels,
            block_out_channels: value.block_out_channels,
            layers_per_block: value.layers_per_block,
            act_fn: value.act_fn,
            latent_channels: value.latent_channels,
            norm_num_groups: value.norm_num_groups,
            mid_block_add_attention: value.mid_block_add_attention,
            down_block_types: value.down_block_types,
            up_block_types: value.up_block_types,
        }
    }
}

#[derive(Debug, Clone)]
pub struct AutoEncoderKl {
    encoder: Encoder,
    decoder: Decoder,
    reg: DiagonalGaussian,
    quant_conv: Option<Conv2d>,
    post_quant_conv: Option<Conv2d>,
    shift_factor: f64,
    scale_factor: f64,
}

impl AutoEncoderKl {
    pub fn new(cfg: &AutencoderKlConfig, vb: VarBuilder) -> Result<Self> {
        let encoder = Encoder::new(&cfg.clone().into(), vb.pp("encoder"))?;
        let decoder = Decoder::new(&cfg.clone().into(), vb.pp("decoder"))?;
        let reg = DiagonalGaussian::new(true, 1)?;
        let quant_conv = if cfg.use_quant_conv {
            Some(diffusion_rs_common::conv2d(
                2 * cfg.latent_channels,
                2 * cfg.latent_channels,
                1,
                Conv2dConfig::default(),
                vb.pp("quant_conv"),
            )?)
        } else {
            None
        };
        let post_quant_conv = if cfg.use_post_quant_conv {
            Some(diffusion_rs_common::conv2d(
                cfg.latent_channels,
                cfg.latent_channels,
                1,
                Conv2dConfig::default(),
                vb.pp("post_quant_conv"),
            )?)
        } else {
            None
        };
        Ok(Self {
            encoder,
            decoder,
            reg,
            scale_factor: cfg.scaling_factor,
            shift_factor: cfg.shift_factor,
            quant_conv,
            post_quant_conv,
        })
    }
}

impl VAEModel for AutoEncoderKl {
    fn encode(&self, xs: &Tensor) -> Result<Tensor> {
        let mut z = xs.apply(&self.encoder)?;
        if let Some(conv) = &self.quant_conv {
            z = z.apply(conv)?;
        }
        z = z.apply(&self.reg)?;
        // (z - self.shift_factor)? * self.scale_factor
        Ok(z)
    }

    fn decode(&self, xs: &Tensor) -> Result<Tensor> {
        // let xs = ((xs / self.scale_factor)? + self.shift_factor)?;
        let mut z = xs.apply(&self.decoder)?;
        if let Some(conv) = &self.post_quant_conv {
            z = z.apply(conv)?;
        }
        Ok(z)
    }

    fn shift_factor(&self) -> f64 {
        self.shift_factor
    }

    fn scale_factor(&self) -> f64 {
        self.scale_factor
    }
}