mistralrs/
diffusion_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model};
4
5/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
6pub struct DiffusionModelBuilder {
7    // Loading model
8    pub(crate) model_id: String,
9    pub(crate) token_source: TokenSource,
10    pub(crate) hf_revision: Option<String>,
11
12    // Model running
13    pub(crate) loader_type: DiffusionLoaderType,
14    pub(crate) dtype: ModelDType,
15    pub(crate) force_cpu: bool,
16
17    // Other things
18    pub(crate) max_num_seqs: usize,
19    pub(crate) with_logging: bool,
20}
21
22impl DiffusionModelBuilder {
23    /// A few defaults are applied here:
24    /// - Token source is from the cache (.cache/huggingface/token)
25    /// - Maximum number of sequences running is 32
26    pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
27        Self {
28            model_id: model_id.to_string(),
29            loader_type,
30            dtype: ModelDType::Auto,
31            force_cpu: false,
32            token_source: TokenSource::CacheToken,
33            hf_revision: None,
34            max_num_seqs: 32,
35            with_logging: false,
36        }
37    }
38
39    /// Load the model in a certain dtype.
40    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
41        self.dtype = dtype;
42        self
43    }
44
45    /// Force usage of the CPU device. Do not use PagedAttention with this.
46    pub fn with_force_cpu(mut self) -> Self {
47        self.force_cpu = true;
48        self
49    }
50
51    /// Source of the Hugging Face token.
52    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
53        self.token_source = token_source;
54        self
55    }
56
57    /// Set the revision to use for a Hugging Face remote model.
58    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
59        self.hf_revision = Some(revision.to_string());
60        self
61    }
62
63    /// Set the maximum number of sequences which can be run at once.
64    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
65        self.max_num_seqs = max_num_seqs;
66        self
67    }
68
69    /// Enable logging.
70    pub fn with_logging(mut self) -> Self {
71        self.with_logging = true;
72        self
73    }
74
75    pub async fn build(self) -> anyhow::Result<Model> {
76        if self.with_logging {
77            initialize_logging();
78        }
79
80        let loader = DiffusionLoaderBuilder::new(Some(self.model_id)).build(self.loader_type);
81
82        // Load, into a Pipeline
83        let pipeline = loader.load_model_from_hf(
84            self.hf_revision,
85            self.token_source,
86            &self.dtype,
87            &best_device(self.force_cpu)?,
88            !self.with_logging,
89            DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
90            None,
91            None,
92        )?;
93
94        let scheduler_method = SchedulerConfig::DefaultScheduler {
95            method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
96        };
97
98        let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
99
100        Ok(Model::new(runner.build().await))
101    }
102}