diffusion_rs_core/models/vaes/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use std::sync::Arc;

use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl};
use diffusion_rs_common::{
    core::{DType, Device, Result, Tensor},
    ModelSource,
};
use serde::Deserialize;

use diffusion_rs_common::{from_mmaped_safetensors, FileData, VarBuilder};

mod autoencoder_kl;
mod vae;

pub(crate) trait VAEModel: Send + Sync {
    #[allow(dead_code)]
    /// This function *does not* handle scaling the tensor! If you want to do this, apply the following to the output:
    /// `(x - vae.shift_factor())? * self.scale_factor()`
    fn encode(&self, xs: &Tensor) -> Result<Tensor>;

    /// This function *does not* handle scaling the tensor! If you want to do this, apply the following to the input:
    /// `(x / vae.scale_factor())? + self.shift_factor()`
    fn decode(&self, xs: &Tensor) -> Result<Tensor>;

    fn shift_factor(&self) -> f64;

    fn scale_factor(&self) -> f64;
}

#[derive(Clone, Debug, Deserialize)]
struct VaeConfigShim {
    #[serde(rename = "_class_name")]
    name: String,
}

fn load_autoencoder_kl(
    cfg_json: &FileData,
    vb: VarBuilder,
    source: Arc<ModelSource>,
) -> anyhow::Result<Arc<dyn VAEModel>> {
    let cfg: AutencoderKlConfig = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
    Ok(Arc::new(AutoEncoderKl::new(&cfg, vb)?))
}

pub(crate) fn dispatch_load_vae_model(
    cfg_json: &FileData,
    safetensor_files: Vec<FileData>,
    device: &Device,
    dtype: DType,
    silent: bool,
    source: Arc<ModelSource>,
) -> anyhow::Result<Arc<dyn VAEModel>> {
    let vb = from_mmaped_safetensors(
        safetensor_files,
        Some(dtype),
        device,
        silent,
        source.clone(),
    )?;

    let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
    match name.as_str() {
        "AutoencoderKL" => load_autoencoder_kl(cfg_json, vb, source),
        other => anyhow::bail!("Unexpected VAE type `{other:?}`."),
    }
}