mistralrs_core/pipeline/
diffusion.rs

1use super::loaders::{DiffusionModelPaths, DiffusionModelPathsInner};
2use super::{
3    AnyMoePipelineMixin, Cache, CacheManagerMixin, DiffusionLoaderType, DiffusionModel,
4    DiffusionModelLoader, EitherCache, FluxLoader, ForwardInputsResult, GeneralMetadata,
5    IsqPipelineMixin, Loader, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::diffusion_models::processor::{DiffusionProcessor, ModelInputs};
10use crate::paged_attention::AttentionImplementation;
11use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
12use crate::prefix_cacher::PrefixCacheManagerV2;
13use crate::sequence::Sequence;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{DeviceMapSetting, PagedAttentionConfig, Pipeline, TryIntoDType};
17use anyhow::Result;
18use candle_core::{DType, Device, Tensor};
19use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
20use image::{DynamicImage, RgbImage};
21use indicatif::MultiProgress;
22use mistralrs_quant::log::once_log_info;
23use mistralrs_quant::IsqType;
24use rand_isaac::Isaac64Rng;
25use std::any::Any;
26use std::io;
27use std::sync::Arc;
28use tokenizers::Tokenizer;
29use tokio::sync::Mutex;
30use tracing::warn;
31
32pub struct DiffusionPipeline {
33    model: Box<dyn DiffusionModel + Send + Sync>,
34    model_id: String,
35    metadata: Arc<GeneralMetadata>,
36    dummy_cache: EitherCache,
37}
38
39/// A loader for a vision (non-quantized) model.
40pub struct DiffusionLoader {
41    inner: Box<dyn DiffusionModelLoader>,
42    model_id: String,
43    kind: ModelKind,
44}
45
46#[derive(Default)]
47/// A builder for a loader for a vision (non-quantized) model.
48pub struct DiffusionLoaderBuilder {
49    model_id: Option<String>,
50    kind: ModelKind,
51}
52
53impl DiffusionLoaderBuilder {
54    pub fn new(model_id: Option<String>) -> Self {
55        Self {
56            model_id,
57            kind: ModelKind::Normal,
58        }
59    }
60
61    pub fn build(self, loader: DiffusionLoaderType) -> Box<dyn Loader> {
62        let loader: Box<dyn DiffusionModelLoader> = match loader {
63            DiffusionLoaderType::Flux => Box::new(FluxLoader { offload: false }),
64            DiffusionLoaderType::FluxOffloaded => Box::new(FluxLoader { offload: true }),
65        };
66        Box::new(DiffusionLoader {
67            inner: loader,
68            model_id: self.model_id.unwrap(),
69            kind: self.kind,
70        })
71    }
72}
73
74impl Loader for DiffusionLoader {
75    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
76    fn load_model_from_hf(
77        &self,
78        revision: Option<String>,
79        token_source: TokenSource,
80        dtype: &dyn TryIntoDType,
81        device: &Device,
82        silent: bool,
83        mapper: DeviceMapSetting,
84        in_situ_quant: Option<IsqType>,
85        paged_attn_config: Option<PagedAttentionConfig>,
86    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
87        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
88            let api = ApiBuilder::new()
89                .with_progress(!silent)
90                .with_token(get_token(&token_source)?)
91                .build()?;
92            let revision = revision.unwrap_or("main".to_string());
93            let api = api.repo(Repo::with_revision(
94                self.model_id.clone(),
95                RepoType::Model,
96                revision.clone(),
97            ));
98            let model_id = std::path::Path::new(&self.model_id);
99            let filenames = self.inner.get_model_paths(&api, model_id)?;
100            let config_filenames = self.inner.get_config_filenames(&api, model_id)?;
101            Ok(Box::new(DiffusionModelPaths(DiffusionModelPathsInner {
102                config_filenames,
103                filenames,
104            })))
105        };
106        self.load_model_from_path(
107            &paths?,
108            dtype,
109            device,
110            silent,
111            mapper,
112            in_situ_quant,
113            paged_attn_config,
114        )
115    }
116
117    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
118    fn load_model_from_path(
119        &self,
120        paths: &Box<dyn ModelPaths>,
121        dtype: &dyn TryIntoDType,
122        device: &Device,
123        silent: bool,
124        mapper: DeviceMapSetting,
125        in_situ_quant: Option<IsqType>,
126        mut paged_attn_config: Option<PagedAttentionConfig>,
127    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
128        let paths = &paths
129            .as_ref()
130            .as_any()
131            .downcast_ref::<DiffusionModelPaths>()
132            .expect("Path downcast failed.")
133            .0;
134
135        if matches!(mapper, DeviceMapSetting::Map(_)) {
136            anyhow::bail!("Device mapping is not supported for diffusion models.")
137        }
138
139        if in_situ_quant.is_some() {
140            anyhow::bail!("ISQ is not supported for Diffusion models.");
141        }
142
143        if paged_attn_config.is_some() {
144            warn!("PagedAttention is not supported for Diffusion models, disabling it.");
145
146            paged_attn_config = None;
147        }
148
149        if crate::using_flash_attn() {
150            once_log_info("FlashAttention is enabled.");
151        }
152
153        let configs = paths
154            .config_filenames
155            .iter()
156            .map(std::fs::read_to_string)
157            .collect::<io::Result<Vec<_>>>()?;
158
159        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
160        let dtype = mapper.get_min_dtype(dtype)?;
161
162        let attention_mechanism = if paged_attn_config.is_some() {
163            AttentionImplementation::PagedAttention
164        } else {
165            AttentionImplementation::Eager
166        };
167
168        let model = match self.kind {
169            ModelKind::Normal => {
170                let vbs = paths
171                    .filenames
172                    .iter()
173                    .zip(self.inner.force_cpu_vb())
174                    .map(|(path, force_cpu)| {
175                        let dev = if force_cpu { &Device::Cpu } else { device };
176                        from_mmaped_safetensors(
177                            vec![path.clone()],
178                            Vec::new(),
179                            Some(dtype),
180                            dev,
181                            vec![None],
182                            silent,
183                            None,
184                            |_| true,
185                            Arc::new(|_| DeviceForLoadTensor::Base),
186                        )
187                    })
188                    .collect::<candle_core::Result<Vec<_>>>()?;
189
190                self.inner.load(
191                    configs,
192                    vbs,
193                    crate::pipeline::NormalLoadingMetadata {
194                        mapper,
195                        loading_isq: false,
196                        real_device: device.clone(),
197                        multi_progress: Arc::new(MultiProgress::new()),
198                    },
199                    attention_mechanism,
200                    silent,
201                )?
202            }
203            _ => unreachable!(),
204        };
205
206        let max_seq_len = model.max_seq_len();
207        Ok(Arc::new(Mutex::new(DiffusionPipeline {
208            model,
209            model_id: self.model_id.clone(),
210            metadata: Arc::new(GeneralMetadata {
211                max_seq_len,
212                llg_factory: None,
213                is_xlora: false,
214                no_prefix_cache: false,
215                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
216                eos_tok: vec![],
217                kind: self.kind.clone(),
218                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
219                activation_dtype: dtype,
220                sliding_window: None,
221                cache_config: None,
222                cache_engine: None,
223                prompt_chunksize: None,
224                model_metadata: None,
225                modalities: Modalities {
226                    input: vec![SupportedModality::Text],
227                    output: vec![SupportedModality::Vision],
228                },
229            }),
230            dummy_cache: EitherCache::Full(Cache::new(0, false)),
231        })))
232    }
233
234    fn get_id(&self) -> String {
235        self.model_id.to_string()
236    }
237
238    fn get_kind(&self) -> ModelKind {
239        self.kind.clone()
240    }
241}
242
243impl PreProcessingMixin for DiffusionPipeline {
244    fn get_processor(&self) -> Arc<dyn Processor> {
245        Arc::new(DiffusionProcessor)
246    }
247    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
248        None
249    }
250    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
251        None
252    }
253}
254
255impl IsqPipelineMixin for DiffusionPipeline {
256    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
257        anyhow::bail!("Diffusion models do not support ISQ for now.")
258    }
259}
260
261impl CacheManagerMixin for DiffusionPipeline {
262    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
263    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
264    fn set_none_cache(
265        &self,
266        _seqs: &mut [&mut Sequence],
267        _reset_non_granular: bool,
268        _modify_draft_cache: bool,
269        _load_preallocated_cache: bool,
270    ) {
271    }
272    fn cache(&self) -> &EitherCache {
273        &self.dummy_cache
274    }
275}
276
277impl MetadataMixin for DiffusionPipeline {
278    fn device(&self) -> Device {
279        self.model.device().clone()
280    }
281    fn get_metadata(&self) -> Arc<GeneralMetadata> {
282        self.metadata.clone()
283    }
284    fn name(&self) -> String {
285        self.model_id.clone()
286    }
287    fn reset_non_granular_state(&self) {}
288    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
289        None
290    }
291    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
292        None
293    }
294}
295
296#[async_trait::async_trait]
297impl Pipeline for DiffusionPipeline {
298    fn forward_inputs(
299        &mut self,
300        inputs: Box<dyn Any>,
301        return_raw_logits: bool,
302    ) -> candle_core::Result<ForwardInputsResult> {
303        assert!(!return_raw_logits);
304
305        let ModelInputs { prompts, params } = *inputs.downcast().expect("Downcast failed.");
306        let img = self.model.forward(prompts, params)?.to_dtype(DType::U8)?;
307        let (_b, c, h, w) = img.dims4()?;
308        let mut images = Vec::new();
309        for b_img in img.chunk(img.dim(0)?, 0)? {
310            let flattened = b_img.squeeze(0)?.permute((1, 2, 0))?.flatten_all()?;
311            if c != 3 {
312                candle_core::bail!("Expected 3 channels in image output");
313            }
314            #[allow(clippy::cast_possible_truncation)]
315            images.push(DynamicImage::ImageRgb8(
316                RgbImage::from_raw(w as u32, h as u32, flattened.to_vec1::<u8>()?).ok_or(
317                    candle_core::Error::Msg("RgbImage has invalid capacity.".to_string()),
318                )?,
319            ));
320        }
321        Ok(ForwardInputsResult::Image { images })
322    }
323    async fn sample_causal_gen(
324        &self,
325        _seqs: &mut [&mut Sequence],
326        _logits: Vec<Tensor>,
327        _prefix_cacher: &mut PrefixCacheManagerV2,
328        _disable_eos_stop: bool,
329        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
330    ) -> Result<(), candle_core::Error> {
331        candle_core::bail!("`sample_causal_gen` is incompatible with `DiffusionPipeline`");
332    }
333    fn category(&self) -> ModelCategory {
334        ModelCategory::Diffusion
335    }
336}
337
338impl AnyMoePipelineMixin for DiffusionPipeline {}