1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3 AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4 DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5 IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{
16 progress::{new_multi_progress, ProgressScopeGuard},
17 tokens::get_token,
18 varbuilder_utils::from_mmaped_safetensors,
19};
20use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
24use image::{DynamicImage, RgbImage};
25use mistralrs_quant::log::once_log_info;
26use mistralrs_quant::IsqType;
27use rand_isaac::Isaac64Rng;
28use std::any::Any;
29use std::io;
30use std::sync::Arc;
31use tokenizers::Tokenizer;
32use tokio::sync::Mutex;
33use tracing::warn;
34
35pub struct DiffusionPipeline {
36 model: Box<dyn DiffusionModel + Send + Sync>,
37 model_id: String,
38 metadata: Arc<GeneralMetadata>,
39 dummy_cache: EitherCache,
40}
41
42pub struct DiffusionLoader {
44 inner: Box<dyn DiffusionModelLoader>,
45 model_id: String,
46 kind: ModelKind,
47}
48
49#[derive(Default)]
50pub struct DiffusionLoaderBuilder {
52 model_id: Option<String>,
53 kind: ModelKind,
54}
55
56impl DiffusionLoaderBuilder {
57 pub fn new(model_id: Option<String>) -> Self {
58 Self {
59 model_id,
60 kind: ModelKind::Normal,
61 }
62 }
63
64 pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
65 let loader: Box<dyn DiffusionModelLoader> = match loader {
66 DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
67 DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
68 };
69 Box::new(DiffusionLoader {
70 inner: loader,
71 model_id: self.model_id.unwrap(),
72 kind: self.kind,
73 })
74 }
75}
76
77impl Loader for DiffusionLoader {
78 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
79 fn load_model_from_hf(
80 &self,
81 revision: Option<String>,
82 token_source: TokenSource,
83 dtype: &dyn TryIntoDType,
84 device: &Device,
85 silent: bool,
86 mapper: DeviceMapSetting,
87 in_situ_quant: Option<IsqType>,
88 paged_attn_config: Option<PagedAttentionConfig>,
89 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
90 let _progress_guard = ProgressScopeGuard::new(silent);
91 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
92 let api = ApiBuilder::new()
93 .with_progress(!silent)
94 .with_token(get_token(&token_source)?)
95 .build()?;
96 let revision = revision.unwrap_or("main".to_string());
97 let api = api.repo(Repo::with_revision(
98 self.model_id.clone(),
99 RepoType::Model,
100 revision.clone(),
101 ));
102 let model_id = std::path::Path::new(&self.model_id);
103 let filenames = self.inner.get_model_paths(&api, model_id)?;
104 let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
105 Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
106 config_filenames,
107 filenames,
108 })))
109 };
110 self.load_model_from_path(
111 &paths?,
112 dtype,
113 device,
114 silent,
115 mapper,
116 in_situ_quant,
117 paged_attn_config,
118 )
119 }
120
121 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
122 fn load_model_from_path(
123 &self,
124 paths: &Box<dyn ModelPaths>,
125 dtype: &dyn TryIntoDType,
126 device: &Device,
127 silent: bool,
128 mapper: DeviceMapSetting,
129 in_situ_quant: Option<IsqType>,
130 mut paged_attn_config: Option<PagedAttentionConfig>,
131 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
132 let _progress_guard = ProgressScopeGuard::new(silent);
133 let paths = &paths
134 .as_ref()
135 .as_any()
136 .downcast_ref::<DiffusionModelPaths>()
137 .expect("Path downcast failed.")
138 .0;
139
140 if matches!(mapper, DeviceMapSetting::Map(_)) {
141 anyhow::bail!("Device mapping is not supported for diffusion models.")
142 }
143
144 if in_situ_quant.is_some() {
145 anyhow::bail!("ISQ is not supported for Diffusion models.");
146 }
147
148 if paged_attn_config.is_some() {
149 warn!("PagedAttention is not supported for Diffusion models, disabling it.");
150
151 paged_attn_config = None;
152 }
153
154 if crate::using_flash_attn() {
155 once_log_info("FlashAttention is enabled.");
156 }
157
158 let configs = paths
159 .config_filenames
160 .iter()
161 .map(std::fs::read_to_string)
162 .collect::<io::Result<Vec<_>>>()?;
163
164 #[cfg(feature = "cuda")]
165 if let Device::Cuda(dev) = &device {
166 unsafe { dev.disable_event_tracking() };
167 }
168
169 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
170 let dtype = mapper.get_min_dtype(dtype)?;
171
172 let attention_mechanism = if paged_attn_config.is_some() {
173 AttentionImplementation::PagedAttention
174 } else {
175 AttentionImplementation::Eager
176 };
177
178 let model = match self.kind {
179 ModelKind::Normal => {
180 let vbs = paths
181 .filenames
182 .iter()
183 .zip(self.inner.force_cpu_vb())
184 .map(|(path, force_cpu)| {
185 let dev = if force_cpu { &Device::Cpu } else { device };
186 from_mmaped_safetensors(
187 vec![path.clone()],
188 Vec::new(),
189 Some(dtype),
190 dev,
191 vec![None],
192 silent,
193 None,
194 |_| true,
195 Arc::new(|_| DeviceForLoadTensor::Base),
196 )
197 })
198 .collect::<candle_core::Result<Vec<_>>>()?;
199
200 self.inner.load(
201 configs,
202 vbs,
203 crate::pipeline::NormalLoadingMetadata {
204 mapper,
205 loading_isq: false,
206 real_device: device.clone(),
207 multi_progress: Arc::new(new_multi_progress()),
208 matformer_slicing_config: None,
209 },
210 attention_mechanism,
211 silent,
212 )?
213 }
214 _ => unreachable!(),
215 };
216
217 let max_seq_len = model.max_seq_len();
218 Ok(Arc::new(Mutex::new(DiffusionPipeline {
219 model,
220 model_id: self.model_id.clone(),
221 metadata: Arc::new(GeneralMetadata {
222 max_seq_len,
223 llg_factory: None,
224 is_xlora: false,
225 no_prefix_cache: false,
226 num_hidden_layers: 1, eos_tok: vec![],
228 kind: self.kind.clone(),
229 no_kv_cache: true, activation_dtype: dtype,
231 sliding_window: None,
232 cache_config: None,
233 cache_engine: None,
234 model_metadata: None,
235 modalities: Modalities {
236 input: vec![SupportedModality::Text],
237 output: vec![SupportedModality::Vision],
238 },
239 }),
240 dummy_cache: EitherCache::Full(Cache::new(0, false)),
241 })))
242 }
243
244 fn get_id(&self) -> String {
245 self.model_id.to_string()
246 }
247
248 fn get_kind(&self) -> ModelKind {
249 self.kind.clone()
250 }
251}
252
253impl PreProcessingMixin for DiffusionPipeline {
254 fn get_processor(&self) -> Arc<dyn Processor> {
255 Arc::new(DiffusionProcessor)
256 }
257 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
258 None
259 }
260 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
261 None
262 }
263}
264
265impl IsqPipelineMixin for DiffusionPipeline {
266 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
267 anyhow::bail!("Diffusion models do not support ISQ for now.")
268 }
269}
270
271impl CacheManagerMixin for DiffusionPipeline {
272 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
273 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
274 fn set_none_cache(
275 &self,
276 _seqs: &mut [&mut Sequence],
277 _reset_non_granular: bool,
278 _modify_draft_cache: bool,
279 _load_preallocated_cache: bool,
280 ) {
281 }
282 fn cache(&self) -> &EitherCache {
283 &self.dummy_cache
284 }
285}
286
287impl MetadataMixin for DiffusionPipeline {
288 fn device(&self) -> Device {
289 self.model.device().clone()
290 }
291 fn get_metadata(&self) -> Arc<GeneralMetadata> {
292 self.metadata.clone()
293 }
294 fn name(&self) -> String {
295 self.model_id.clone()
296 }
297 fn reset_non_granular_state(&self) {}
298 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
299 None
300 }
301 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
302 None
303 }
304}
305
306#[async_trait::async_trait]
307impl Pipeline for DiffusionPipeline {
308 fn forward_inputs(
309 &mut self,
310 inputs: Box<dyn Any>,
311 return_raw_logits: bool,
312 ) -> candle_core::Result<ForwardInputsResult> {
313 assert!(!return_raw_logits);
314
315 let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
316 let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
317 let (_b, c, h, w) = img.dims4()?;
318 let mut images = Vec::new();
319 for b_img in img.chunk(img.dim(0)?, 0)? {
320 let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
321 if c != 3 {
322 candle_core::bail!("Expected 3 channels in image output");
323 }
324 #[allow(clippy::cast_possible_truncation)]
325 images.push(DynamicImage::ImageRgb8(
326 RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
327 candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
328 )?,
329 ));
330 }
331 Ok(ForwardInputsResult::Image { images })
332 }
333 async fn sample_causal_gen(
334 &self,
335 _seqs: &mut [&mut Sequence],
336 _logits: Vec<Tensor>,
337 _prefix_cacher: &mut PrefixCacheManagerV2,
338 _disable_eos_stop: bool,
339 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
340 ) -> Result<(), candle_core::Error> {
341 candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
342 }
343 fn category(&self) -> ModelCategory {
344 ModelCategory::Diffusion
345 }
346}
347
348impl AnyMoePipelineMixin for DiffusionPipeline {}