mistralrs/
speech_model.rs

1use mistralrs_core::*;
2
3use crate::{best_device, Model};
4
5/// Configure a text model with the various parameters for loading, running, and other inference behaviors.
6pub struct SpeechModelBuilder {
7    // Loading model
8    pub(crate) model_id: String,
9    pub(crate) dac_model_id: Option<String>,
10    pub(crate) token_source: TokenSource,
11    pub(crate) hf_revision: Option<String>,
12    pub(crate) cfg: Option<SpeechGenerationConfig>,
13
14    // Model running
15    pub(crate) loader_type: SpeechLoaderType,
16    pub(crate) dtype: ModelDType,
17    pub(crate) force_cpu: bool,
18
19    // Other things
20    pub(crate) max_num_seqs: usize,
21    pub(crate) with_logging: bool,
22}
23
24impl SpeechModelBuilder {
25    /// A few defaults are applied here:
26    /// - Token source is from the cache (.cache/huggingface/token)
27    /// - Maximum number of sequences running is 32
28    pub fn new(model_id: impl ToString, loader_type: SpeechLoaderType) -> Self {
29        Self {
30            model_id: model_id.to_string(),
31            loader_type,
32            dtype: ModelDType::Auto,
33            force_cpu: false,
34            token_source: TokenSource::CacheToken,
35            hf_revision: None,
36            max_num_seqs: 32,
37            with_logging: false,
38            cfg: None,
39            dac_model_id: None,
40        }
41    }
42
43    /// DAC Model ID to load from. If not provided, this is automatically downloaded from the default path for the model.
44    /// This may be a HF hub repo or a local path.
45    pub fn with_dac_model_id(mut self, dac_model_id: String) -> Self {
46        self.dac_model_id = Some(dac_model_id);
47        self
48    }
49
50    /// Load the model in a certain dtype.
51    pub fn with_dtype(mut self, dtype: ModelDType) -> Self {
52        self.dtype = dtype;
53        self
54    }
55
56    /// Force usage of the CPU device. Do not use PagedAttention with this.
57    pub fn with_force_cpu(mut self) -> Self {
58        self.force_cpu = true;
59        self
60    }
61
62    /// Source of the Hugging Face token.
63    pub fn with_token_source(mut self, token_source: TokenSource) -> Self {
64        self.token_source = token_source;
65        self
66    }
67
68    /// Set the revision to use for a Hugging Face remote model.
69    pub fn with_hf_revision(mut self, revision: impl ToString) -> Self {
70        self.hf_revision = Some(revision.to_string());
71        self
72    }
73
74    /// Set the maximum number of sequences which can be run at once.
75    pub fn with_max_num_seqs(mut self, max_num_seqs: usize) -> Self {
76        self.max_num_seqs = max_num_seqs;
77        self
78    }
79
80    /// Enable logging.
81    pub fn with_logging(mut self) -> Self {
82        self.with_logging = true;
83        self
84    }
85
86    pub async fn build(self) -> anyhow::Result<Model> {
87        if self.with_logging {
88            initialize_logging();
89        }
90
91        let loader = SpeechLoader {
92            model_id: self.model_id,
93            dac_model_id: self.dac_model_id,
94            arch: self.loader_type,
95            cfg: self.cfg,
96        };
97
98        // Load, into a Pipeline
99        let pipeline = loader.load_model_from_hf(
100            self.hf_revision,
101            self.token_source,
102            &self.dtype,
103            &best_device(self.force_cpu)?,
104            !self.with_logging,
105            DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
106            None,
107            None,
108        )?;
109
110        let scheduler_method = SchedulerConfig::DefaultScheduler {
111            method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
112        };
113
114        let runner = MistralRsBuilder::new(pipeline, scheduler_method, false, None);
115
116        Ok(Model::new(runner.build().await))
117    }
118}