mistralrs/
diffusion_model.rs

1use mistralrs_core::*;
2
3use crate::model_builder_trait::{build_diffusion_pipeline, build_model_from_pipeline};
4use crate::Model;
5
6/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
7pub struct DiffusionModelBuilder {
8    // Loading model
9    pub(crate) model_id: String,
10    pub(crate) token_source: TokenSource,
11    pub(crate) hf_revision: Option<String>,
12
13    // Model running
14    pub(crate) loader_type: DiffusionLoaderType,
15    pub(crate) dtype: ModelDType,
16    pub(crate) force_cpu: bool,
17
18    // Other things
19    pub(crate) max_num_seqs: usize,
20    pub(crate) with_logging: bool,
21}
22
23impl DiffusionModelBuilder {
24    /// A few defaults are applied here:
25    /// - Token source is from the cache (.cache/huggingface/token)
26    /// - Maximum number of sequences running is 32
27    pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
28        Self {
29            model_id: model_id.to_string(),
30            loader_type,
31            dtype: ModelDType::Auto,
32            force_cpu: false,
33            token_source: TokenSource::CacheToken,
34            hf_revision: None,
35            max_num_seqs: 32,
36            with_logging: false,
37        }
38    }
39
40    /// Load the model in a certain dtype.
41    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
42        self.dtype = dtype;
43        self
44    }
45
46    /// Force usage of the CPU device. Do not use PagedAttention with this.
47    pub fn with_force_cpu(mut self) -> Self {
48        self.force_cpu = true;
49        self
50    }
51
52    /// Source of the Hugging Face token.
53    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
54        self.token_source = token_source;
55        self
56    }
57
58    /// Set the revision to use for a Hugging Face remote model.
59    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
60        self.hf_revision = Some(revision.to_string());
61        self
62    }
63
64    /// Set the maximum number of sequences which can be run at once.
65    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
66        self.max_num_seqs = max_num_seqs;
67        self
68    }
69
70    /// Enable logging.
71    pub fn with_logging(mut self) -> Self {
72        self.with_logging = true;
73        self
74    }
75
76    pub async fn build(self) -> anyhow::Result<Model> {
77        let (pipeline, scheduler_config, add_model_config) = build_diffusion_pipeline(self).await?;
78        Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
79    }
80}