1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3 AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4 DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5 IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::{self, DeviceMapper};
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::distributed::{self, WorkerTransferData};
11use crate::paged_attention::AttentionImplementation;
12use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
13use crate::prefix_cacher::PrefixCacheManagerV2;
14use crate::sequence::Sequence;
15use crate::utils::varbuilder_utils::DeviceForLoadTensor;
16use crate::utils::{
17 progress::{new_multi_progress, ProgressScopeGuard},
18 tokens::get_token,
19 varbuilder_utils::from_mmaped_safetensors,
20};
21use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
22use anyhow::Result;
23use candle_core::{DType, Device, Tensor};
24use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
25use image::{DynamicImage, RgbImage};
26use mistralrs_quant::log::once_log_info;
27use mistralrs_quant::IsqType;
28use rand_isaac::Isaac64Rng;
29use std::any::Any;
30use std::sync::Arc;
31use std::{env, io};
32use tokenizers::Tokenizer;
33use tokio::sync::Mutex;
34use tracing::warn;
35
36pub struct DiffusionPipeline {
37 model: Box<dyn DiffusionModel + Send + Sync>,
38 model_id: String,
39 metadata: Arc<GeneralMetadata>,
40 dummy_cache: EitherCache,
41}
42
43pub struct DiffusionLoader {
45 inner: Box<dyn DiffusionModelLoader>,
46 model_id: String,
47 kind: ModelKind,
48}
49
50#[derive(Default)]
51pub struct DiffusionLoaderBuilder {
53 model_id: Option<String>,
54 kind: ModelKind,
55}
56
57impl DiffusionLoaderBuilder {
58 pub fn new(model_id: Option<String>) -> Self {
59 Self {
60 model_id,
61 kind: ModelKind::Normal,
62 }
63 }
64
65 pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
66 let loader: Box<dyn DiffusionModelLoader> = match loader {
67 DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
68 DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
69 };
70 Box::new(DiffusionLoader {
71 inner: loader,
72 model_id: self.model_id.unwrap(),
73 kind: self.kind,
74 })
75 }
76}
77
78impl Loader for DiffusionLoader {
79 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
80 fn load_model_from_hf(
81 &self,
82 revision: Option<String>,
83 token_source: TokenSource,
84 dtype: &dyn TryIntoDType,
85 device: &Device,
86 silent: bool,
87 mapper: DeviceMapSetting,
88 in_situ_quant: Option<IsqType>,
89 paged_attn_config: Option<PagedAttentionConfig>,
90 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
91 let _progress_guard = ProgressScopeGuard::new(silent);
92 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
93 let api = ApiBuilder::new()
94 .with_progress(!silent)
95 .with_token(get_token(&token_source)?)
96 .build()?;
97 let revision = revision.unwrap_or("main".to_string());
98 let api = api.repo(Repo::with_revision(
99 self.model_id.clone(),
100 RepoType::Model,
101 revision.clone(),
102 ));
103 let model_id = std::path::Path::new(&self.model_id);
104 let filenames = self.inner.get_model_paths(&api, model_id)?;
105 let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
106 Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
107 config_filenames,
108 filenames,
109 })))
110 };
111 self.load_model_from_path(
112 &paths?,
113 dtype,
114 device,
115 silent,
116 mapper,
117 in_situ_quant,
118 paged_attn_config,
119 )
120 }
121
122 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
123 fn load_model_from_path(
124 &self,
125 paths: &Box<dyn ModelPaths>,
126 dtype: &dyn TryIntoDType,
127 device: &Device,
128 silent: bool,
129 mapper: DeviceMapSetting,
130 in_situ_quant: Option<IsqType>,
131 mut paged_attn_config: Option<PagedAttentionConfig>,
132 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
133 let _progress_guard = ProgressScopeGuard::new(silent);
134 let paths = &paths
135 .as_ref()
136 .as_any()
137 .downcast_ref::<DiffusionModelPaths>()
138 .expect("Path downcast failed.")
139 .0;
140
141 if matches!(mapper, DeviceMapSetting::Map(_)) {
142 anyhow::bail!("Device mapping is not supported for diffusion models.")
143 }
144
145 if in_situ_quant.is_some() {
146 anyhow::bail!("ISQ is not supported for Diffusion models.");
147 }
148
149 if paged_attn_config.is_some() {
150 warn!("PagedAttention is not supported for Diffusion models, disabling it.");
151
152 paged_attn_config = None;
153 }
154
155 if crate::using_flash_attn() {
156 once_log_info("FlashAttention is enabled.");
157 }
158
159 let configs = paths
160 .config_filenames
161 .iter()
162 .map(std::fs::read_to_string)
163 .collect::<io::Result<Vec<_>>>()?;
164
165 #[cfg(feature = "cuda")]
166 if let Device::Cuda(dev) = &device {
167 unsafe { dev.disable_event_tracking() };
168 }
169 let use_nccl = mistralrs_quant::distributed::use_nccl();
170 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
171 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
172 let WorkerTransferData::Init { id: _, worker_rank } = payload;
173 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
174 } else if use_nccl {
175 vec![candle_core::Device::new_cuda(0)?]
176 } else {
177 device_map::get_all_similar_devices(device)?
178 };
179
180 let mapper =
181 DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None, &available_devices)?;
182 let dtype = mapper.get_min_dtype(dtype)?;
183
184 let attention_mechanism = if paged_attn_config.is_some() {
185 AttentionImplementation::PagedAttention
186 } else {
187 AttentionImplementation::Eager
188 };
189
190 let model = match self.kind {
191 ModelKind::Normal => {
192 let vbs = paths
193 .filenames
194 .iter()
195 .zip(self.inner.force_cpu_vb())
196 .map(|(path, force_cpu)| {
197 let dev = if force_cpu { &Device::Cpu } else { device };
198 from_mmaped_safetensors(
199 vec![path.clone()],
200 Vec::new(),
201 Some(dtype),
202 dev,
203 vec![None],
204 silent,
205 None,
206 |_| true,
207 Arc::new(|_| DeviceForLoadTensor::Base),
208 )
209 })
210 .collect::<candle_core::Result<Vec<_>>>()?;
211
212 self.inner.load(
213 configs,
214 vbs,
215 crate::pipeline::NormalLoadingMetadata {
216 mapper,
217 loading_isq: false,
218 real_device: device.clone(),
219 multi_progress: Arc::new(new_multi_progress()),
220 matformer_slicing_config: None,
221 },
222 attention_mechanism,
223 silent,
224 )?
225 }
226 _ => unreachable!(),
227 };
228
229 let max_seq_len = model.max_seq_len();
230 Ok(Arc::new(Mutex::new(DiffusionPipeline {
231 model,
232 model_id: self.model_id.clone(),
233 metadata: Arc::new(GeneralMetadata {
234 max_seq_len,
235 llg_factory: None,
236 is_xlora: false,
237 no_prefix_cache: false,
238 num_hidden_layers: 1, eos_tok: vec![],
240 kind: self.kind.clone(),
241 no_kv_cache: true, activation_dtype: dtype,
243 sliding_window: None,
244 cache_config: None,
245 cache_engine: None,
246 model_metadata: None,
247 modalities: Modalities {
248 input: vec![SupportedModality::Text],
249 output: vec![SupportedModality::Vision],
250 },
251 }),
252 dummy_cache: EitherCache::Full(Cache::new(0, false)),
253 })))
254 }
255
256 fn get_id(&self) -> String {
257 self.model_id.to_string()
258 }
259
260 fn get_kind(&self) -> ModelKind {
261 self.kind.clone()
262 }
263}
264
265impl PreProcessingMixin for DiffusionPipeline {
266 fn get_processor(&self) -> Arc<dyn Processor> {
267 Arc::new(DiffusionProcessor)
268 }
269 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
270 None
271 }
272 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
273 None
274 }
275}
276
277impl IsqPipelineMixin for DiffusionPipeline {
278 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
279 anyhow::bail!("Diffusion models do not support ISQ for now.")
280 }
281}
282
283impl CacheManagerMixin for DiffusionPipeline {
284 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
285 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
286 fn set_none_cache(
287 &self,
288 _seqs: &mut [&mut Sequence],
289 _reset_non_granular: bool,
290 _modify_draft_cache: bool,
291 _load_preallocated_cache: bool,
292 ) {
293 }
294 fn cache(&self) -> &EitherCache {
295 &self.dummy_cache
296 }
297}
298
299impl MetadataMixin for DiffusionPipeline {
300 fn device(&self) -> Device {
301 self.model.device().clone()
302 }
303 fn get_metadata(&self) -> Arc<GeneralMetadata> {
304 self.metadata.clone()
305 }
306 fn name(&self) -> String {
307 self.model_id.clone()
308 }
309 fn reset_non_granular_state(&self) {}
310 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
311 None
312 }
313 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
314 None
315 }
316}
317
318#[async_trait::async_trait]
319impl Pipeline for DiffusionPipeline {
320 fn forward_inputs(
321 &mut self,
322 inputs: Box<dyn Any>,
323 return_raw_logits: bool,
324 ) -> candle_core::Result<ForwardInputsResult> {
325 assert!(!return_raw_logits);
326
327 let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
328 let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
329 let (_b, c, h, w) = img.dims4()?;
330 let mut images = Vec::new();
331 for b_img in img.chunk(img.dim(0)?, 0)? {
332 let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
333 if c != 3 {
334 candle_core::bail!("Expected 3 channels in image output");
335 }
336 #[allow(clippy::cast_possible_truncation)]
337 images.push(DynamicImage::ImageRgb8(
338 RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
339 candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
340 )?,
341 ));
342 }
343 Ok(ForwardInputsResult::Image { images })
344 }
345 async fn sample_causal_gen(
346 &self,
347 _seqs: &mut [&mut Sequence],
348 _logits: Vec<Tensor>,
349 _prefix_cacher: &mut PrefixCacheManagerV2,
350 _disable_eos_stop: bool,
351 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
352 ) -> Result<(), candle_core::Error> {
353 candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
354 }
355 fn category(&self) -> ModelCategory {
356 ModelCategory::Diffusion
357 }
358}
359
360impl AnyMoePipelineMixin for DiffusionPipeline {}