mistralrs_core/pipeline/
diffusion.rs

1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3    AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4    DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5    IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
17use anyhow::Result;
18use candle_core::{DType, Device, Tensor};
19use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
20use image::{DynamicImage, RgbImage};
21use indicatif::MultiProgress;
22use mistralrs_quant::log::once_log_info;
23use mistralrs_quant::IsqType;
24use rand_isaac::Isaac64Rng;
25use std::any::Any;
26use std::io;
27use std::sync::Arc;
28use tokenizers::Tokenizer;
29use tokio::sync::Mutex;
30use tracing::warn;
31
32pub struct DiffusionPipeline {
33    model: Box<dyn DiffusionModel + Send + Sync>,
34    model_id: String,
35    metadata: Arc<GeneralMetadata>,
36    dummy_cache: EitherCache,
37}
38
39/// A loader for a vision (non-quantized) model.
40pub struct DiffusionLoader {
41    inner: Box<dyn DiffusionModelLoader>,
42    model_id: String,
43    kind: ModelKind,
44}
45
46#[derive(Default)]
47/// A builder for a loader for a vision (non-quantized) model.
48pub struct DiffusionLoaderBuilder {
49    model_id: Option<String>,
50    kind: ModelKind,
51}
52
53impl DiffusionLoaderBuilder {
54    pub fn new(model_id: Option<String>) -> Self {
55        Self {
56            model_id,
57            kind: ModelKind::Normal,
58        }
59    }
60
61    pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
62        let loader: Box<dyn DiffusionModelLoader> = match loader {
63            DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
64            DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
65        };
66        Box::new(DiffusionLoader {
67            inner: loader,
68            model_id: self.model_id.unwrap(),
69            kind: self.kind,
70        })
71    }
72}
73
74impl Loader for DiffusionLoader {
75    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
76    fn load_model_from_hf(
77        &self,
78        revision: Option<String>,
79        token_source: TokenSource,
80        dtype: &dyn TryIntoDType,
81        device: &Device,
82        silent: bool,
83        mapper: DeviceMapSetting,
84        in_situ_quant: Option<IsqType>,
85        paged_attn_config: Option<PagedAttentionConfig>,
86    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
87        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
88            let api = ApiBuilder::new()
89                .with_progress(!silent)
90                .with_token(get_token(&token_source)?)
91                .build()?;
92            let revision = revision.unwrap_or("main".to_string());
93            let api = api.repo(Repo::with_revision(
94                self.model_id.clone(),
95                RepoType::Model,
96                revision.clone(),
97            ));
98            let model_id = std::path::Path::new(&self.model_id);
99            let filenames = self.inner.get_model_paths(&api, model_id)?;
100            let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
101            Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
102                config_filenames,
103                filenames,
104            })))
105        };
106        self.load_model_from_path(
107            &paths?,
108            dtype,
109            device,
110            silent,
111            mapper,
112            in_situ_quant,
113            paged_attn_config,
114        )
115    }
116
117    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
118    fn load_model_from_path(
119        &self,
120        paths: &Box<dyn ModelPaths>,
121        dtype: &dyn TryIntoDType,
122        device: &Device,
123        silent: bool,
124        mapper: DeviceMapSetting,
125        in_situ_quant: Option<IsqType>,
126        mut paged_attn_config: Option<PagedAttentionConfig>,
127    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
128        let paths = &paths
129            .as_ref()
130            .as_any()
131            .downcast_ref::<DiffusionModelPaths>()
132            .expect("Path downcast failed.")
133            .0;
134
135        if matches!(mapper, DeviceMapSetting::Map(_)) {
136            anyhow::bail!("Device mapping is not supported for diffusion models.")
137        }
138
139        if in_situ_quant.is_some() {
140            anyhow::bail!("ISQ is not supported for Diffusion models.");
141        }
142
143        if paged_attn_config.is_some() {
144            warn!("PagedAttention is not supported for Diffusion models, disabling it.");
145
146            paged_attn_config = None;
147        }
148
149        if crate::using_flash_attn() {
150            once_log_info("FlashAttention is enabled.");
151        }
152
153        let configs = paths
154            .config_filenames
155            .iter()
156            .map(std::fs::read_to_string)
157            .collect::<io::Result<Vec<_>>>()?;
158
159        #[cfg(feature = "cuda")]
160        if let Device::Cuda(dev) = &device {
161            unsafe { dev.disable_event_tracking() };
162        }
163
164        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
165        let dtype = mapper.get_min_dtype(dtype)?;
166
167        let attention_mechanism = if paged_attn_config.is_some() {
168            AttentionImplementation::PagedAttention
169        } else {
170            AttentionImplementation::Eager
171        };
172
173        let model = match self.kind {
174            ModelKind::Normal => {
175                let vbs = paths
176                    .filenames
177                    .iter()
178                    .zip(self.inner.force_cpu_vb())
179                    .map(|(path, force_cpu)| {
180                        let dev = if force_cpu { &Device::Cpu } else { device };
181                        from_mmaped_safetensors(
182                            vec![path.clone()],
183                            Vec::new(),
184                            Some(dtype),
185                            dev,
186                            vec![None],
187                            silent,
188                            None,
189                            |_| true,
190                            Arc::new(|_| DeviceForLoadTensor::Base),
191                        )
192                    })
193                    .collect::<candle_core::Result<Vec<_>>>()?;
194
195                self.inner.load(
196                    configs,
197                    vbs,
198                    crate::pipeline::NormalLoadingMetadata {
199                        mapper,
200                        loading_isq: false,
201                        real_device: device.clone(),
202                        multi_progress: Arc::new(MultiProgress::new()),
203                        matformer_slicing_config: None,
204                    },
205                    attention_mechanism,
206                    silent,
207                )?
208            }
209            _ => unreachable!(),
210        };
211
212        let max_seq_len = model.max_seq_len();
213        Ok(Arc::new(Mutex::new(DiffusionPipeline {
214            model,
215            model_id: self.model_id.clone(),
216            metadata: Arc::new(GeneralMetadata {
217                max_seq_len,
218                llg_factory: None,
219                is_xlora: false,
220                no_prefix_cache: false,
221                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
222                eos_tok: vec![],
223                kind: self.kind.clone(),
224                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
225                activation_dtype: dtype,
226                sliding_window: None,
227                cache_config: None,
228                cache_engine: None,
229                model_metadata: None,
230                modalities: Modalities {
231                    input: vec![SupportedModality::Text],
232                    output: vec![SupportedModality::Vision],
233                },
234            }),
235            dummy_cache: EitherCache::Full(Cache::new(0, false)),
236        })))
237    }
238
239    fn get_id(&self) -> String {
240        self.model_id.to_string()
241    }
242
243    fn get_kind(&self) -> ModelKind {
244        self.kind.clone()
245    }
246}
247
248impl PreProcessingMixin for DiffusionPipeline {
249    fn get_processor(&self) -> Arc<dyn Processor> {
250        Arc::new(DiffusionProcessor)
251    }
252    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
253        None
254    }
255    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
256        None
257    }
258}
259
260impl IsqPipelineMixin for DiffusionPipeline {
261    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
262        anyhow::bail!("Diffusion models do not support ISQ for now.")
263    }
264}
265
266impl CacheManagerMixin for DiffusionPipeline {
267    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
268    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
269    fn set_none_cache(
270        &self,
271        _seqs: &mut [&mut Sequence],
272        _reset_non_granular: bool,
273        _modify_draft_cache: bool,
274        _load_preallocated_cache: bool,
275    ) {
276    }
277    fn cache(&self) -> &EitherCache {
278        &self.dummy_cache
279    }
280}
281
282impl MetadataMixin for DiffusionPipeline {
283    fn device(&self) -> Device {
284        self.model.device().clone()
285    }
286    fn get_metadata(&self) -> Arc<GeneralMetadata> {
287        self.metadata.clone()
288    }
289    fn name(&self) -> String {
290        self.model_id.clone()
291    }
292    fn reset_non_granular_state(&self) {}
293    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
294        None
295    }
296    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
297        None
298    }
299}
300
301#[async_trait::async_trait]
302impl Pipeline for DiffusionPipeline {
303    fn forward_inputs(
304        &mut self,
305        inputs: Box<dyn Any>,
306        return_raw_logits: bool,
307    ) -> candle_core::Result<ForwardInputsResult> {
308        assert!(!return_raw_logits);
309
310        let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
311        let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
312        let (_b, c, h, w) = img.dims4()?;
313        let mut images = Vec::new();
314        for b_img in img.chunk(img.dim(0)?, 0)? {
315            let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
316            if c != 3 {
317                candle_core::bail!("Expected 3 channels in image output");
318            }
319            #[allow(clippy::cast_possible_truncation)]
320            images.push(DynamicImage::ImageRgb8(
321                RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
322                    candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
323                )?,
324            ));
325        }
326        Ok(ForwardInputsResult::Image { images })
327    }
328    async fn sample_causal_gen(
329        &self,
330        _seqs: &mut [&mut Sequence],
331        _logits: Vec<Tensor>,
332        _prefix_cacher: &mut PrefixCacheManagerV2,
333        _disable_eos_stop: bool,
334        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
335    ) -> Result<(), candle_core::Error> {
336        candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
337    }
338    fn category(&self) -> ModelCategory {
339        ModelCategory::Diffusion
340    }
341}
342
343impl AnyMoePipelineMixin for DiffusionPipeline {}