mistralrs_core/pipeline/
diffusion.rs

1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3    AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4    DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5    IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::{self, DeviceMapper};
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::distributed::{self, WorkerTransferData};
11use crate::paged_attention::AttentionImplementation;
12use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
13use crate::prefix_cacher::PrefixCacheManagerV2;
14use crate::sequence::Sequence;
15use crate::utils::varbuilder_utils::DeviceForLoadTensor;
16use crate::utils::{
17    progress::{new_multi_progress, ProgressScopeGuard},
18    tokens::get_token,
19    varbuilder_utils::from_mmaped_safetensors,
20};
21use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
22use anyhow::Result;
23use candle_core::{DType, Device, Tensor};
24use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
25use image::{DynamicImage, RgbImage};
26use mistralrs_quant::log::once_log_info;
27use mistralrs_quant::IsqType;
28use rand_isaac::Isaac64Rng;
29use std::any::Any;
30use std::sync::Arc;
31use std::{env, io};
32use tokenizers::Tokenizer;
33use tokio::sync::Mutex;
34use tracing::warn;
35
36pub struct DiffusionPipeline {
37    model: Box<dyn DiffusionModel + Send + Sync>,
38    model_id: String,
39    metadata: Arc<GeneralMetadata>,
40    dummy_cache: EitherCache,
41}
42
43/// A loader for a vision (non-quantized) model.
44pub struct DiffusionLoader {
45    inner: Box<dyn DiffusionModelLoader>,
46    model_id: String,
47    kind: ModelKind,
48}
49
50#[derive(Default)]
51/// A builder for a loader for a vision (non-quantized) model.
52pub struct DiffusionLoaderBuilder {
53    model_id: Option<String>,
54    kind: ModelKind,
55}
56
57impl DiffusionLoaderBuilder {
58    pub fn new(model_id: Option<String>) -> Self {
59        Self {
60            model_id,
61            kind: ModelKind::Normal,
62        }
63    }
64
65    pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
66        let loader: Box<dyn DiffusionModelLoader> = match loader {
67            DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
68            DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
69        };
70        Box::new(DiffusionLoader {
71            inner: loader,
72            model_id: self.model_id.unwrap(),
73            kind: self.kind,
74        })
75    }
76}
77
78impl Loader for DiffusionLoader {
79    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
80    fn load_model_from_hf(
81        &self,
82        revision: Option<String>,
83        token_source: TokenSource,
84        dtype: &dyn TryIntoDType,
85        device: &Device,
86        silent: bool,
87        mapper: DeviceMapSetting,
88        in_situ_quant: Option<IsqType>,
89        paged_attn_config: Option<PagedAttentionConfig>,
90    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
91        let _progress_guard = ProgressScopeGuard::new(silent);
92        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
93            let api = ApiBuilder::new()
94                .with_progress(!silent)
95                .with_token(get_token(&token_source)?)
96                .build()?;
97            let revision = revision.unwrap_or("main".to_string());
98            let api = api.repo(Repo::with_revision(
99                self.model_id.clone(),
100                RepoType::Model,
101                revision.clone(),
102            ));
103            let model_id = std::path::Path::new(&self.model_id);
104            let filenames = self.inner.get_model_paths(&api, model_id)?;
105            let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
106            Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
107                config_filenames,
108                filenames,
109            })))
110        };
111        self.load_model_from_path(
112            &paths?,
113            dtype,
114            device,
115            silent,
116            mapper,
117            in_situ_quant,
118            paged_attn_config,
119        )
120    }
121
122    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
123    fn load_model_from_path(
124        &self,
125        paths: &Box<dyn ModelPaths>,
126        dtype: &dyn TryIntoDType,
127        device: &Device,
128        silent: bool,
129        mapper: DeviceMapSetting,
130        in_situ_quant: Option<IsqType>,
131        mut paged_attn_config: Option<PagedAttentionConfig>,
132    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
133        let _progress_guard = ProgressScopeGuard::new(silent);
134        let paths = &paths
135            .as_ref()
136            .as_any()
137            .downcast_ref::<DiffusionModelPaths>()
138            .expect("Path downcast failed.")
139            .0;
140
141        if matches!(mapper, DeviceMapSetting::Map(_)) {
142            anyhow::bail!("Device mapping is not supported for diffusion models.")
143        }
144
145        if in_situ_quant.is_some() {
146            anyhow::bail!("ISQ is not supported for Diffusion models.");
147        }
148
149        if paged_attn_config.is_some() {
150            warn!("PagedAttention is not supported for Diffusion models, disabling it.");
151
152            paged_attn_config = None;
153        }
154
155        if crate::using_flash_attn() {
156            once_log_info("FlashAttention is enabled.");
157        }
158
159        let configs = paths
160            .config_filenames
161            .iter()
162            .map(std::fs::read_to_string)
163            .collect::<io::Result<Vec<_>>>()?;
164
165        #[cfg(feature = "cuda")]
166        if let Device::Cuda(dev) = &device {
167            unsafe { dev.disable_event_tracking() };
168        }
169        let use_nccl = mistralrs_quant::distributed::use_nccl();
170        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
171            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
172            let WorkerTransferData::Init { id: _, worker_rank } = payload;
173            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
174        } else if use_nccl {
175            vec![candle_core::Device::new_cuda(0)?]
176        } else {
177            device_map::get_all_similar_devices(device)?
178        };
179
180        let mapper =
181            DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None, &available_devices)?;
182        let dtype = mapper.get_min_dtype(dtype)?;
183
184        let attention_mechanism = if paged_attn_config.is_some() {
185            AttentionImplementation::PagedAttention
186        } else {
187            AttentionImplementation::Eager
188        };
189
190        let model = match self.kind {
191            ModelKind::Normal => {
192                let vbs = paths
193                    .filenames
194                    .iter()
195                    .zip(self.inner.force_cpu_vb())
196                    .map(|(path, force_cpu)| {
197                        let dev = if force_cpu { &Device::Cpu } else { device };
198                        from_mmaped_safetensors(
199                            vec![path.clone()],
200                            Vec::new(),
201                            Some(dtype),
202                            dev,
203                            vec![None],
204                            silent,
205                            None,
206                            |_| true,
207                            Arc::new(|_| DeviceForLoadTensor::Base),
208                        )
209                    })
210                    .collect::<candle_core::Result<Vec<_>>>()?;
211
212                self.inner.load(
213                    configs,
214                    vbs,
215                    crate::pipeline::NormalLoadingMetadata {
216                        mapper,
217                        loading_isq: false,
218                        real_device: device.clone(),
219                        multi_progress: Arc::new(new_multi_progress()),
220                        matformer_slicing_config: None,
221                    },
222                    attention_mechanism,
223                    silent,
224                )?
225            }
226            _ => unreachable!(),
227        };
228
229        let max_seq_len = model.max_seq_len();
230        Ok(Arc::new(Mutex::new(DiffusionPipeline {
231            model,
232            model_id: self.model_id.clone(),
233            metadata: Arc::new(GeneralMetadata {
234                max_seq_len,
235                llg_factory: None,
236                is_xlora: false,
237                no_prefix_cache: false,
238                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
239                eos_tok: vec![],
240                kind: self.kind.clone(),
241                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
242                activation_dtype: dtype,
243                sliding_window: None,
244                cache_config: None,
245                cache_engine: None,
246                model_metadata: None,
247                modalities: Modalities {
248                    input: vec![SupportedModality::Text],
249                    output: vec![SupportedModality::Vision],
250                },
251            }),
252            dummy_cache: EitherCache::Full(Cache::new(0, false)),
253        })))
254    }
255
256    fn get_id(&self) -> String {
257        self.model_id.to_string()
258    }
259
260    fn get_kind(&self) -> ModelKind {
261        self.kind.clone()
262    }
263}
264
265impl PreProcessingMixin for DiffusionPipeline {
266    fn get_processor(&self) -> Arc<dyn Processor> {
267        Arc::new(DiffusionProcessor)
268    }
269    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
270        None
271    }
272    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
273        None
274    }
275}
276
277impl IsqPipelineMixin for DiffusionPipeline {
278    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
279        anyhow::bail!("Diffusion models do not support ISQ for now.")
280    }
281}
282
283impl CacheManagerMixin for DiffusionPipeline {
284    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
285    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
286    fn set_none_cache(
287        &self,
288        _seqs: &mut [&mut Sequence],
289        _reset_non_granular: bool,
290        _modify_draft_cache: bool,
291        _load_preallocated_cache: bool,
292    ) {
293    }
294    fn cache(&self) -> &EitherCache {
295        &self.dummy_cache
296    }
297}
298
299impl MetadataMixin for DiffusionPipeline {
300    fn device(&self) -> Device {
301        self.model.device().clone()
302    }
303    fn get_metadata(&self) -> Arc<GeneralMetadata> {
304        self.metadata.clone()
305    }
306    fn name(&self) -> String {
307        self.model_id.clone()
308    }
309    fn reset_non_granular_state(&self) {}
310    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
311        None
312    }
313    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
314        None
315    }
316}
317
318#[async_trait::async_trait]
319impl Pipeline for DiffusionPipeline {
320    fn forward_inputs(
321        &mut self,
322        inputs: Box<dyn Any>,
323        return_raw_logits: bool,
324    ) -> candle_core::Result<ForwardInputsResult> {
325        assert!(!return_raw_logits);
326
327        let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
328        let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
329        let (_b, c, h, w) = img.dims4()?;
330        let mut images = Vec::new();
331        for b_img in img.chunk(img.dim(0)?, 0)? {
332            let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
333            if c != 3 {
334                candle_core::bail!("Expected 3 channels in image output");
335            }
336            #[allow(clippy::cast_possible_truncation)]
337            images.push(DynamicImage::ImageRgb8(
338                RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
339                    candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
340                )?,
341            ));
342        }
343        Ok(ForwardInputsResult::Image { images })
344    }
345    async fn sample_causal_gen(
346        &self,
347        _seqs: &mut [&mut Sequence],
348        _logits: Vec<Tensor>,
349        _prefix_cacher: &mut PrefixCacheManagerV2,
350        _disable_eos_stop: bool,
351        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
352    ) -> Result<(), candle_core::Error> {
353        candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
354    }
355    fn category(&self) -> ModelCategory {
356        ModelCategory::Diffusion
357    }
358}
359
360impl AnyMoePipelineMixin for DiffusionPipeline {}