mistralrs/
diffusion_model.rs1use mistralrs_core::*;
2
3use crate::{best_device, Model};
4
5pub struct DiffusionModelBuilder {
7 pub(crate) model_id: String,
9 pub(crate) token_source: TokenSource,
10 pub(crate) hf_revision: Option<String>,
11
12 pub(crate) loader_type: DiffusionLoaderType,
14 pub(crate) dtype: ModelDType,
15 pub(crate) force_cpu: bool,
16 pub(crate) use_flash_attn: bool,
17
18 pub(crate) max_num_seqs: usize,
20 pub(crate) with_logging: bool,
21}
22
23impl DiffusionModelBuilder {
24 pub fn new(model_id: impl ToString, loader_type: DiffusionLoaderType) -> Self {
28 Self {
29 model_id: model_id.to_string(),
30 use_flash_attn: cfg!(feature = "flash-attn"),
31 loader_type,
32 dtype: ModelDType::Auto,
33 force_cpu: false,
34 token_source: TokenSource::CacheToken,
35 hf_revision: None,
36 max_num_seqs: 32,
37 with_logging: false,
38 }
39 }
40
41 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
43 self.dtype = dtype;
44 self
45 }
46
47 pub fn with_force_cpu(mut self) -> Self {
49 self.force_cpu = true;
50 self
51 }
52
53 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
55 self.token_source = token_source;
56 self
57 }
58
59 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
61 self.hf_revision = Some(revision.to_string());
62 self
63 }
64
65 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
67 self.max_num_seqs = max_num_seqs;
68 self
69 }
70
71 pub fn with_logging(mut self) -> Self {
73 self.with_logging = true;
74 self
75 }
76
77 pub async fn build(self) -> anyhow::Result<Model> {
78 let config = DiffusionSpecificConfig {
79 use_flash_attn: self.use_flash_attn,
80 };
81
82 if self.with_logging {
83 initialize_logging();
84 }
85
86 let loader =
87 DiffusionLoaderBuilder::new(config, Some(self.model_id)).build(self.loader_type);
88
89 let pipeline = loader.load_model_from_hf(
91 self.hf_revision,
92 self.token_source,
93 &self.dtype,
94 &best_device(self.force_cpu)?,
95 !self.with_logging,
96 DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
97 None,
98 None,
99 )?;
100
101 let scheduler_method = SchedulerConfig::DefaultScheduler {
102 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
103 };
104
105 let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
106
107 Ok(Model::new(runner.build()))
108 }
109}