mistralrs/
diffusion_model.rs1use mistralrs_core::*;
2
3use crate::{best_device, Model};
4
5pub struct DiffusionModelBuilder {
7 pub(crate) model_id: String,
9 pub(crate) token_source: TokenSource,
10 pub(crate) hf_revision: Option<String>,
11
12 pub(crate) loader_type: DiffusionLoaderType,
14 pub(crate) dtype: ModelDType,
15 pub(crate) force_cpu: bool,
16
17 pub(crate) max_num_seqs: usize,
19 pub(crate) with_logging: bool,
20}
21
22impl DiffusionModelBuilder {
23 pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
27 Self {
28 model_id: model_id.to_string(),
29 loader_type,
30 dtype: ModelDType::Auto,
31 force_cpu: false,
32 token_source: TokenSource::CacheToken,
33 hf_revision: None,
34 max_num_seqs: 32,
35 with_logging: false,
36 }
37 }
38
39 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
41 self.dtype = dtype;
42 self
43 }
44
45 pub fn with_force_cpu(mut self) -> Self {
47 self.force_cpu = true;
48 self
49 }
50
51 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
53 self.token_source = token_source;
54 self
55 }
56
57 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
59 self.hf_revision = Some(revision.to_string());
60 self
61 }
62
63 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
65 self.max_num_seqs = max_num_seqs;
66 self
67 }
68
69 pub fn with_logging(mut self) -> Self {
71 self.with_logging = true;
72 self
73 }
74
75 pub async fn build(self) -> anyhow::Result<Model> {
76 if self.with_logging {
77 initialize_logging();
78 }
79
80 let loader = DiffusionLoaderBuilder::new(Some(self.model_id)).build(self.loader_type);
81
82 let pipeline = loader.load_model_from_hf(
84 self.hf_revision,
85 self.token_source,
86 &self.dtype,
87 &best_device(self.force_cpu)?,
88 !self.with_logging,
89 DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
90 None,
91 None,
92 )?;
93
94 let scheduler_method = SchedulerConfig::DefaultScheduler {
95 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
96 };
97
98 let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
99
100 Ok(Model::new(runner.build().await))
101 }
102}