mistralrs_core/diffusion_models/flux/
stepper.rsuse std::{cmp::Ordering, fs::File};
use candle_core::{DType, Device, Result, Tensor, D};
use candle_nn::{Module, VarBuilder};
use hf_hub::api::sync::{Api, ApiError};
use tokenizers::Tokenizer;
use tracing::info;
use crate::{
diffusion_models::{
clip::text::{ClipConfig, ClipTextTransformer},
flux,
t5::{self, T5EncoderModel},
DiffusionGenerationParams,
},
pipeline::DiffusionModel,
utils::varbuilder_utils::from_mmaped_safetensors,
};
use super::{autoencoder::AutoEncoder, model::Flux};
const T5_XXL_SAFETENSOR_FILES: &[&str] =
&["t5_xxl-shard-0.safetensors", "t5_xxl-shard-1.safetensors"];
#[derive(Clone, Copy, Debug)]
pub struct FluxStepperShift {
pub base_shift: f64,
pub max_shift: f64,
pub guidance_scale: f64,
}
#[derive(Clone, Copy, Debug)]
pub struct FluxStepperConfig {
pub num_steps: usize,
pub guidance_config: Option<FluxStepperShift>,
pub is_guidance: bool,
}
impl FluxStepperConfig {
pub fn default_for_guidance(has_guidance: bool) -> Self {
if has_guidance {
Self {
num_steps: 50,
guidance_config: Some(FluxStepperShift {
base_shift: 0.5,
max_shift: 1.15,
guidance_scale: 4.0,
}),
is_guidance: true,
}
} else {
Self {
num_steps: 4,
guidance_config: None,
is_guidance: false,
}
}
}
}
pub struct FluxStepper {
cfg: FluxStepperConfig,
t5_tok: Tokenizer,
clip_tok: Tokenizer,
clip_text: ClipTextTransformer,
flux_model: Flux,
flux_vae: AutoEncoder,
is_guidance: bool,
device: Device,
dtype: DType,
api: Api,
silent: bool,
offloaded: bool,
}
fn get_t5_tokenizer(api: &Api) -> anyhow::Result<Tokenizer> {
let tokenizer_filename = api
.model("EricB/t5_tokenizer".to_string())
.get("t5-v1_1-xxl.tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
Ok(tokenizer)
}
fn get_t5_model(
api: &Api,
dtype: DType,
device: &Device,
silent: bool,
offloaded: bool,
) -> candle_core::Result<T5EncoderModel> {
let repo = api.repo(hf_hub::Repo::with_revision(
"EricB/t5-v1_1-xxl-enc-only".to_string(),
hf_hub::RepoType::Model,
"main".to_string(),
));
let vb = from_mmaped_safetensors(
T5_XXL_SAFETENSOR_FILES
.iter()
.map(|f| repo.get(f))
.collect::<std::result::Result<Vec<_>, ApiError>>()
.map_err(candle_core::Error::msg)?,
vec![],
Some(dtype),
device,
silent,
None,
|_| true,
)?;
let config_filename = repo.get("config.json").map_err(candle_core::Error::msg)?;
let config = std::fs::read_to_string(config_filename)?;
let config: t5::Config = serde_json::from_str(&config).map_err(candle_core::Error::msg)?;
t5::T5EncoderModel::load(vb, &config, device, offloaded)
}
fn get_clip_model_and_tokenizer(
api: &Api,
device: &Device,
silent: bool,
) -> anyhow::Result<(ClipTextTransformer, Tokenizer)> {
let repo = api.repo(hf_hub::Repo::model(
"openai/clip-vit-large-patch14".to_string(),
));
let model_file = repo.get("model.safetensors")?;
let vb = from_mmaped_safetensors(vec![model_file], vec![], None, device, silent, None, |_| {
true
})?;
let config_file = repo.get("config.json")?;
let config: ClipConfig = serde_json::from_reader(File::open(config_file)?)?;
let config = config.text_config;
let model = ClipTextTransformer::new(vb.pp("text_model"), &config)?;
let tokenizer_filename = repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
Ok((model, tokenizer))
}
fn get_tokenization(tok: &Tokenizer, prompts: Vec<String>, device: &Device) -> Result<Tensor> {
Tensor::new(
tok.encode_batch(prompts, true)
.map_err(|e| candle_core::Error::Msg(e.to_string()))?
.into_iter()
.map(|e| e.get_ids().to_vec())
.collect::<Vec<_>>(),
device,
)
}
impl FluxStepper {
pub fn new(
cfg: FluxStepperConfig,
(flux_vb, flux_cfg): (VarBuilder, &flux::model::Config),
(flux_ae_vb, flux_ae_cfg): (VarBuilder, &flux::autoencoder::Config),
dtype: DType,
device: &Device,
silent: bool,
offloaded: bool,
) -> anyhow::Result<Self> {
let api = Api::new()?;
info!("Loading T5 XXL tokenizer.");
let t5_tokenizer = get_t5_tokenizer(&api)?;
info!("Loading CLIP model and tokenizer.");
let (clip_encoder, clip_tokenizer) = get_clip_model_and_tokenizer(&api, device, silent)?;
Ok(Self {
cfg,
t5_tok: t5_tokenizer,
clip_tok: clip_tokenizer,
clip_text: clip_encoder,
flux_model: Flux::new(flux_cfg, flux_vb, device.clone(), offloaded)?,
flux_vae: AutoEncoder::new(flux_ae_cfg, flux_ae_vb)?,
is_guidance: cfg.is_guidance,
device: device.clone(),
dtype,
api,
silent,
offloaded,
})
}
}
impl DiffusionModel for FluxStepper {
fn forward(
&mut self,
prompts: Vec<String>,
params: DiffusionGenerationParams,
) -> Result<Tensor> {
let mut t5_input_ids = get_tokenization(&self.t5_tok, prompts.clone(), &self.device)?;
if !self.is_guidance {
match t5_input_ids.dim(1)?.cmp(&256) {
Ordering::Greater => {
candle_core::bail!("T5 embedding length greater than 256, please shrink the prompt or use the -dev (with guidance distillation) version.")
}
Ordering::Less | Ordering::Equal => {
t5_input_ids =
t5_input_ids.pad_with_zeros(D::Minus1, 0, 256 - t5_input_ids.dim(1)?)?;
}
}
}
let t5_embed = {
info!("Hotloading T5 XXL model.");
let mut t5_encoder = get_t5_model(
&self.api,
self.dtype,
&self.device,
self.silent,
self.offloaded,
)?;
t5_encoder.forward(&t5_input_ids)?
};
let clip_input_ids = get_tokenization(&self.clip_tok, prompts, &self.device)?;
let clip_embed = self
.clip_text
.forward(&clip_input_ids)?
.to_dtype(self.dtype)?;
let img = flux::sampling::get_noise(
t5_embed.dim(0)?,
params.height,
params.width,
self.device(),
)?
.to_dtype(self.dtype)?;
let state = flux::sampling::State::new(&t5_embed, &clip_embed, &img)?;
let timesteps = flux::sampling::get_schedule(
self.cfg.num_steps,
self.cfg
.guidance_config
.map(|s| (state.img.dims()[1], s.base_shift, s.max_shift)),
);
let img = if let Some(guidance_cfg) = &self.cfg.guidance_config {
flux::sampling::denoise(
&mut self.flux_model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
×teps,
guidance_cfg.guidance_scale,
)?
} else {
flux::sampling::denoise_no_guidance(
&mut self.flux_model,
&state.img,
&state.img_ids,
&state.txt,
&state.txt_ids,
&state.vec,
×teps,
)?
};
let latent_img = flux::sampling::unpack(&img, params.height, params.width)?;
let img = self.flux_vae.decode(&latent_img)?;
let normalized_img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
Ok(normalized_img)
}
fn device(&self) -> &Device {
&self.device
}
fn max_seq_len(&self) -> usize {
if self.is_guidance {
usize::MAX
} else {
256
}
}
}