mistralrs_core/pipeline/
diffusion.rs

1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3    AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4    DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5    IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::ChatTemplate;
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
17use anyhow::Result;
18use candle_core::{DType, Device, Tensor};
19use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
20use image::{DynamicImage, RgbImage};
21use indicatif::MultiProgress;
22use mistralrs_quant::IsqType;
23use rand_isaac::Isaac64Rng;
24use std::any::Any;
25use std::io;
26use std::sync::Arc;
27use tokenizers::Tokenizer;
28use tokio::sync::Mutex;
29use tracing::warn;
30
31pub struct DiffusionPipeline {
32    model: Box<dyn DiffusionModel + Send + Sync>,
33    model_id: String,
34    metadata: Arc<GeneralMetadata>,
35    dummy_cache: EitherCache,
36}
37
38/// A loader for a vision (non-quantized) model.
39pub struct DiffusionLoader {
40    inner: Box<dyn DiffusionModelLoader>,
41    model_id: String,
42    config: DiffusionSpecificConfig,
43    kind: ModelKind,
44}
45
46#[derive(Default)]
47/// A builder for a loader for a vision (non-quantized) model.
48pub struct DiffusionLoaderBuilder {
49    model_id: Option<String>,
50    config: DiffusionSpecificConfig,
51    kind: ModelKind,
52}
53
54#[derive(Clone, Default)]
55/// Config specific to loading a vision model.
56pub struct DiffusionSpecificConfig {
57    pub use_flash_attn: bool,
58}
59
60impl DiffusionLoaderBuilder {
61    pub fn new(config: DiffusionSpecificConfig, model_id: Option<String>) -> Self {
62        Self {
63            config,
64            model_id,
65            kind: ModelKind::Normal,
66        }
67    }
68
69    pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
70        let loader: Box<dyn DiffusionModelLoader> = match loader {
71            DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
72            DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
73        };
74        Box::new(DiffusionLoader {
75            inner: loader,
76            model_id: self.model_id.unwrap(),
77            config: self.config,
78            kind: self.kind,
79        })
80    }
81}
82
83impl Loader for DiffusionLoader {
84    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
85    fn load_model_from_hf(
86        &self,
87        revision: Option<String>,
88        token_source: TokenSource,
89        dtype: &dyn TryIntoDType,
90        device: &Device,
91        silent: bool,
92        mapper: DeviceMapSetting,
93        in_situ_quant: Option<IsqType>,
94        paged_attn_config: Option<PagedAttentionConfig>,
95    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
96        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
97            let api = ApiBuilder::new()
98                .with_progress(!silent)
99                .with_token(get_token(&token_source)?)
100                .build()?;
101            let revision = revision.unwrap_or("main".to_string());
102            let api = api.repo(Repo::with_revision(
103                self.model_id.clone(),
104                RepoType::Model,
105                revision.clone(),
106            ));
107            let model_id = std::path::Path::new(&self.model_id);
108            let filenames = self.inner.get_model_paths(&api, model_id)?;
109            let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
110            Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
111                config_filenames,
112                filenames,
113            })))
114        };
115        self.load_model_from_path(
116            &paths?,
117            dtype,
118            device,
119            silent,
120            mapper,
121            in_situ_quant,
122            paged_attn_config,
123        )
124    }
125
126    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
127    fn load_model_from_path(
128        &self,
129        paths: &Box<dyn ModelPaths>,
130        dtype: &dyn TryIntoDType,
131        device: &Device,
132        silent: bool,
133        mapper: DeviceMapSetting,
134        in_situ_quant: Option<IsqType>,
135        mut paged_attn_config: Option<PagedAttentionConfig>,
136    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
137        let paths = &paths
138            .as_ref()
139            .as_any()
140            .downcast_ref::<DiffusionModelPaths>()
141            .expect("Path downcast failed.")
142            .0;
143
144        if matches!(mapper, DeviceMapSetting::Map(_)) {
145            anyhow::bail!("Device mapping is not supported for diffusion models.")
146        }
147
148        if in_situ_quant.is_some() {
149            anyhow::bail!("ISQ is not supported for Diffusion models.");
150        }
151
152        if paged_attn_config.is_some() {
153            warn!("PagedAttention is not supported for Diffusion models, disabling it.");
154
155            paged_attn_config = None;
156        }
157
158        let configs = paths
159            .config_filenames
160            .iter()
161            .map(std::fs::read_to_string)
162            .collect::<io::Result<Vec<_>>>()?;
163
164        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
165        let dtype = mapper.get_min_dtype(dtype)?;
166
167        let attention_mechanism = if paged_attn_config.is_some() {
168            AttentionImplementation::PagedAttention
169        } else {
170            AttentionImplementation::Eager
171        };
172
173        let model = match self.kind {
174            ModelKind::Normal => {
175                let vbs = paths
176                    .filenames
177                    .iter()
178                    .zip(self.inner.force_cpu_vb())
179                    .map(|(path, force_cpu)| {
180                        let dev = if force_cpu { &Device::Cpu } else { device };
181                        from_mmaped_safetensors(
182                            vec![path.clone()],
183                            Vec::new(),
184                            Some(dtype),
185                            dev,
186                            vec![None],
187                            silent,
188                            None,
189                            |_| true,
190                            Arc::new(|_| DeviceForLoadTensor::Base),
191                        )
192                    })
193                    .collect::<candle_core::Result<Vec<_>>>()?;
194
195                self.inner.load(
196                    configs,
197                    self.config.use_flash_attn,
198                    vbs,
199                    crate::pipeline::NormalLoadingMetadata {
200                        mapper,
201                        loading_isq: false,
202                        real_device: device.clone(),
203                        multi_progress: Arc::new(MultiProgress::new()),
204                    },
205                    attention_mechanism,
206                    silent,
207                )?
208            }
209            _ => unreachable!(),
210        };
211
212        let max_seq_len = model.max_seq_len();
213        Ok(Arc::new(Mutex::new(DiffusionPipeline {
214            model,
215            model_id: self.model_id.clone(),
216            metadata: Arc::new(GeneralMetadata {
217                max_seq_len,
218                tok_env: None,
219                is_xlora: false,
220                no_prefix_cache: false,
221                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
222                eos_tok: vec![],
223                kind: self.kind.clone(),
224                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
225                activation_dtype: dtype,
226                sliding_window: None,
227                cache_config: None,
228                cache_engine: None,
229                prompt_chunksize: None,
230                model_metadata: None,
231            }),
232            dummy_cache: EitherCache::Full(Cache::new(0, false)),
233        })))
234    }
235
236    fn get_id(&self) -> String {
237        self.model_id.to_string()
238    }
239
240    fn get_kind(&self) -> ModelKind {
241        self.kind.clone()
242    }
243}
244
245impl PreProcessingMixin for DiffusionPipeline {
246    fn get_processor(&self) -> Arc<dyn Processor> {
247        Arc::new(DiffusionProcessor)
248    }
249    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
250        None
251    }
252    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
253        None
254    }
255}
256
257impl IsqPipelineMixin for DiffusionPipeline {
258    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
259        anyhow::bail!("Diffusion models do not support ISQ for now.")
260    }
261}
262
263impl CacheManagerMixin for DiffusionPipeline {
264    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
265    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
266    fn set_none_cache(
267        &self,
268        _seqs: &mut [&mut Sequence],
269        _reset_non_granular: bool,
270        _modify_draft_cache: bool,
271        _load_preallocated_cache: bool,
272    ) {
273    }
274    fn cache(&self) -> &EitherCache {
275        &self.dummy_cache
276    }
277}
278
279impl MetadataMixin for DiffusionPipeline {
280    fn device(&self) -> Device {
281        self.model.device().clone()
282    }
283    fn get_metadata(&self) -> Arc<GeneralMetadata> {
284        self.metadata.clone()
285    }
286    fn name(&self) -> String {
287        self.model_id.clone()
288    }
289    fn reset_non_granular_state(&self) {}
290    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
291        None
292    }
293    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
294        None
295    }
296}
297
298#[async_trait::async_trait]
299impl Pipeline for DiffusionPipeline {
300    fn forward_inputs(
301        &mut self,
302        inputs: Box<dyn Any>,
303        return_raw_logits: bool,
304    ) -> candle_core::Result<ForwardInputsResult> {
305        assert!(!return_raw_logits);
306
307        let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
308        let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
309        let (_b, c, h, w) = img.dims4()?;
310        let mut images = Vec::new();
311        for b_img in img.chunk(img.dim(0)?, 0)? {
312            let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
313            if c != 3 {
314                candle_core::bail!("Expected 3 channels in image output");
315            }
316            #[allow(clippy::cast_possible_truncation)]
317            images.push(DynamicImage::ImageRgb8(
318                RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
319                    candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
320                )?,
321            ));
322        }
323        Ok(ForwardInputsResult::Image { images })
324    }
325    async fn sample_causal_gen(
326        &self,
327        _seqs: &mut [&mut Sequence],
328        _logits: Vec<Tensor>,
329        _prefix_cacher: &mut PrefixCacheManagerV2,
330        _disable_eos_stop: bool,
331        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
332    ) -> Result<(), candle_core::Error> {
333        candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
334    }
335    fn category(&self) -> ModelCategory {
336        ModelCategory::Diffusion
337    }
338}
339
340impl AnyMoePipelineMixin for DiffusionPipeline {}