1use std::{cmp::Ordering, fs::File, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor, D};
4use candle_nn::Module;
5use hf_hub::api::sync::{Api, ApiError};
6use mistralrs_quant::ShardedVarBuilder;
7use tokenizers::Tokenizer;
8use tracing::info;
9
10use crate::{
11 diffusion_models::{
12 clip::text::{ClipConfig, ClipTextTransformer},
13 flux,
14 t5::{self, T5EncoderModel},
15 DiffusionGenerationParams,
16 },
17 pipeline::DiffusionModel,
18 utils::varbuilder_utils::{from_mmaped_safetensors, DeviceForLoadTensor},
19};
20
21use super::{autoencoder::AutoEncoder, model::Flux};
22
23const T5_XXL_SAFETENSOR_FILES: &[&str] =
24 &["t5_xxl-shard-0.safetensors", "t5_xxl-shard-1.safetensors"];
25
26#[derive(Clone, Copy, Debug)]
27pub struct FluxStepperShift {
28 pub base_shift: f64,
29 pub max_shift: f64,
30 pub guidance_scale: f64,
31}
32
33#[derive(Clone, Copy, Debug)]
34pub struct FluxStepperConfig {
35 pub num_steps: usize,
36 pub guidance_config: Option<FluxStepperShift>,
37 pub is_guidance: bool,
38}
39
40impl FluxStepperConfig {
41 pub fn default_for_guidance(has_guidance: bool) -> Self {
42 if has_guidance {
43 Self {
44 num_steps: 50,
45 guidance_config: Some(FluxStepperShift {
46 base_shift: 0.5,
47 max_shift: 1.15,
48 guidance_scale: 4.0,
49 }),
50 is_guidance: true,
51 }
52 } else {
53 Self {
54 num_steps: 4,
55 guidance_config: None,
56 is_guidance: false,
57 }
58 }
59 }
60}
61
62pub struct FluxStepper {
63 cfg: FluxStepperConfig,
64 t5_tok: Tokenizer,
65 clip_tok: Tokenizer,
66 clip_text: ClipTextTransformer,
67 flux_model: Flux,
68 flux_vae: AutoEncoder,
69 is_guidance: bool,
70 device: Device,
71 dtype: DType,
72 api: Api,
73 silent: bool,
74 offloaded: bool,
75}
76
77fn get_t5_tokenizer(api: &Api) -> anyhow::Result<Tokenizer> {
78 let tokenizer_filename = api
79 .model("EricB/t5_tokenizer".to_string())
80 .get("t5-v1_1-xxl.tokenizer.json")?;
81 let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
82
83 Ok(tokenizer)
84}
85
86fn get_t5_model(
87 api: &Api,
88 dtype: DType,
89 device: &Device,
90 silent: bool,
91 offloaded: bool,
92) -> candle_core::Result<T5EncoderModel> {
93 let repo = api.repo(hf_hub::Repo::with_revision(
94 "EricB/t5-v1_1-xxl-enc-only".to_string(),
95 hf_hub::RepoType::Model,
96 "main".to_string(),
97 ));
98
99 let vb = from_mmaped_safetensors(
100 T5_XXL_SAFETENSOR_FILES
101 .iter()
102 .map(|f| repo.get(f))
103 .collect::<std::result::Result<Vec<_>, ApiError>>()
104 .map_err(candle_core::Error::msg)?,
105 vec![],
106 Some(dtype),
107 device,
108 vec![None],
109 silent,
110 None,
111 |_| true,
112 Arc::new(|_| DeviceForLoadTensor::Base),
113 )?;
114 let config_filename = repo.get("config.json").map_err(candle_core::Error::msg)?;
115 let config = std::fs::read_to_string(config_filename)?;
116 let config: t5::Config = serde_json::from_str(&config).map_err(candle_core::Error::msg)?;
117
118 t5::T5EncoderModel::load(vb, &config, device, offloaded)
119}
120
121fn get_clip_model_and_tokenizer(
122 api: &Api,
123 device: &Device,
124 silent: bool,
125) -> anyhow::Result<(ClipTextTransformer, Tokenizer)> {
126 let repo = api.repo(hf_hub::Repo::model(
127 "openai/clip-vit-large-patch14".to_string(),
128 ));
129
130 let model_file = repo.get("model.safetensors")?;
131 let vb = from_mmaped_safetensors(
132 vec![model_file],
133 vec![],
134 None,
135 device,
136 vec![None],
137 silent,
138 None,
139 |_| true,
140 Arc::new(|_| DeviceForLoadTensor::Base),
141 )?;
142 let config_file = repo.get("config.json")?;
143 let config: ClipConfig = serde_json::from_reader(File::open(config_file)?)?;
144 let config = config.text_config;
145 let model = ClipTextTransformer::new(vb.pp("text_model"), &config)?;
146
147 let tokenizer_filename = repo.get("tokenizer.json")?;
148 let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
149
150 Ok((model, tokenizer))
151}
152
153fn get_tokenization(tok: &Tokenizer, prompts: Vec<String>, device: &Device) -> Result<Tensor> {
154 Tensor::new(
155 tok.encode_batch(prompts, true)
156 .map_err(|e| candle_core::Error::Msg(e.to_string()))?
157 .into_iter()
158 .map(|e| e.get_ids().to_vec())
159 .collect::<Vec<_>>(),
160 device,
161 )
162}
163
164impl FluxStepper {
165 pub fn new(
166 cfg: FluxStepperConfig,
167 (flux_vb, flux_cfg): (ShardedVarBuilder, &flux::model::Config),
168 (flux_ae_vb, flux_ae_cfg): (ShardedVarBuilder, &flux::autoencoder::Config),
169 dtype: DType,
170 device: &Device,
171 silent: bool,
172 offloaded: bool,
173 ) -> anyhow::Result<Self> {
174 let api = Api::new()?;
175
176 info!("Loading T5 XXL tokenizer.");
177 let t5_tokenizer = get_t5_tokenizer(&api)?;
178 info!("Loading CLIP model and tokenizer.");
179 let (clip_encoder, clip_tokenizer) = get_clip_model_and_tokenizer(&api, device, silent)?;
180
181 Ok(Self {
182 cfg,
183 t5_tok: t5_tokenizer,
184 clip_tok: clip_tokenizer,
185 clip_text: clip_encoder,
186 flux_model: Flux::new(flux_cfg, flux_vb, device.clone(), offloaded)?,
187 flux_vae: AutoEncoder::new(flux_ae_cfg, flux_ae_vb)?,
188 is_guidance: cfg.is_guidance,
189 device: device.clone(),
190 dtype,
191 api,
192 silent,
193 offloaded,
194 })
195 }
196}
197
198impl DiffusionModel for FluxStepper {
199 fn forward(
200 &mut self,
201 prompts: Vec<String>,
202 params: DiffusionGenerationParams,
203 ) -> Result<Tensor> {
204 let mut t5_input_ids = get_tokenization(&self.t5_tok, prompts.clone(), &self.device)?;
205 if !self.is_guidance {
206 match t5_input_ids.dim(1)?.cmp(&256) {
207 Ordering::Greater => {
208 candle_core::bail!("T5 embedding length greater than 256, please shrink the prompt or use the -dev (with guidance distillation) version.")
209 }
210 Ordering::Less | Ordering::Equal => {
211 t5_input_ids =
212 t5_input_ids.pad_with_zeros(D::Minus1, 0, 256 - t5_input_ids.dim(1)?)?;
213 }
214 }
215 }
216
217 let t5_embed = {
218 info!("Hotloading T5 XXL model.");
219 let mut t5_encoder = get_t5_model(
220 &self.api,
221 self.dtype,
222 &self.device,
223 self.silent,
224 self.offloaded,
225 )?;
226 t5_encoder.forward(&t5_input_ids)?
227 };
228
229 let clip_input_ids = get_tokenization(&self.clip_tok, prompts, &self.device)?;
230 let clip_embed = self
231 .clip_text
232 .forward(&clip_input_ids)?
233 .to_dtype(self.dtype)?;
234
235 let img = flux::sampling::get_noise(
236 t5_embed.dim(0)?,
237 params.height,
238 params.width,
239 self.device(),
240 )?
241 .to_dtype(self.dtype)?;
242
243 let state = flux::sampling::State::new(&t5_embed, &clip_embed, &img)?;
244 let timesteps = flux::sampling::get_schedule(
245 self.cfg.num_steps,
246 self.cfg
247 .guidance_config
248 .map(|s| (state.img.dims()[1], s.base_shift, s.max_shift)),
249 );
250
251 let img = if let Some(guidance_cfg) = &self.cfg.guidance_config {
252 flux::sampling::denoise(
253 &mut self.flux_model,
254 &state.img,
255 &state.img_ids,
256 &state.txt,
257 &state.txt_ids,
258 &state.vec,
259 ×teps,
260 guidance_cfg.guidance_scale,
261 )?
262 } else {
263 flux::sampling::denoise_no_guidance(
264 &mut self.flux_model,
265 &state.img,
266 &state.img_ids,
267 &state.txt,
268 &state.txt_ids,
269 &state.vec,
270 ×teps,
271 )?
272 };
273
274 let latent_img = flux::sampling::unpack(&img, params.height, params.width)?;
275
276 let img = self.flux_vae.decode(&latent_img)?;
277
278 let normalized_img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
279
280 Ok(normalized_img)
281 }
282
283 fn device(&self) -> &Device {
284 &self.device
285 }
286
287 fn max_seq_len(&self) -> usize {
288 if self.is_guidance {
289 usize::MAX
290 } else {
291 256
292 }
293 }
294}