mistralrs_core/pipeline/loaders/
diffusion_loaders.rs

1use std::{
2    fmt::Debug,
3    path::{Path, PathBuf},
4    str::FromStr,
5};
6
7use anyhow::{Context, Result};
8use candle_core::{Device, Tensor};
9
10use hf_hub::api::sync::ApiRepo;
11use mistralrs_quant::ShardedVarBuilder;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::pyclass;
14
15use regex::Regex;
16use serde::Deserialize;
17
18use tracing::info;
19
20use super::{ModelPaths, NormalLoadingMetadata};
21use crate::{
22    api_dir_list, api_get_file,
23    diffusion_models::{
24        flux::{
25            self,
26            stepper::{FluxStepper, FluxStepperConfig},
27        },
28        DiffusionGenerationParams,
29    },
30    paged_attention::AttentionImplementation,
31    pipeline::paths::AdapterPaths,
32};
33
34pub trait DiffusionModel {
35    /// This returns a tensor of shape (bs, c, h, w), with values in [0, 255].
36    fn forward(
37        &mut self,
38        prompts: Vec<String>,
39        params: DiffusionGenerationParams,
40    ) -> candle_core::Result<Tensor>;
41    fn device(&self) -> &Device;
42    fn max_seq_len(&self) -> usize;
43}
44
45pub trait DiffusionModelLoader: Send + Sync {
46    /// If the model is being loaded with `load_model_from_hf` (so manual paths not provided), this will be called.
47    fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>>;
48    /// If the model is being loaded with `load_model_from_hf` (so manual paths not provided), this will be called.
49    fn get_config_filenames(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>>;
50    fn force_cpu_vb(&self) -> Vec<bool>;
51    // `configs` and `vbs` should be corresponding. It is up to the implementer to maintain this invaraint.
52    fn load(
53        &self,
54        configs: Vec<String>,
55        use_flash_attn: bool,
56        vbs: Vec<ShardedVarBuilder>,
57        normal_loading_metadata: NormalLoadingMetadata,
58        attention_mechanism: AttentionImplementation,
59        silent: bool,
60    ) -> Result<Box<dyn DiffusionModel + Send + Sync>>;
61}
62
63#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
64#[derive(Clone, Debug, Deserialize, PartialEq)]
65/// The architecture to load the vision model as.
66pub enum DiffusionLoaderType {
67    #[serde(rename = "flux")]
68    Flux,
69    #[serde(rename = "flux-offloaded")]
70    FluxOffloaded,
71}
72
73impl FromStr for DiffusionLoaderType {
74    type Err = String;
75    fn from_str(s: &str) -> Result<Self, Self::Err> {
76        match s {
77            "flux" => Ok(Self::Flux),
78            "flux-offloaded" => Ok(Self::FluxOffloaded),
79            a => Err(format!(
80                "Unknown architecture `{a}`. Possible architectures: `flux`."
81            )),
82        }
83    }
84}
85
86#[derive(Clone, Debug)]
87pub struct DiffusionModelPathsInner {
88    pub config_filenames: Vec<PathBuf>,
89    pub filenames: Vec<PathBuf>,
90}
91
92#[derive(Clone, Debug)]
93pub struct DiffusionModelPaths(pub DiffusionModelPathsInner);
94
95impl ModelPaths for DiffusionModelPaths {
96    fn get_config_filename(&self) -> &PathBuf {
97        unreachable!("Use `std::any::Any`.")
98    }
99    fn get_tokenizer_filename(&self) -> &PathBuf {
100        unreachable!("Use `std::any::Any`.")
101    }
102    fn get_weight_filenames(&self) -> &[PathBuf] {
103        unreachable!("Use `std::any::Any`.")
104    }
105    fn get_template_filename(&self) -> &Option<PathBuf> {
106        unreachable!("Use `std::any::Any`.")
107    }
108    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
109        unreachable!("Use `std::any::Any`.")
110    }
111    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
112        unreachable!("Use `std::any::Any`.")
113    }
114    fn get_processor_config(&self) -> &Option<PathBuf> {
115        unreachable!("Use `std::any::Any`.")
116    }
117    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
118        unreachable!("Use `std::any::Any`.")
119    }
120    fn get_adapter_paths(&self) -> &AdapterPaths {
121        unreachable!("Use `std::any::Any`.")
122    }
123}
124
125// ======================== Flux loader
126
127/// [`DiffusionLoader`] for a Flux Diffusion model.
128///
129/// [`DiffusionLoader`]: https://ericlbuehler.github.io/mistral.rs/mistralrs/struct.DiffusionLoader.html
130pub struct FluxLoader {
131    pub(crate) offload: bool,
132}
133
134impl DiffusionModelLoader for FluxLoader {
135    fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>> {
136        let regex = Regex::new(r"^flux\d+-(schnell|dev)\.safetensors$")?;
137        let flux_name = api_dir_list!(api, model_id)
138            .filter(|x| regex.is_match(x))
139            .nth(0)
140            .with_context(|| "Expected at least 1 .safetensors file matching the FLUX regex, please raise an issue.")?;
141        let flux_file = api_get_file!(api, &flux_name, model_id);
142        let ae_file = api_get_file!(api, "ae.safetensors", model_id);
143
144        // NOTE(EricLBuehler): disgusting way of doing this but the 0th path is the flux, 1 is ae
145        Ok(vec![flux_file, ae_file])
146    }
147    fn get_config_filenames(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>> {
148        let flux_file = api_get_file!(api, "transformer/config.json", model_id);
149        let ae_file = api_get_file!(api, "vae/config.json", model_id);
150
151        // NOTE(EricLBuehler): disgusting way of doing this but the 0th path is the flux, 1 is ae
152        Ok(vec![flux_file, ae_file])
153    }
154    fn force_cpu_vb(&self) -> Vec<bool> {
155        vec![self.offload, false]
156    }
157    fn load(
158        &self,
159        mut configs: Vec<String>,
160        _use_flash_attn: bool,
161        mut vbs: Vec<ShardedVarBuilder>,
162        normal_loading_metadata: NormalLoadingMetadata,
163        _attention_mechanism: AttentionImplementation,
164        silent: bool,
165    ) -> Result<Box<dyn DiffusionModel + Send + Sync>> {
166        let (vae_cfg, vae_vb) = (configs.remove(1), vbs.remove(1));
167        let (flux_cfg, flux_vb) = (configs.remove(0), vbs.remove(0));
168
169        let vae_cfg: flux::autoencoder::Config = serde_json::from_str(&vae_cfg)?;
170        let flux_cfg: flux::model::Config = serde_json::from_str(&flux_cfg)?;
171
172        let flux_dtype = flux_vb.dtype();
173        if flux_dtype != vae_vb.dtype() {
174            anyhow::bail!(
175                "Expected VAE and FLUX model VBs to be the same dtype, got {:?} and {flux_dtype:?}",
176                vae_vb.dtype()
177            );
178        }
179
180        Ok(Box::new(FluxStepper::new(
181            FluxStepperConfig::default_for_guidance(flux_cfg.guidance_embeds),
182            (flux_vb, &flux_cfg),
183            (vae_vb, &vae_cfg),
184            flux_dtype,
185            &normal_loading_metadata.real_device,
186            silent,
187            self.offload,
188        )?))
189    }
190}