1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3 AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4 DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5 IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::ChatTemplate;
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
17use anyhow::Result;
18use candle_core::{DType, Device, Tensor};
19use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
20use image::{DynamicImage, RgbImage};
21use indicatif::MultiProgress;
22use mistralrs_quant::IsqType;
23use rand_isaac::Isaac64Rng;
24use std::any::Any;
25use std::io;
26use std::sync::Arc;
27use tokenizers::Tokenizer;
28use tokio::sync::Mutex;
29use tracing::warn;
30
31pub struct DiffusionPipeline {
32 model: Box<dyn DiffusionModel + Send + Sync>,
33 model_id: String,
34 metadata: Arc<GeneralMetadata>,
35 dummy_cache: EitherCache,
36}
37
38pub struct DiffusionLoader {
40 inner: Box<dyn DiffusionModelLoader>,
41 model_id: String,
42 config: DiffusionSpecificConfig,
43 kind: ModelKind,
44}
45
46#[derive(Default)]
47pub struct DiffusionLoaderBuilder {
49 model_id: Option<String>,
50 config: DiffusionSpecificConfig,
51 kind: ModelKind,
52}
53
54#[derive(Clone, Default)]
55pub struct DiffusionSpecificConfig {
57 pub use_flash_attn: bool,
58}
59
60impl DiffusionLoaderBuilder {
61 pub fn new(config: DiffusionSpecificConfig, model_id: Option<String>) -> Self {
62 Self {
63 config,
64 model_id,
65 kind: ModelKind::Normal,
66 }
67 }
68
69 pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
70 let loader: Box<dyn DiffusionModelLoader> = match loader {
71 DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
72 DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
73 };
74 Box::new(DiffusionLoader {
75 inner: loader,
76 model_id: self.model_id.unwrap(),
77 config: self.config,
78 kind: self.kind,
79 })
80 }
81}
82
83impl Loader for DiffusionLoader {
84 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
85 fn load_model_from_hf(
86 &self,
87 revision: Option<String>,
88 token_source: TokenSource,
89 dtype: &dyn TryIntoDType,
90 device: &Device,
91 silent: bool,
92 mapper: DeviceMapSetting,
93 in_situ_quant: Option<IsqType>,
94 paged_attn_config: Option<PagedAttentionConfig>,
95 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
96 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
97 let api = ApiBuilder::new()
98 .with_progress(!silent)
99 .with_token(get_token(&token_source)?)
100 .build()?;
101 let revision = revision.unwrap_or("main".to_string());
102 let api = api.repo(Repo::with_revision(
103 self.model_id.clone(),
104 RepoType::Model,
105 revision.clone(),
106 ));
107 let model_id = std::path::Path::new(&self.model_id);
108 let filenames = self.inner.get_model_paths(&api, model_id)?;
109 let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
110 Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
111 config_filenames,
112 filenames,
113 })))
114 };
115 self.load_model_from_path(
116 &paths?,
117 dtype,
118 device,
119 silent,
120 mapper,
121 in_situ_quant,
122 paged_attn_config,
123 )
124 }
125
126 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
127 fn load_model_from_path(
128 &self,
129 paths: &Box<dyn ModelPaths>,
130 dtype: &dyn TryIntoDType,
131 device: &Device,
132 silent: bool,
133 mapper: DeviceMapSetting,
134 in_situ_quant: Option<IsqType>,
135 mut paged_attn_config: Option<PagedAttentionConfig>,
136 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
137 let paths = &paths
138 .as_ref()
139 .as_any()
140 .downcast_ref::<DiffusionModelPaths>()
141 .expect("Path downcast failed.")
142 .0;
143
144 if matches!(mapper, DeviceMapSetting::Map(_)) {
145 anyhow::bail!("Device mapping is not supported for diffusion models.")
146 }
147
148 if in_situ_quant.is_some() {
149 anyhow::bail!("ISQ is not supported for Diffusion models.");
150 }
151
152 if paged_attn_config.is_some() {
153 warn!("PagedAttention is not supported for Diffusion models, disabling it.");
154
155 paged_attn_config = None;
156 }
157
158 let configs = paths
159 .config_filenames
160 .iter()
161 .map(std::fs::read_to_string)
162 .collect::<io::Result<Vec<_>>>()?;
163
164 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
165 let dtype = mapper.get_min_dtype(dtype)?;
166
167 let attention_mechanism = if paged_attn_config.is_some() {
168 AttentionImplementation::PagedAttention
169 } else {
170 AttentionImplementation::Eager
171 };
172
173 let model = match self.kind {
174 ModelKind::Normal => {
175 let vbs = paths
176 .filenames
177 .iter()
178 .zip(self.inner.force_cpu_vb())
179 .map(|(path, force_cpu)| {
180 let dev = if force_cpu { &Device::Cpu } else { device };
181 from_mmaped_safetensors(
182 vec![path.clone()],
183 Vec::new(),
184 Some(dtype),
185 dev,
186 vec![None],
187 silent,
188 None,
189 |_| true,
190 Arc::new(|_| DeviceForLoadTensor::Base),
191 )
192 })
193 .collect::<candle_core::Result<Vec<_>>>()?;
194
195 self.inner.load(
196 configs,
197 self.config.use_flash_attn,
198 vbs,
199 crate::pipeline::NormalLoadingMetadata {
200 mapper,
201 loading_isq: false,
202 real_device: device.clone(),
203 multi_progress: Arc::new(MultiProgress::new()),
204 },
205 attention_mechanism,
206 silent,
207 )?
208 }
209 _ => unreachable!(),
210 };
211
212 let max_seq_len = model.max_seq_len();
213 Ok(Arc::new(Mutex::new(DiffusionPipeline {
214 model,
215 model_id: self.model_id.clone(),
216 metadata: Arc::new(GeneralMetadata {
217 max_seq_len,
218 tok_env: None,
219 is_xlora: false,
220 no_prefix_cache: false,
221 num_hidden_layers: 1, eos_tok: vec![],
223 kind: self.kind.clone(),
224 no_kv_cache: true, activation_dtype: dtype,
226 sliding_window: None,
227 cache_config: None,
228 cache_engine: None,
229 prompt_chunksize: None,
230 model_metadata: None,
231 }),
232 dummy_cache: EitherCache::Full(Cache::new(0, false)),
233 })))
234 }
235
236 fn get_id(&self) -> String {
237 self.model_id.to_string()
238 }
239
240 fn get_kind(&self) -> ModelKind {
241 self.kind.clone()
242 }
243}
244
245impl PreProcessingMixin for DiffusionPipeline {
246 fn get_processor(&self) -> Arc<dyn Processor> {
247 Arc::new(DiffusionProcessor)
248 }
249 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
250 None
251 }
252 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
253 None
254 }
255}
256
257impl IsqPipelineMixin for DiffusionPipeline {
258 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
259 anyhow::bail!("Diffusion models do not support ISQ for now.")
260 }
261}
262
263impl CacheManagerMixin for DiffusionPipeline {
264 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
265 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
266 fn set_none_cache(
267 &self,
268 _seqs: &mut [&mut Sequence],
269 _reset_non_granular: bool,
270 _modify_draft_cache: bool,
271 _load_preallocated_cache: bool,
272 ) {
273 }
274 fn cache(&self) -> &EitherCache {
275 &self.dummy_cache
276 }
277}
278
279impl MetadataMixin for DiffusionPipeline {
280 fn device(&self) -> Device {
281 self.model.device().clone()
282 }
283 fn get_metadata(&self) -> Arc<GeneralMetadata> {
284 self.metadata.clone()
285 }
286 fn name(&self) -> String {
287 self.model_id.clone()
288 }
289 fn reset_non_granular_state(&self) {}
290 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
291 None
292 }
293 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
294 None
295 }
296}
297
298#[async_trait::async_trait]
299impl Pipeline for DiffusionPipeline {
300 fn forward_inputs(
301 &mut self,
302 inputs: Box<dyn Any>,
303 return_raw_logits: bool,
304 ) -> candle_core::Result<ForwardInputsResult> {
305 assert!(!return_raw_logits);
306
307 let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
308 let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
309 let (_b, c, h, w) = img.dims4()?;
310 let mut images = Vec::new();
311 for b_img in img.chunk(img.dim(0)?, 0)? {
312 let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
313 if c != 3 {
314 candle_core::bail!("Expected 3 channels in image output");
315 }
316 #[allow(clippy::cast_possible_truncation)]
317 images.push(DynamicImage::ImageRgb8(
318 RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
319 candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
320 )?,
321 ));
322 }
323 Ok(ForwardInputsResult::Image { images })
324 }
325 async fn sample_causal_gen(
326 &self,
327 _seqs: &mut [&mut Sequence],
328 _logits: Vec<Tensor>,
329 _prefix_cacher: &mut PrefixCacheManagerV2,
330 _disable_eos_stop: bool,
331 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
332 ) -> Result<(), candle_core::Error> {
333 candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
334 }
335 fn category(&self) -> ModelCategory {
336 ModelCategory::Diffusion
337 }
338}
339
340impl AnyMoePipelineMixin for DiffusionPipeline {}