mistralrs/
speech_model.rs

1use mistralrs_core::*;
2
3use crate::model_builder_trait::{build_model_from_pipeline, build_speech_pipeline};
4use crate::Model;
5
6/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
7pub struct SpeechModelBuilder {
8    // Loading model
9    pub(crate) model_id: String,
10    pub(crate) dac_model_id: Option<String>,
11    pub(crate) token_source: TokenSource,
12    pub(crate) hf_revision: Option<String>,
13    pub(crate) cfg: Option<SpeechGenerationConfig>,
14
15    // Model running
16    pub(crate) loader_type: SpeechLoaderType,
17    pub(crate) dtype: ModelDType,
18    pub(crate) force_cpu: bool,
19
20    // Other things
21    pub(crate) max_num_seqs: usize,
22    pub(crate) with_logging: bool,
23}
24
25impl SpeechModelBuilder {
26    /// A few defaults are applied here:
27    /// - Token source is from the cache (.cache/huggingface/token)
28    /// - Maximum number of sequences running is 32
29    pub fn new(model_id: impl ToString, loader_type: SpeechLoaderType) -> Self {
30        Self {
31            model_id: model_id.to_string(),
32            loader_type,
33            dtype: ModelDType::Auto,
34            force_cpu: false,
35            token_source: TokenSource::CacheToken,
36            hf_revision: None,
37            max_num_seqs: 32,
38            with_logging: false,
39            cfg: None,
40            dac_model_id: None,
41        }
42    }
43
44    /// DAC Model ID to load from. If not provided, this is automatically downloaded from the default path for the model.
45    /// This may be a HF hub repo or a local path.
46    pub fn with_dac_model_id(mut self, dac_model_id: String) -> Self {
47        self.dac_model_id = Some(dac_model_id);
48        self
49    }
50
51    /// Load the model in a certain dtype.
52    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
53        self.dtype = dtype;
54        self
55    }
56
57    /// Force usage of the CPU device. Do not use PagedAttention with this.
58    pub fn with_force_cpu(mut self) -> Self {
59        self.force_cpu = true;
60        self
61    }
62
63    /// Source of the Hugging Face token.
64    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
65        self.token_source = token_source;
66        self
67    }
68
69    /// Set the revision to use for a Hugging Face remote model.
70    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
71        self.hf_revision = Some(revision.to_string());
72        self
73    }
74
75    /// Set the maximum number of sequences which can be run at once.
76    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
77        self.max_num_seqs = max_num_seqs;
78        self
79    }
80
81    /// Enable logging.
82    pub fn with_logging(mut self) -> Self {
83        self.with_logging = true;
84        self
85    }
86
87    pub async fn build(self) -> anyhow::Result<Model> {
88        let (pipeline, scheduler_config, add_model_config) = build_speech_pipeline(self).await?;
89        Ok(build_model_from_pipeline(pipeline, scheduler_config, add_model_config).await)
90    }
91}