mistralrs_core/pipeline/loaders/
diffusion_loaders.rs1use std::{
2 fmt::Debug,
3 path::{Path, PathBuf},
4 str::FromStr,
5};
6
7use anyhow::{Context, Result};
8use candle_core::{Device, Tensor};
9
10use hf_hub::api::sync::ApiRepo;
11use mistralrs_quant::ShardedVarBuilder;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::pyclass;
14
15use regex::Regex;
16use serde::Deserialize;
17
18use tracing::info;
19
20use super::{ModelPaths, NormalLoadingMetadata};
21use crate::{
22 api_dir_list, api_get_file,
23 diffusion_models::{
24 flux::{
25 self,
26 stepper::{FluxStepper, FluxStepperConfig},
27 },
28 DiffusionGenerationParams,
29 },
30 paged_attention::AttentionImplementation,
31 pipeline::{paths::AdapterPaths, EmbeddingModulePaths},
32};
33
34pub trait DiffusionModel {
35 fn forward(
37 &mut self,
38 prompts: Vec<String>,
39 params: DiffusionGenerationParams,
40 ) -> candle_core::Result<Tensor>;
41 fn device(&self) -> &Device;
42 fn max_seq_len(&self) -> usize;
43}
44
45pub trait DiffusionModelLoader: Send + Sync {
46 fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>>;
48 fn get_config_filenames(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>>;
50 fn force_cpu_vb(&self) -> Vec<bool>;
51 fn load(
53 &self,
54 configs: Vec<String>,
55 vbs: Vec<ShardedVarBuilder>,
56 normal_loading_metadata: NormalLoadingMetadata,
57 attention_mechanism: AttentionImplementation,
58 silent: bool,
59 ) -> Result<Box<dyn DiffusionModel + Send + Sync>>;
60}
61
62#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
63#[derive(Clone, Debug, Deserialize, PartialEq)]
64pub enum DiffusionLoaderType {
66 #[serde(rename = "flux")]
67 Flux,
68 #[serde(rename = "flux-offloaded")]
69 FluxOffloaded,
70}
71
72impl FromStr for DiffusionLoaderType {
73 type Err = String;
74 fn from_str(s: &str) -> Result<Self, Self::Err> {
75 match s {
76 "flux" => Ok(Self::Flux),
77 "flux-offloaded" => Ok(Self::FluxOffloaded),
78 a => Err(format!(
79 "Unknown architecture `{a}`. Possible architectures: `flux`."
80 )),
81 }
82 }
83}
84
85#[derive(Clone, Debug)]
86pub struct DiffusionModelPathsInner {
87 pub config_filenames: Vec<PathBuf>,
88 pub filenames: Vec<PathBuf>,
89}
90
91#[derive(Clone, Debug)]
92pub struct DiffusionModelPaths(pub DiffusionModelPathsInner);
93
94impl ModelPaths for DiffusionModelPaths {
95 fn get_config_filename(&self) -> &PathBuf {
96 unreachable!("Use `std::any::Any`.")
97 }
98 fn get_tokenizer_filename(&self) -> &PathBuf {
99 unreachable!("Use `std::any::Any`.")
100 }
101 fn get_weight_filenames(&self) -> &[PathBuf] {
102 unreachable!("Use `std::any::Any`.")
103 }
104 fn get_template_filename(&self) -> &Option<PathBuf> {
105 unreachable!("Use `std::any::Any`.")
106 }
107 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
108 unreachable!("Use `std::any::Any`.")
109 }
110 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
111 unreachable!("Use `std::any::Any`.")
112 }
113 fn get_processor_config(&self) -> &Option<PathBuf> {
114 unreachable!("Use `std::any::Any`.")
115 }
116 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
117 unreachable!("Use `std::any::Any`.")
118 }
119 fn get_adapter_paths(&self) -> &AdapterPaths {
120 unreachable!("Use `std::any::Any`.")
121 }
122 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
123 unreachable!("Use `std::any::Any`.")
124 }
125}
126
127pub struct FluxLoader {
133 pub(crate) offload: bool,
134}
135
136impl DiffusionModelLoader for FluxLoader {
137 fn get_model_paths(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>> {
138 let regex = Regex::new(r"^flux\d+-(schnell|dev)\.safetensors$")?;
139 let flux_name = api_dir_list!(api, model_id, true)
140 .filter(|x| regex.is_match(x))
141 .nth(0)
142 .with_context(|| "Expected at least 1 .safetensors file matching the FLUX regex, please raise an issue.")?;
143 let flux_file = api_get_file!(api, &flux_name, model_id);
144 let ae_file = api_get_file!(api, "ae.safetensors", model_id);
145
146 Ok(vec![flux_file, ae_file])
148 }
149 fn get_config_filenames(&self, api: &ApiRepo, model_id: &Path) -> Result<Vec<PathBuf>> {
150 let flux_file = api_get_file!(api, "transformer/config.json", model_id);
151 let ae_file = api_get_file!(api, "vae/config.json", model_id);
152
153 Ok(vec![flux_file, ae_file])
155 }
156 fn force_cpu_vb(&self) -> Vec<bool> {
157 vec![self.offload, false]
158 }
159 fn load(
160 &self,
161 mut configs: Vec<String>,
162 mut vbs: Vec<ShardedVarBuilder>,
163 normal_loading_metadata: NormalLoadingMetadata,
164 _attention_mechanism: AttentionImplementation,
165 silent: bool,
166 ) -> Result<Box<dyn DiffusionModel + Send + Sync>> {
167 let (vae_cfg, vae_vb) = (configs.remove(1), vbs.remove(1));
168 let (flux_cfg, flux_vb) = (configs.remove(0), vbs.remove(0));
169
170 let vae_cfg: flux::autoencoder::Config = serde_json::from_str(&vae_cfg)?;
171 let flux_cfg: flux::model::Config = serde_json::from_str(&flux_cfg)?;
172
173 let flux_dtype = flux_vb.dtype();
174 if flux_dtype != vae_vb.dtype() {
175 anyhow::bail!(
176 "Expected VAE and FLUX model VBs to be the same dtype, got {:?} and {flux_dtype:?}",
177 vae_vb.dtype()
178 );
179 }
180
181 Ok(Box::new(FluxStepper::new(
182 FluxStepperConfig::default_for_guidance(flux_cfg.guidance_embeds),
183 (flux_vb, &flux_cfg),
184 (vae_vb, &vae_cfg),
185 flux_dtype,
186 &normal_loading_metadata.real_device,
187 silent,
188 self.offload,
189 )?))
190 }
191}