mistralrs/
speech_model.rs1use mistralrs_core::*;
2
3use crate::model_builder_trait::{build_model_from_pipeline, build_speech_pipeline};
4use crate::Model;
5
6pub struct SpeechModelBuilder {
8 pub(crate) model_id: String,
10 pub(crate) dac_model_id: Option<String>,
11 pub(crate) token_source: TokenSource,
12 pub(crate) hf_revision: Option<String>,
13 pub(crate) cfg: Option<SpeechGenerationConfig>,
14
15 pub(crate) loader_type: SpeechLoaderType,
17 pub(crate) dtype: ModelDType,
18 pub(crate) force_cpu: bool,
19
20 pub(crate) max_num_seqs: usize,
22 pub(crate) with_logging: bool,
23}
24
25impl SpeechModelBuilder {
26 pub fn new(model_id: impl ToString, loader_type: SpeechLoaderType) -> Self {
30 Self {
31 model_id: model_id.to_string(),
32 loader_type,
33 dtype: ModelDType::Auto,
34 force_cpu: false,
35 token_source: TokenSource::CacheToken,
36 hf_revision: None,
37 max_num_seqs: 32,
38 with_logging: false,
39 cfg: None,
40 dac_model_id: None,
41 }
42 }
43
44 pub fn with_dac_model_id(mut self, dac_model_id: String) -> Self {
47 self.dac_model_id = Some(dac_model_id);
48 self
49 }
50
51 pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
53 self.dtype = dtype;
54 self
55 }
56
57 pub fn with_force_cpu(mut self) -> Self {
59 self.force_cpu = true;
60 self
61 }
62
63 pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
65 self.token_source = token_source;
66 self
67 }
68
69 pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
71 self.hf_revision = Some(revision.to_string());
72 self
73 }
74
75 pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
77 self.max_num_seqs = max_num_seqs;
78 self
79 }
80
81 pub fn with_logging(mut self) -> Self {
83 self.with_logging = true;
84 self
85 }
86
87 pub async fn build(self) -> anyhow::Result<Model> {
88 let (pipeline, scheduler_config, add_model_config) = build_speech_pipeline(self).await?;
89 Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
90 }
91}