mistralrs/
speech_model.rs1use mistralrs_core::*;
2
3use crate::{best_device, Model};
4
5pub struct SpeechModelBuilder {
7 pub(crate) model_id: String,
9 pub(crate) dac_model_id: Option<String>,
10 pub(crate) token_source: TokenSource,
11 pub(crate) hf_revision: Option<String>,
12 pub(crate) cfg: Option<SpeechGenerationConfig>,
13
14 pub(crate) loader_type: SpeechLoaderType,
16 pub(crate) dtype: ModelDType,
17 pub(crate) force_cpu: bool,
18
19 pub(crate) max_num_seqs: usize,
21 pub(crate) with_logging: bool,
22}
23
24impl SpeechModelBuilder {
25 pub fn new(model_id: impl ToString, loader_type: SpeechLoaderType) -> Self {
29 Self {
30 model_id: model_id.to_string(),
31 loader_type,
32 dtype: ModelDType::Auto,
33 force_cpu: false,
34 token_source: TokenSource::CacheToken,
35 hf_revision: None,
36 max_num_seqs: 32,
37 with_logging: false,
38 cfg: None,
39 dac_model_id: None,
40 }
41 }
42
43 pub fn with_dac_model_id(mut self, dac_model_id: String) -> Self {
46 self.dac_model_id = Some(dac_model_id);
47 self
48 }
49
50 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
52 self.dtype = dtype;
53 self
54 }
55
56 pub fn with_force_cpu(mut self) -> Self {
58 self.force_cpu = true;
59 self
60 }
61
62 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
64 self.token_source = token_source;
65 self
66 }
67
68 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
70 self.hf_revision = Some(revision.to_string());
71 self
72 }
73
74 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
76 self.max_num_seqs = max_num_seqs;
77 self
78 }
79
80 pub fn with_logging(mut self) -> Self {
82 self.with_logging = true;
83 self
84 }
85
86 pub async fn build(self) -> anyhow::Result<Model> {
87 if self.with_logging {
88 initialize_logging();
89 }
90
91 let loader = SpeechLoader {
92 model_id: self.model_id,
93 dac_model_id: self.dac_model_id,
94 arch: self.loader_type,
95 cfg: self.cfg,
96 };
97
98 let pipeline = loader.load_model_from_hf(
100 self.hf_revision,
101 self.token_source,
102 &self.dtype,
103 &best_device(self.force_cpu)?,
104 !self.with_logging,
105 DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
106 None,
107 None,
108 )?;
109
110 let scheduler_method = SchedulerConfig::DefaultScheduler {
111 method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
112 };
113
114 let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
115
116 Ok(Model::new(runner.build().await))
117 }
118}