mistralrs/
diffusion_model.rs1use mistralrs_core::*;
2
3use crate::model_builder_trait::{build_diffusion_pipeline, build_model_from_pipeline};
4use crate::Model;
5
6pub struct DiffusionModelBuilder {
8 pub(crate) model_id: String,
10 pub(crate) token_source: TokenSource,
11 pub(crate) hf_revision: Option<String>,
12
13 pub(crate) loader_type: DiffusionLoaderType,
15 pub(crate) dtype: ModelDType,
16 pub(crate) force_cpu: bool,
17
18 pub(crate) max_num_seqs: usize,
20 pub(crate) with_logging: bool,
21}
22
23impl DiffusionModelBuilder {
24 pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
28 Self {
29 model_id: model_id.to_string(),
30 loader_type,
31 dtype: ModelDType::Auto,
32 force_cpu: false,
33 token_source: TokenSource::CacheToken,
34 hf_revision: None,
35 max_num_seqs: 32,
36 with_logging: false,
37 }
38 }
39
40 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
42 self.dtype = dtype;
43 self
44 }
45
46 pub fn with_force_cpu(mut self) -> Self {
48 self.force_cpu = true;
49 self
50 }
51
52 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
54 self.token_source = token_source;
55 self
56 }
57
58 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
60 self.hf_revision = Some(revision.to_string());
61 self
62 }
63
64 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
66 self.max_num_seqs = max_num_seqs;
67 self
68 }
69
70 pub fn with_logging(mut self) -> Self {
72 self.with_logging = true;
73 self
74 }
75
76 pub async fn build(self) -> anyhow::Result<Model> {
77 let (pipeline, scheduler_config, add_model_config) = build_diffusion_pipeline(self).await?;
78 Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
79 }
80}