mistralrs_core/diffusion_models/flux/
stepper.rs

1use std::{cmp::Ordering, fs::File, sync::Arc};
2
3use candle_core::{DType, Device, Result, Tensor, D};
4use candle_nn::Module;
5use hf_hub::api::sync::{Api, ApiError};
6use mistralrs_quant::ShardedVarBuilder;
7use tokenizers::Tokenizer;
8use tracing::info;
9
10use crate::{
11    diffusion_models::{
12        clip::text::{ClipConfig, ClipTextTransformer},
13        flux,
14        t5::{self, T5EncoderModel},
15        DiffusionGenerationParams,
16    },
17    pipeline::DiffusionModel,
18    utils::varbuilder_utils::{from_mmaped_safetensors, DeviceForLoadTensor},
19};
20
21use super::{autoencoder::AutoEncoder, model::Flux};
22
23const T5_XXL_SAFETENSOR_FILES: &[&str] =
24    &["t5_xxl-shard-0.safetensors", "t5_xxl-shard-1.safetensors"];
25
26#[derive(Clone, Copy, Debug)]
27pub struct FluxStepperShift {
28    pub base_shift: f64,
29    pub max_shift: f64,
30    pub guidance_scale: f64,
31}
32
33#[derive(Clone, Copy, Debug)]
34pub struct FluxStepperConfig {
35    pub num_steps: usize,
36    pub guidance_config: Option<FluxStepperShift>,
37    pub is_guidance: bool,
38}
39
40impl FluxStepperConfig {
41    pub fn default_for_guidance(has_guidance: bool) -> Self {
42        if has_guidance {
43            Self {
44                num_steps: 50,
45                guidance_config: Some(FluxStepperShift {
46                    base_shift: 0.5,
47                    max_shift: 1.15,
48                    guidance_scale: 4.0,
49                }),
50                is_guidance: true,
51            }
52        } else {
53            Self {
54                num_steps: 4,
55                guidance_config: None,
56                is_guidance: false,
57            }
58        }
59    }
60}
61
62pub struct FluxStepper {
63    cfg: FluxStepperConfig,
64    t5_tok: Tokenizer,
65    clip_tok: Tokenizer,
66    clip_text: ClipTextTransformer,
67    flux_model: Flux,
68    flux_vae: AutoEncoder,
69    is_guidance: bool,
70    device: Device,
71    dtype: DType,
72    api: Api,
73    silent: bool,
74    offloaded: bool,
75}
76
77fn get_t5_tokenizer(api: &Api) -> anyhow::Result<Tokenizer> {
78    let tokenizer_filename = api
79        .model("EricB/t5_tokenizer".to_string())
80        .get("t5-v1_1-xxl.tokenizer.json")?;
81    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
82
83    Ok(tokenizer)
84}
85
86fn get_t5_model(
87    api: &Api,
88    dtype: DType,
89    device: &Device,
90    silent: bool,
91    offloaded: bool,
92) -> candle_core::Result<T5EncoderModel> {
93    let repo = api.repo(hf_hub::Repo::with_revision(
94        "EricB/t5-v1_1-xxl-enc-only".to_string(),
95        hf_hub::RepoType::Model,
96        "main".to_string(),
97    ));
98
99    let vb = from_mmaped_safetensors(
100        T5_XXL_SAFETENSOR_FILES
101            .iter()
102            .map(|f| repo.get(f))
103            .collect::<std::result::Result<Vec<_>, ApiError>>()
104            .map_err(candle_core::Error::msg)?,
105        vec![],
106        Some(dtype),
107        device,
108        vec![None],
109        silent,
110        None,
111        |_| true,
112        Arc::new(|_| DeviceForLoadTensor::Base),
113    )?;
114    let config_filename = repo.get("config.json").map_err(candle_core::Error::msg)?;
115    let config = std::fs::read_to_string(config_filename)?;
116    let config: t5::Config = serde_json::from_str(&config).map_err(candle_core::Error::msg)?;
117
118    t5::T5EncoderModel::load(vb, &config, device, offloaded)
119}
120
121fn get_clip_model_and_tokenizer(
122    api: &Api,
123    device: &Device,
124    silent: bool,
125) -> anyhow::Result<(ClipTextTransformer, Tokenizer)> {
126    let repo = api.repo(hf_hub::Repo::model(
127        "openai/clip-vit-large-patch14".to_string(),
128    ));
129
130    let model_file = repo.get("model.safetensors")?;
131    let vb = from_mmaped_safetensors(
132        vec![model_file],
133        vec![],
134        None,
135        device,
136        vec![None],
137        silent,
138        None,
139        |_| true,
140        Arc::new(|_| DeviceForLoadTensor::Base),
141    )?;
142    let config_file = repo.get("config.json")?;
143    let config: ClipConfig = serde_json::from_reader(File::open(config_file)?)?;
144    let config = config.text_config;
145    let model = ClipTextTransformer::new(vb.pp("text_model"), &config)?;
146
147    let tokenizer_filename = repo.get("tokenizer.json")?;
148    let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?;
149
150    Ok((model, tokenizer))
151}
152
153fn get_tokenization(tok: &Tokenizer, prompts: Vec<String>, device: &Device) -> Result<Tensor> {
154    Tensor::new(
155        tok.encode_batch(prompts, true)
156            .map_err(|e| candle_core::Error::Msg(e.to_string()))?
157            .into_iter()
158            .map(|e| e.get_ids().to_vec())
159            .collect::<Vec<_>>(),
160        device,
161    )
162}
163
164impl FluxStepper {
165    pub fn new(
166        cfg: FluxStepperConfig,
167        (flux_vb, flux_cfg): (ShardedVarBuilder, &flux::model::Config),
168        (flux_ae_vb, flux_ae_cfg): (ShardedVarBuilder, &flux::autoencoder::Config),
169        dtype: DType,
170        device: &Device,
171        silent: bool,
172        offloaded: bool,
173    ) -> anyhow::Result<Self> {
174        let api = Api::new()?;
175
176        info!("Loading T5 XXL tokenizer.");
177        let t5_tokenizer = get_t5_tokenizer(&api)?;
178        info!("Loading CLIP model and tokenizer.");
179        let (clip_encoder, clip_tokenizer) = get_clip_model_and_tokenizer(&api, device, silent)?;
180
181        Ok(Self {
182            cfg,
183            t5_tok: t5_tokenizer,
184            clip_tok: clip_tokenizer,
185            clip_text: clip_encoder,
186            flux_model: Flux::new(flux_cfg, flux_vb, device.clone(), offloaded)?,
187            flux_vae: AutoEncoder::new(flux_ae_cfg, flux_ae_vb)?,
188            is_guidance: cfg.is_guidance,
189            device: device.clone(),
190            dtype,
191            api,
192            silent,
193            offloaded,
194        })
195    }
196}
197
198impl DiffusionModel for FluxStepper {
199    fn forward(
200        &mut self,
201        prompts: Vec<String>,
202        params: DiffusionGenerationParams,
203    ) -> Result<Tensor> {
204        let mut t5_input_ids = get_tokenization(&self.t5_tok, prompts.clone(), &self.device)?;
205        if !self.is_guidance {
206            match t5_input_ids.dim(1)?.cmp(&256) {
207                Ordering::Greater => {
208                    candle_core::bail!("T5 embedding length greater than 256, please shrink the prompt or use the -dev (with guidance distillation) version.")
209                }
210                Ordering::Less | Ordering::Equal => {
211                    t5_input_ids =
212                        t5_input_ids.pad_with_zeros(D::Minus1, 0, 256 - t5_input_ids.dim(1)?)?;
213                }
214            }
215        }
216
217        let t5_embed = {
218            info!("Hotloading T5 XXL model.");
219            let mut t5_encoder = get_t5_model(
220                &self.api,
221                self.dtype,
222                &self.device,
223                self.silent,
224                self.offloaded,
225            )?;
226            t5_encoder.forward(&t5_input_ids)?
227        };
228
229        let clip_input_ids = get_tokenization(&self.clip_tok, prompts, &self.device)?;
230        let clip_embed = self
231            .clip_text
232            .forward(&clip_input_ids)?
233            .to_dtype(self.dtype)?;
234
235        let img = flux::sampling::get_noise(
236            t5_embed.dim(0)?,
237            params.height,
238            params.width,
239            self.device(),
240        )?
241        .to_dtype(self.dtype)?;
242
243        let state = flux::sampling::State::new(&t5_embed, &clip_embed, &img)?;
244        let timesteps = flux::sampling::get_schedule(
245            self.cfg.num_steps,
246            self.cfg
247                .guidance_config
248                .map(|s| (state.img.dims()[1], s.base_shift, s.max_shift)),
249        );
250
251        let img = if let Some(guidance_cfg) = &self.cfg.guidance_config {
252            flux::sampling::denoise(
253                &mut self.flux_model,
254                &state.img,
255                &state.img_ids,
256                &state.txt,
257                &state.txt_ids,
258                &state.vec,
259                &timesteps,
260                guidance_cfg.guidance_scale,
261            )?
262        } else {
263            flux::sampling::denoise_no_guidance(
264                &mut self.flux_model,
265                &state.img,
266                &state.img_ids,
267                &state.txt,
268                &state.txt_ids,
269                &state.vec,
270                &timesteps,
271            )?
272        };
273
274        let latent_img = flux::sampling::unpack(&img, params.height, params.width)?;
275
276        let img = self.flux_vae.decode(&latent_img)?;
277
278        let normalized_img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
279
280        Ok(normalized_img)
281    }
282
283    fn device(&self) -> &Device {
284        &self.device
285    }
286
287    fn max_seq_len(&self) -> usize {
288        if self.is_guidance {
289            usize::MAX
290        } else {
291            256
292        }
293    }
294}