mistralrs_core/pipeline/
diffusion.rs

1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3    AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4    DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5    IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{
16    progress::{new_multi_progress, ProgressScopeGuard},
17    tokens::get_token,
18    varbuilder_utils::from_mmaped_safetensors,
19};
20use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
24use image::{DynamicImage, RgbImage};
25use mistralrs_quant::log::once_log_info;
26use mistralrs_quant::IsqType;
27use rand_isaac::Isaac64Rng;
28use std::any::Any;
29use std::io;
30use std::sync::Arc;
31use tokenizers::Tokenizer;
32use tokio::sync::Mutex;
33use tracing::warn;
34
35pub struct DiffusionPipeline {
36    model: Box<dyn DiffusionModel + Send + Sync>,
37    model_id: String,
38    metadata: Arc<GeneralMetadata>,
39    dummy_cache: EitherCache,
40}
41
42/// A loader for a vision (non-quantized) model.
43pub struct DiffusionLoader {
44    inner: Box<dyn DiffusionModelLoader>,
45    model_id: String,
46    kind: ModelKind,
47}
48
49#[derive(Default)]
50/// A builder for a loader for a vision (non-quantized) model.
51pub struct DiffusionLoaderBuilder {
52    model_id: Option<String>,
53    kind: ModelKind,
54}
55
56impl DiffusionLoaderBuilder {
57    pub fn new(model_id: Option<String>) -> Self {
58        Self {
59            model_id,
60            kind: ModelKind::Normal,
61        }
62    }
63
64    pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
65        let loader: Box<dyn DiffusionModelLoader> = match loader {
66            DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
67            DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
68        };
69        Box::new(DiffusionLoader {
70            inner: loader,
71            model_id: self.model_id.unwrap(),
72            kind: self.kind,
73        })
74    }
75}
76
77impl Loader for DiffusionLoader {
78    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
79    fn load_model_from_hf(
80        &self,
81        revision: Option<String>,
82        token_source: TokenSource,
83        dtype: &dyn TryIntoDType,
84        device: &Device,
85        silent: bool,
86        mapper: DeviceMapSetting,
87        in_situ_quant: Option<IsqType>,
88        paged_attn_config: Option<PagedAttentionConfig>,
89    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
90        let _progress_guard = ProgressScopeGuard::new(silent);
91        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
92            let api = ApiBuilder::new()
93                .with_progress(!silent)
94                .with_token(get_token(&token_source)?)
95                .build()?;
96            let revision = revision.unwrap_or("main".to_string());
97            let api = api.repo(Repo::with_revision(
98                self.model_id.clone(),
99                RepoType::Model,
100                revision.clone(),
101            ));
102            let model_id = std::path::Path::new(&self.model_id);
103            let filenames = self.inner.get_model_paths(&api, model_id)?;
104            let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
105            Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
106                config_filenames,
107                filenames,
108            })))
109        };
110        self.load_model_from_path(
111            &paths?,
112            dtype,
113            device,
114            silent,
115            mapper,
116            in_situ_quant,
117            paged_attn_config,
118        )
119    }
120
121    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
122    fn load_model_from_path(
123        &self,
124        paths: &Box<dyn ModelPaths>,
125        dtype: &dyn TryIntoDType,
126        device: &Device,
127        silent: bool,
128        mapper: DeviceMapSetting,
129        in_situ_quant: Option<IsqType>,
130        mut paged_attn_config: Option<PagedAttentionConfig>,
131    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
132        let _progress_guard = ProgressScopeGuard::new(silent);
133        let paths = &paths
134            .as_ref()
135            .as_any()
136            .downcast_ref::<DiffusionModelPaths>()
137            .expect("Path downcast failed.")
138            .0;
139
140        if matches!(mapper, DeviceMapSetting::Map(_)) {
141            anyhow::bail!("Device mapping is not supported for diffusion models.")
142        }
143
144        if in_situ_quant.is_some() {
145            anyhow::bail!("ISQ is not supported for Diffusion models.");
146        }
147
148        if paged_attn_config.is_some() {
149            warn!("PagedAttention is not supported for Diffusion models, disabling it.");
150
151            paged_attn_config = None;
152        }
153
154        if crate::using_flash_attn() {
155            once_log_info("FlashAttention is enabled.");
156        }
157
158        let configs = paths
159            .config_filenames
160            .iter()
161            .map(std::fs::read_to_string)
162            .collect::<io::Result<Vec<_>>>()?;
163
164        #[cfg(feature = "cuda")]
165        if let Device::Cuda(dev) = &device {
166            unsafe { dev.disable_event_tracking() };
167        }
168
169        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
170        let dtype = mapper.get_min_dtype(dtype)?;
171
172        let attention_mechanism = if paged_attn_config.is_some() {
173            AttentionImplementation::PagedAttention
174        } else {
175            AttentionImplementation::Eager
176        };
177
178        let model = match self.kind {
179            ModelKind::Normal => {
180                let vbs = paths
181                    .filenames
182                    .iter()
183                    .zip(self.inner.force_cpu_vb())
184                    .map(|(path, force_cpu)| {
185                        let dev = if force_cpu { &Device::Cpu } else { device };
186                        from_mmaped_safetensors(
187                            vec![path.clone()],
188                            Vec::new(),
189                            Some(dtype),
190                            dev,
191                            vec![None],
192                            silent,
193                            None,
194                            |_| true,
195                            Arc::new(|_| DeviceForLoadTensor::Base),
196                        )
197                    })
198                    .collect::<candle_core::Result<Vec<_>>>()?;
199
200                self.inner.load(
201                    configs,
202                    vbs,
203                    crate::pipeline::NormalLoadingMetadata {
204                        mapper,
205                        loading_isq: false,
206                        real_device: device.clone(),
207                        multi_progress: Arc::new(new_multi_progress()),
208                        matformer_slicing_config: None,
209                    },
210                    attention_mechanism,
211                    silent,
212                )?
213            }
214            _ => unreachable!(),
215        };
216
217        let max_seq_len = model.max_seq_len();
218        Ok(Arc::new(Mutex::new(DiffusionPipeline {
219            model,
220            model_id: self.model_id.clone(),
221            metadata: Arc::new(GeneralMetadata {
222                max_seq_len,
223                llg_factory: None,
224                is_xlora: false,
225                no_prefix_cache: false,
226                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
227                eos_tok: vec![],
228                kind: self.kind.clone(),
229                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
230                activation_dtype: dtype,
231                sliding_window: None,
232                cache_config: None,
233                cache_engine: None,
234                model_metadata: None,
235                modalities: Modalities {
236                    input: vec![SupportedModality::Text],
237                    output: vec![SupportedModality::Vision],
238                },
239            }),
240            dummy_cache: EitherCache::Full(Cache::new(0, false)),
241        })))
242    }
243
244    fn get_id(&self) -> String {
245        self.model_id.to_string()
246    }
247
248    fn get_kind(&self) -> ModelKind {
249        self.kind.clone()
250    }
251}
252
253impl PreProcessingMixin for DiffusionPipeline {
254    fn get_processor(&self) -> Arc<dyn Processor> {
255        Arc::new(DiffusionProcessor)
256    }
257    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
258        None
259    }
260    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
261        None
262    }
263}
264
265impl IsqPipelineMixin for DiffusionPipeline {
266    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
267        anyhow::bail!("Diffusion models do not support ISQ for now.")
268    }
269}
270
271impl CacheManagerMixin for DiffusionPipeline {
272    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
273    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
274    fn set_none_cache(
275        &self,
276        _seqs: &mut [&mut Sequence],
277        _reset_non_granular: bool,
278        _modify_draft_cache: bool,
279        _load_preallocated_cache: bool,
280    ) {
281    }
282    fn cache(&self) -> &EitherCache {
283        &self.dummy_cache
284    }
285}
286
287impl MetadataMixin for DiffusionPipeline {
288    fn device(&self) -> Device {
289        self.model.device().clone()
290    }
291    fn get_metadata(&self) -> Arc<GeneralMetadata> {
292        self.metadata.clone()
293    }
294    fn name(&self) -> String {
295        self.model_id.clone()
296    }
297    fn reset_non_granular_state(&self) {}
298    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
299        None
300    }
301    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
302        None
303    }
304}
305
306#[async_trait::async_trait]
307impl Pipeline for DiffusionPipeline {
308    fn forward_inputs(
309        &mut self,
310        inputs: Box<dyn Any>,
311        return_raw_logits: bool,
312    ) -> candle_core::Result<ForwardInputsResult> {
313        assert!(!return_raw_logits);
314
315        let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
316        let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
317        let (_b, c, h, w) = img.dims4()?;
318        let mut images = Vec::new();
319        for b_img in img.chunk(img.dim(0)?, 0)? {
320            let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
321            if c != 3 {
322                candle_core::bail!("Expected 3 channels in image output");
323            }
324            #[allow(clippy::cast_possible_truncation)]
325            images.push(DynamicImage::ImageRgb8(
326                RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
327                    candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
328                )?,
329            ));
330        }
331        Ok(ForwardInputsResult::Image { images })
332    }
333    async fn sample_causal_gen(
334        &self,
335        _seqs: &mut [&mut Sequence],
336        _logits: Vec<Tensor>,
337        _prefix_cacher: &mut PrefixCacheManagerV2,
338        _disable_eos_stop: bool,
339        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
340    ) -> Result<(), candle_core::Error> {
341        candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
342    }
343    fn category(&self) -> ModelCategory {
344        ModelCategory::Diffusion
345    }
346}
347
348impl AnyMoePipelineMixin for DiffusionPipeline {}