diffusion_rs_core/models/vaes/
mod.rsuse std::sync::Arc;
use autoencoder_kl::{AutencoderKlConfig, AutoEncoderKl};
use diffusion_rs_common::{
core::{DType, Device, Result, Tensor},
ModelSource,
};
use serde::Deserialize;
use diffusion_rs_common::{from_mmaped_safetensors, FileData, VarBuilder};
mod autoencoder_kl;
mod vae;
pub(crate) trait VAEModel: Send + Sync {
#[allow(dead_code)]
fn encode(&self, xs: &Tensor) -> Result<Tensor>;
fn decode(&self, xs: &Tensor) -> Result<Tensor>;
fn shift_factor(&self) -> f64;
fn scale_factor(&self) -> f64;
}
#[derive(Clone, Debug, Deserialize)]
struct VaeConfigShim {
#[serde(rename = "_class_name")]
name: String,
}
fn load_autoencoder_kl(
cfg_json: &FileData,
vb: VarBuilder,
source: Arc<ModelSource>,
) -> anyhow::Result<Arc<dyn VAEModel>> {
let cfg: AutencoderKlConfig = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
Ok(Arc::new(AutoEncoderKl::new(&cfg, vb)?))
}
pub(crate) fn dispatch_load_vae_model(
cfg_json: &FileData,
safetensor_files: Vec<FileData>,
device: &Device,
dtype: DType,
silent: bool,
source: Arc<ModelSource>,
) -> anyhow::Result<Arc<dyn VAEModel>> {
let vb = from_mmaped_safetensors(
safetensor_files,
Some(dtype),
device,
silent,
source.clone(),
)?;
let VaeConfigShim { name } = serde_json::from_str(&cfg_json.read_to_string(&source)?)?;
match name.as_str() {
"AutoencoderKL" => load_autoencoder_kl(cfg_json, vb, source),
other => anyhow::bail!("Unexpected VAE type `{other:?}`."),
}
}