1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3 AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4 DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5 IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::ChatTemplate;
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
17use anyhow::Result;
18use candle_core::{DType, Device, Tensor};
19use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
20use image::{DynamicImage, RgbImage};
21use indicatif::MultiProgress;
22use mistralrs_quant::log::once_log_info;
23use mistralrs_quant::IsqType;
24use rand_isaac::Isaac64Rng;
25use std::any::Any;
26use std::io;
27use std::sync::Arc;
28use tokenizers::Tokenizer;
29use tokio::sync::Mutex;
30use tracing::warn;
31
32pub struct DiffusionPipeline {
33 model: Box<dyn DiffusionModel + Send + Sync>,
34 model_id: String,
35 metadata: Arc<GeneralMetadata>,
36 dummy_cache: EitherCache,
37}
38
39pub struct DiffusionLoader {
41 inner: Box<dyn DiffusionModelLoader>,
42 model_id: String,
43 kind: ModelKind,
44}
45
46#[derive(Default)]
47pub struct DiffusionLoaderBuilder {
49 model_id: Option<String>,
50 kind: ModelKind,
51}
52
53impl DiffusionLoaderBuilder {
54 pub fn new(model_id: Option<String>) -> Self {
55 Self {
56 model_id,
57 kind: ModelKind::Normal,
58 }
59 }
60
61 pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
62 let loader: Box<dyn DiffusionModelLoader> = match loader {
63 DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
64 DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
65 };
66 Box::new(DiffusionLoader {
67 inner: loader,
68 model_id: self.model_id.unwrap(),
69 kind: self.kind,
70 })
71 }
72}
73
74impl Loader for DiffusionLoader {
75 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
76 fn load_model_from_hf(
77 &self,
78 revision: Option<String>,
79 token_source: TokenSource,
80 dtype: &dyn TryIntoDType,
81 device: &Device,
82 silent: bool,
83 mapper: DeviceMapSetting,
84 in_situ_quant: Option<IsqType>,
85 paged_attn_config: Option<PagedAttentionConfig>,
86 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
87 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
88 let api = ApiBuilder::new()
89 .with_progress(!silent)
90 .with_token(get_token(&token_source)?)
91 .build()?;
92 let revision = revision.unwrap_or("main".to_string());
93 let api = api.repo(Repo::with_revision(
94 self.model_id.clone(),
95 RepoType::Model,
96 revision.clone(),
97 ));
98 let model_id = std::path::Path::new(&self.model_id);
99 let filenames = self.inner.get_model_paths(&api, model_id)?;
100 let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
101 Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
102 config_filenames,
103 filenames,
104 })))
105 };
106 self.load_model_from_path(
107 &paths?,
108 dtype,
109 device,
110 silent,
111 mapper,
112 in_situ_quant,
113 paged_attn_config,
114 )
115 }
116
117 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
118 fn load_model_from_path(
119 &self,
120 paths: &Box<dyn ModelPaths>,
121 dtype: &dyn TryIntoDType,
122 device: &Device,
123 silent: bool,
124 mapper: DeviceMapSetting,
125 in_situ_quant: Option<IsqType>,
126 mut paged_attn_config: Option<PagedAttentionConfig>,
127 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
128 let paths = &paths
129 .as_ref()
130 .as_any()
131 .downcast_ref::<DiffusionModelPaths>()
132 .expect("Path downcast failed.")
133 .0;
134
135 if matches!(mapper, DeviceMapSetting::Map(_)) {
136 anyhow::bail!("Device mapping is not supported for diffusion models.")
137 }
138
139 if in_situ_quant.is_some() {
140 anyhow::bail!("ISQ is not supported for Diffusion models.");
141 }
142
143 if paged_attn_config.is_some() {
144 warn!("PagedAttention is not supported for Diffusion models, disabling it.");
145
146 paged_attn_config = None;
147 }
148
149 if crate::using_flash_attn() {
150 once_log_info("FlashAttention is enabled.");
151 }
152
153 let configs = paths
154 .config_filenames
155 .iter()
156 .map(std::fs::read_to_string)
157 .collect::<io::Result<Vec<_>>>()?;
158
159 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
160 let dtype = mapper.get_min_dtype(dtype)?;
161
162 let attention_mechanism = if paged_attn_config.is_some() {
163 AttentionImplementation::PagedAttention
164 } else {
165 AttentionImplementation::Eager
166 };
167
168 let model = match self.kind {
169 ModelKind::Normal => {
170 let vbs = paths
171 .filenames
172 .iter()
173 .zip(self.inner.force_cpu_vb())
174 .map(|(path, force_cpu)| {
175 let dev = if force_cpu { &Device::Cpu } else { device };
176 from_mmaped_safetensors(
177 vec![path.clone()],
178 Vec::new(),
179 Some(dtype),
180 dev,
181 vec![None],
182 silent,
183 None,
184 |_| true,
185 Arc::new(|_| DeviceForLoadTensor::Base),
186 )
187 })
188 .collect::<candle_core::Result<Vec<_>>>()?;
189
190 self.inner.load(
191 configs,
192 vbs,
193 crate::pipeline::NormalLoadingMetadata {
194 mapper,
195 loading_isq: false,
196 real_device: device.clone(),
197 multi_progress: Arc::new(MultiProgress::new()),
198 },
199 attention_mechanism,
200 silent,
201 )?
202 }
203 _ => unreachable!(),
204 };
205
206 let max_seq_len = model.max_seq_len();
207 Ok(Arc::new(Mutex::new(DiffusionPipeline {
208 model,
209 model_id: self.model_id.clone(),
210 metadata: Arc::new(GeneralMetadata {
211 max_seq_len,
212 llg_factory: None,
213 is_xlora: false,
214 no_prefix_cache: false,
215 num_hidden_layers: 1, eos_tok: vec![],
217 kind: self.kind.clone(),
218 no_kv_cache: true, activation_dtype: dtype,
220 sliding_window: None,
221 cache_config: None,
222 cache_engine: None,
223 prompt_chunksize: None,
224 model_metadata: None,
225 }),
226 dummy_cache: EitherCache::Full(Cache::new(0, false)),
227 })))
228 }
229
230 fn get_id(&self) -> String {
231 self.model_id.to_string()
232 }
233
234 fn get_kind(&self) -> ModelKind {
235 self.kind.clone()
236 }
237}
238
239impl PreProcessingMixin for DiffusionPipeline {
240 fn get_processor(&self) -> Arc<dyn Processor> {
241 Arc::new(DiffusionProcessor)
242 }
243 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
244 None
245 }
246 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
247 None
248 }
249}
250
251impl IsqPipelineMixin for DiffusionPipeline {
252 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
253 anyhow::bail!("Diffusion models do not support ISQ for now.")
254 }
255}
256
257impl CacheManagerMixin for DiffusionPipeline {
258 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
259 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
260 fn set_none_cache(
261 &self,
262 _seqs: &mut [&mut Sequence],
263 _reset_non_granular: bool,
264 _modify_draft_cache: bool,
265 _load_preallocated_cache: bool,
266 ) {
267 }
268 fn cache(&self) -> &EitherCache {
269 &self.dummy_cache
270 }
271}
272
273impl MetadataMixin for DiffusionPipeline {
274 fn device(&self) -> Device {
275 self.model.device().clone()
276 }
277 fn get_metadata(&self) -> Arc<GeneralMetadata> {
278 self.metadata.clone()
279 }
280 fn name(&self) -> String {
281 self.model_id.clone()
282 }
283 fn reset_non_granular_state(&self) {}
284 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
285 None
286 }
287 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
288 None
289 }
290}
291
292#[async_trait::async_trait]
293impl Pipeline for DiffusionPipeline {
294 fn forward_inputs(
295 &mut self,
296 inputs: Box<dyn Any>,
297 return_raw_logits: bool,
298 ) -> candle_core::Result<ForwardInputsResult> {
299 assert!(!return_raw_logits);
300
301 let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
302 let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
303 let (_b, c, h, w) = img.dims4()?;
304 let mut images = Vec::new();
305 for b_img in img.chunk(img.dim(0)?, 0)? {
306 let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
307 if c != 3 {
308 candle_core::bail!("Expected 3 channels in image output");
309 }
310 #[allow(clippy::cast_possible_truncation)]
311 images.push(DynamicImage::ImageRgb8(
312 RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
313 candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
314 )?,
315 ));
316 }
317 Ok(ForwardInputsResult::Image { images })
318 }
319 async fn sample_causal_gen(
320 &self,
321 _seqs: &mut [&mut Sequence],
322 _logits: Vec<Tensor>,
323 _prefix_cacher: &mut PrefixCacheManagerV2,
324 _disable_eos_stop: bool,
325 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
326 ) -> Result<(), candle_core::Error> {
327 candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
328 }
329 fn category(&self) -> ModelCategory {
330 ModelCategory::Diffusion
331 }
332}
333
334impl AnyMoePipelineMixin for DiffusionPipeline {}