mistralrs/
diffusion_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model};
4
5/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
6pub struct DiffusionModelBuilder {
7    // Loading model
8    pub(crate) model_id: String,
9    pub(crate) token_source: TokenSource,
10    pub(crate) hf_revision: Option<String>,
11
12    // Model running
13    pub(crate) loader_type: DiffusionLoaderType,
14    pub(crate) dtype: ModelDType,
15    pub(crate) force_cpu: bool,
16    pub(crate) use_flash_attn: bool,
17
18    // Other things
19    pub(crate) max_num_seqs: usize,
20    pub(crate) with_logging: bool,
21}
22
23impl DiffusionModelBuilder {
24    /// A few defaults are applied here:
25    /// - Token source is from the cache (.cache/huggingface/token)
26    /// - Maximum number of sequences running is 32
27    pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
28        Self {
29            model_id: model_id.to_string(),
30            use_flash_attn: cfg!(feature = "flash-attn"),
31            loader_type,
32            dtype: ModelDType::Auto,
33            force_cpu: false,
34            token_source: TokenSource::CacheToken,
35            hf_revision: None,
36            max_num_seqs: 32,
37            with_logging: false,
38        }
39    }
40
41    /// Load the model in a certain dtype.
42    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
43        self.dtype = dtype;
44        self
45    }
46
47    /// Force usage of the CPU device. Do not use PagedAttention with this.
48    pub fn with_force_cpu(mut self) -> Self {
49        self.force_cpu = true;
50        self
51    }
52
53    /// Source of the Hugging Face token.
54    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
55        self.token_source = token_source;
56        self
57    }
58
59    /// Set the revision to use for a Hugging Face remote model.
60    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
61        self.hf_revision = Some(revision.to_string());
62        self
63    }
64
65    /// Set the maximum number of sequences which can be run at once.
66    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
67        self.max_num_seqs = max_num_seqs;
68        self
69    }
70
71    /// Enable logging.
72    pub fn with_logging(mut self) -> Self {
73        self.with_logging = true;
74        self
75    }
76
77    pub async fn build(self) -> anyhow::Result<Model> {
78        let config = DiffusionSpecificConfig {
79            use_flash_attn: self.use_flash_attn,
80        };
81
82        if self.with_logging {
83            initialize_logging();
84        }
85
86        let loader =
87            DiffusionLoaderBuilder::new(config, Some(self.model_id)).build(self.loader_type);
88
89        // Load, into a Pipeline
90        let pipeline = loader.load_model_from_hf(
91            self.hf_revision,
92            self.token_source,
93            &self.dtype,
94            &best_device(self.force_cpu)?,
95            !self.with_logging,
96            DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
97            None,
98            None,
99        )?;
100
101        let scheduler_method = SchedulerConfig::DefaultScheduler {
102            method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
103        };
104
105        let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
106
107        Ok(Model::new(runner.build()))
108    }
109}