mistralrs_core/pipeline/
diffusion.rs

1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3    AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4    DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5    IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::ChatTemplate;
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
17use anyhow::Result;
18use candle_core::{DType, Device, Tensor};
19use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
20use image::{DynamicImage, RgbImage};
21use indicatif::MultiProgress;
22use mistralrs_quant::log::once_log_info;
23use mistralrs_quant::IsqType;
24use rand_isaac::Isaac64Rng;
25use std::any::Any;
26use std::io;
27use std::sync::Arc;
28use tokenizers::Tokenizer;
29use tokio::sync::Mutex;
30use tracing::warn;
31
32pub struct DiffusionPipeline {
33    model: Box<dyn DiffusionModel + Send + Sync>,
34    model_id: String,
35    metadata: Arc<GeneralMetadata>,
36    dummy_cache: EitherCache,
37}
38
39/// A loader for a vision (non-quantized) model.
40pub struct DiffusionLoader {
41    inner: Box<dyn DiffusionModelLoader>,
42    model_id: String,
43    kind: ModelKind,
44}
45
46#[derive(Default)]
47/// A builder for a loader for a vision (non-quantized) model.
48pub struct DiffusionLoaderBuilder {
49    model_id: Option<String>,
50    kind: ModelKind,
51}
52
53impl DiffusionLoaderBuilder {
54    pub fn new(model_id: Option<String>) -> Self {
55        Self {
56            model_id,
57            kind: ModelKind::Normal,
58        }
59    }
60
61    pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
62        let loader: Box<dyn DiffusionModelLoader> = match loader {
63            DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
64            DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
65        };
66        Box::new(DiffusionLoader {
67            inner: loader,
68            model_id: self.model_id.unwrap(),
69            kind: self.kind,
70        })
71    }
72}
73
74impl Loader for DiffusionLoader {
75    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
76    fn load_model_from_hf(
77        &self,
78        revision: Option<String>,
79        token_source: TokenSource,
80        dtype: &dyn TryIntoDType,
81        device: &Device,
82        silent: bool,
83        mapper: DeviceMapSetting,
84        in_situ_quant: Option<IsqType>,
85        paged_attn_config: Option<PagedAttentionConfig>,
86    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
87        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
88            let api = ApiBuilder::new()
89                .with_progress(!silent)
90                .with_token(get_token(&token_source)?)
91                .build()?;
92            let revision = revision.unwrap_or("main".to_string());
93            let api = api.repo(Repo::with_revision(
94                self.model_id.clone(),
95                RepoType::Model,
96                revision.clone(),
97            ));
98            let model_id = std::path::Path::new(&self.model_id);
99            let filenames = self.inner.get_model_paths(&api, model_id)?;
100            let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
101            Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
102                config_filenames,
103                filenames,
104            })))
105        };
106        self.load_model_from_path(
107            &paths?,
108            dtype,
109            device,
110            silent,
111            mapper,
112            in_situ_quant,
113            paged_attn_config,
114        )
115    }
116
117    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
118    fn load_model_from_path(
119        &self,
120        paths: &Box<dyn ModelPaths>,
121        dtype: &dyn TryIntoDType,
122        device: &Device,
123        silent: bool,
124        mapper: DeviceMapSetting,
125        in_situ_quant: Option<IsqType>,
126        mut paged_attn_config: Option<PagedAttentionConfig>,
127    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
128        let paths = &paths
129            .as_ref()
130            .as_any()
131            .downcast_ref::<DiffusionModelPaths>()
132            .expect("Path downcast failed.")
133            .0;
134
135        if matches!(mapper, DeviceMapSetting::Map(_)) {
136            anyhow::bail!("Device mapping is not supported for diffusion models.")
137        }
138
139        if in_situ_quant.is_some() {
140            anyhow::bail!("ISQ is not supported for Diffusion models.");
141        }
142
143        if paged_attn_config.is_some() {
144            warn!("PagedAttention is not supported for Diffusion models, disabling it.");
145
146            paged_attn_config = None;
147        }
148
149        if crate::using_flash_attn() {
150            once_log_info("FlashAttention is enabled.");
151        }
152
153        let configs = paths
154            .config_filenames
155            .iter()
156            .map(std::fs::read_to_string)
157            .collect::<io::Result<Vec<_>>>()?;
158
159        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
160        let dtype = mapper.get_min_dtype(dtype)?;
161
162        let attention_mechanism = if paged_attn_config.is_some() {
163            AttentionImplementation::PagedAttention
164        } else {
165            AttentionImplementation::Eager
166        };
167
168        let model = match self.kind {
169            ModelKind::Normal => {
170                let vbs = paths
171                    .filenames
172                    .iter()
173                    .zip(self.inner.force_cpu_vb())
174                    .map(|(path, force_cpu)| {
175                        let dev = if force_cpu { &Device::Cpu } else { device };
176                        from_mmaped_safetensors(
177                            vec![path.clone()],
178                            Vec::new(),
179                            Some(dtype),
180                            dev,
181                            vec![None],
182                            silent,
183                            None,
184                            |_| true,
185                            Arc::new(|_| DeviceForLoadTensor::Base),
186                        )
187                    })
188                    .collect::<candle_core::Result<Vec<_>>>()?;
189
190                self.inner.load(
191                    configs,
192                    vbs,
193                    crate::pipeline::NormalLoadingMetadata {
194                        mapper,
195                        loading_isq: false,
196                        real_device: device.clone(),
197                        multi_progress: Arc::new(MultiProgress::new()),
198                    },
199                    attention_mechanism,
200                    silent,
201                )?
202            }
203            _ => unreachable!(),
204        };
205
206        let max_seq_len = model.max_seq_len();
207        Ok(Arc::new(Mutex::new(DiffusionPipeline {
208            model,
209            model_id: self.model_id.clone(),
210            metadata: Arc::new(GeneralMetadata {
211                max_seq_len,
212                llg_factory: None,
213                is_xlora: false,
214                no_prefix_cache: false,
215                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
216                eos_tok: vec![],
217                kind: self.kind.clone(),
218                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
219                activation_dtype: dtype,
220                sliding_window: None,
221                cache_config: None,
222                cache_engine: None,
223                prompt_chunksize: None,
224                model_metadata: None,
225            }),
226            dummy_cache: EitherCache::Full(Cache::new(0, false)),
227        })))
228    }
229
230    fn get_id(&self) -> String {
231        self.model_id.to_string()
232    }
233
234    fn get_kind(&self) -> ModelKind {
235        self.kind.clone()
236    }
237}
238
239impl PreProcessingMixin for DiffusionPipeline {
240    fn get_processor(&self) -> Arc<dyn Processor> {
241        Arc::new(DiffusionProcessor)
242    }
243    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
244        None
245    }
246    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
247        None
248    }
249}
250
251impl IsqPipelineMixin for DiffusionPipeline {
252    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
253        anyhow::bail!("Diffusion models do not support ISQ for now.")
254    }
255}
256
257impl CacheManagerMixin for DiffusionPipeline {
258    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
259    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
260    fn set_none_cache(
261        &self,
262        _seqs: &mut [&mut Sequence],
263        _reset_non_granular: bool,
264        _modify_draft_cache: bool,
265        _load_preallocated_cache: bool,
266    ) {
267    }
268    fn cache(&self) -> &EitherCache {
269        &self.dummy_cache
270    }
271}
272
273impl MetadataMixin for DiffusionPipeline {
274    fn device(&self) -> Device {
275        self.model.device().clone()
276    }
277    fn get_metadata(&self) -> Arc<GeneralMetadata> {
278        self.metadata.clone()
279    }
280    fn name(&self) -> String {
281        self.model_id.clone()
282    }
283    fn reset_non_granular_state(&self) {}
284    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
285        None
286    }
287    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
288        None
289    }
290}
291
292#[async_trait::async_trait]
293impl Pipeline for DiffusionPipeline {
294    fn forward_inputs(
295        &mut self,
296        inputs: Box<dyn Any>,
297        return_raw_logits: bool,
298    ) -> candle_core::Result<ForwardInputsResult> {
299        assert!(!return_raw_logits);
300
301        let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
302        let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
303        let (_b, c, h, w) = img.dims4()?;
304        let mut images = Vec::new();
305        for b_img in img.chunk(img.dim(0)?, 0)? {
306            let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
307            if c != 3 {
308                candle_core::bail!("Expected 3 channels in image output");
309            }
310            #[allow(clippy::cast_possible_truncation)]
311            images.push(DynamicImage::ImageRgb8(
312                RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
313                    candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
314                )?,
315            ));
316        }
317        Ok(ForwardInputsResult::Image { images })
318    }
319    async fn sample_causal_gen(
320        &self,
321        _seqs: &mut [&mut Sequence],
322        _logits: Vec<Tensor>,
323        _prefix_cacher: &mut PrefixCacheManagerV2,
324        _disable_eos_stop: bool,
325        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
326    ) -> Result<(), candle_core::Error> {
327        candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
328    }
329    fn category(&self) -> ModelCategory {
330        ModelCategory::Diffusion
331    }
332}
333
334impl AnyMoePipelineMixin for DiffusionPipeline {}