mistralrs_core/pipeline/
speech.rs

1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3    AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4    GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5    Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::progress::ProgressScopeGuard;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{
17    api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
18    SpeechGenerationConfig, TryIntoDType,
19};
20use anyhow::Result;
21use candle_core::{Device, Tensor};
22use candle_nn::VarBuilder;
23use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
24use indexmap::IndexMap;
25use mistralrs_quant::IsqType;
26use rand_isaac::Isaac64Rng;
27use regex::Regex;
28use std::any::Any;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokenizers::Tokenizer;
32use tokio::sync::Mutex;
33use tracing::info;
34
35#[derive(Clone, Debug)]
36pub struct SpeechModelPaths {
37    weights: Vec<PathBuf>,
38    config: PathBuf,
39}
40
41impl ModelPaths for SpeechModelPaths {
42    fn get_config_filename(&self) -> &PathBuf {
43        &self.config
44    }
45    fn get_tokenizer_filename(&self) -> &PathBuf {
46        unreachable!("Use `std::any::Any`.")
47    }
48    fn get_weight_filenames(&self) -> &[PathBuf] {
49        &self.weights
50    }
51    fn get_template_filename(&self) -> &Option<PathBuf> {
52        unreachable!("Use `std::any::Any`.")
53    }
54    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
55        unreachable!("Use `std::any::Any`.")
56    }
57    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
58        unreachable!("Use `std::any::Any`.")
59    }
60    fn get_processor_config(&self) -> &Option<PathBuf> {
61        unreachable!("Use `std::any::Any`.")
62    }
63    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
64        unreachable!("Use `std::any::Any`.")
65    }
66    fn get_adapter_paths(&self) -> &AdapterPaths {
67        unreachable!("Use `std::any::Any`.")
68    }
69    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
70        unreachable!("Use `std::any::Any`.")
71    }
72}
73
74pub struct SpeechProcessor;
75
76impl Processor for SpeechProcessor {
77    fn process(
78        &self,
79        _pipeline: &dyn Pipeline,
80        _messages: Vec<IndexMap<String, MessageContent>>,
81        _add_generation_prompt: bool,
82        _add_special_tokens: bool,
83        _enable_thinking: Option<bool>,
84        _tools: Vec<crate::Tool>,
85    ) -> Result<(Vec<u32>, String)> {
86        anyhow::bail!(
87            "SpeechProcessor::process should not be used. It does not expect chat messages."
88        )
89    }
90    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
91        Arc::new(SpeechInputsProcessor)
92    }
93    fn get_special_tokens(&self) -> &[&'static str] {
94        &[]
95    }
96    fn template_action(&self) -> MessagesAction {
97        // Just a default
98        MessagesAction::FlattenOnlyText
99    }
100}
101
102pub struct SpeechInputsProcessor;
103
104#[derive(Clone)]
105pub struct ModelInputs {
106    pub(crate) prompts: Vec<String>,
107}
108
109impl InputsProcessor for SpeechInputsProcessor {
110    fn get_type(&self) -> InputsProcessorType {
111        InputsProcessorType::Text
112    }
113
114    fn process_inputs(
115        &self,
116        _tokenizer: Option<Arc<Tokenizer>>,
117        input_seqs: &mut [&mut Sequence],
118        _is_prompt: bool,
119        _is_xlora: bool,
120        _device: &Device,
121        _no_kv_cache: bool,
122        _last_n_context_len: Option<(usize, usize)>,
123        _return_raw_logits: bool,
124        _other_config: Option<Arc<dyn Any>>,
125        _paged_attn_metadata: Option<PagedAttentionMeta>,
126        _mapper: Option<&dyn DeviceMapper>,
127    ) -> Result<InputProcessorOutput> {
128        let inputs = ModelInputs {
129            prompts: input_seqs
130                .iter()
131                .map(|seq| seq.get_initial_prompt().to_string())
132                .collect(),
133        };
134        Ok(InputProcessorOutput {
135            inputs: Box::new(inputs),
136            seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
137        })
138    }
139}
140
141pub struct SpeechPipeline {
142    model_id: String,
143    model: DiaPipeline,
144    metadata: Arc<GeneralMetadata>,
145    dummy_cache: EitherCache,
146    cfg: SpeechGenerationConfig,
147}
148
149pub struct SpeechLoader {
150    pub model_id: String,
151    pub dac_model_id: Option<String>,
152    pub arch: SpeechLoaderType,
153    pub cfg: Option<SpeechGenerationConfig>,
154}
155
156impl Loader for SpeechLoader {
157    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
158    fn load_model_from_hf(
159        &self,
160        revision: Option<String>,
161        token_source: TokenSource,
162        dtype: &dyn TryIntoDType,
163        device: &Device,
164        silent: bool,
165        mapper: DeviceMapSetting,
166        in_situ_quant: Option<IsqType>,
167        paged_attn_config: Option<PagedAttentionConfig>,
168    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
169        let _progress_guard = ProgressScopeGuard::new(silent);
170        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
171            // Main weights first, DAC is the final one.
172            let mut weights = Vec::new();
173
174            // Main model
175            let config = {
176                let api = ApiBuilder::new()
177                    .with_progress(!silent)
178                    .with_token(get_token(&token_source)?)
179                    .build()?;
180                let revision = revision.clone().unwrap_or("main".to_string());
181                let api = api.repo(Repo::with_revision(
182                    self.model_id.to_string(),
183                    RepoType::Model,
184                    revision.clone(),
185                ));
186                let model_id = std::path::Path::new(&self.model_id);
187
188                let weight = api_get_file!(api, "model.safetensors", &model_id);
189                let config = api_get_file!(api, "config.json", &model_id);
190                weights.push(weight);
191                config
192            };
193
194            // DAC model
195            {
196                let api = ApiBuilder::new()
197                    .with_progress(!silent)
198                    .with_token(get_token(&token_source)?)
199                    .build()?;
200                let revision = revision.unwrap_or("main".to_string());
201
202                // Apply default here
203                let dac_model = self
204                    .dac_model_id
205                    .clone()
206                    .unwrap_or_else(|| match self.arch {
207                        SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
208                    });
209
210                let api = api.repo(Repo::with_revision(
211                    dac_model.clone(),
212                    RepoType::Model,
213                    revision.clone(),
214                ));
215                let model_id = std::path::Path::new(&dac_model);
216
217                let weight = api_get_file!(api, "model.safetensors", &model_id);
218                weights.push(weight);
219            }
220
221            Ok(Box::new(SpeechModelPaths { weights, config }))
222        };
223        self.load_model_from_path(
224            &paths?,
225            dtype,
226            device,
227            silent,
228            mapper,
229            in_situ_quant,
230            paged_attn_config,
231        )
232    }
233
234    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
235    fn load_model_from_path(
236        &self,
237        paths: &Box<dyn ModelPaths>,
238        dtype: &dyn TryIntoDType,
239        device: &Device,
240        silent: bool,
241        mapper: DeviceMapSetting,
242        in_situ_quant: Option<IsqType>,
243        _paged_attn_config: Option<PagedAttentionConfig>,
244    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
245        let _progress_guard = ProgressScopeGuard::new(silent);
246        let paths = &paths
247            .as_ref()
248            .as_any()
249            .downcast_ref::<SpeechModelPaths>()
250            .expect("Path downcast failed.");
251
252        if matches!(mapper, DeviceMapSetting::Map(_)) {
253            anyhow::bail!("Device mapping is not supported for speech models.")
254        }
255
256        mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
257
258        let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
259
260        #[cfg(feature = "cuda")]
261        if let Device::Cuda(dev) = &device {
262            unsafe { dev.disable_event_tracking() };
263        }
264
265        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
266        let dtype = mapper.get_min_dtype(dtype)?;
267
268        // Last weight is the dac.
269        let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
270        let vb = from_mmaped_safetensors(
271            model_weights,
272            Vec::new(),
273            Some(dtype),
274            device,
275            vec![None],
276            silent,
277            None,
278            |_| true,
279            Arc::new(|_| DeviceForLoadTensor::Base),
280        )?;
281
282        let dac_vb = unsafe {
283            VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
284        };
285
286        // Only Dia is supported for now.
287        assert_eq!(self.arch, SpeechLoaderType::Dia);
288
289        let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
290
291        Ok(Arc::new(Mutex::new(SpeechPipeline {
292            model_id: self.model_id.clone(),
293            model,
294            metadata: Arc::new(GeneralMetadata {
295                max_seq_len: 1024,
296                llg_factory: None,
297                is_xlora: false,
298                no_prefix_cache: false,
299                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
300                eos_tok: vec![],
301                kind: ModelKind::Normal,
302                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
303                activation_dtype: dtype,
304                sliding_window: None,
305                cache_config: None,
306                cache_engine: None,
307                model_metadata: None,
308                modalities: Modalities {
309                    input: vec![SupportedModality::Text],
310                    output: vec![SupportedModality::Audio],
311                },
312            }),
313            dummy_cache: EitherCache::Full(Cache::new(0, false)),
314            cfg: self
315                .cfg
316                .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
317        })))
318    }
319
320    fn get_id(&self) -> String {
321        self.model_id.clone()
322    }
323
324    fn get_kind(&self) -> ModelKind {
325        ModelKind::Normal
326    }
327}
328
329impl PreProcessingMixin for SpeechPipeline {
330    fn get_processor(&self) -> Arc<dyn Processor> {
331        Arc::new(SpeechProcessor)
332    }
333    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
334        None
335    }
336    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
337        None
338    }
339}
340
341impl IsqPipelineMixin for SpeechPipeline {
342    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
343        anyhow::bail!("Speech models do not support ISQ for now.")
344    }
345}
346
347impl CacheManagerMixin for SpeechPipeline {
348    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
349    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
350    fn set_none_cache(
351        &self,
352        _seqs: &mut [&mut Sequence],
353        _reset_non_granular: bool,
354        _modify_draft_cache: bool,
355        _load_preallocated_cache: bool,
356    ) {
357    }
358    fn cache(&self) -> &EitherCache {
359        &self.dummy_cache
360    }
361}
362
363impl MetadataMixin for SpeechPipeline {
364    fn device(&self) -> Device {
365        self.model.device().clone()
366    }
367    fn get_metadata(&self) -> Arc<GeneralMetadata> {
368        self.metadata.clone()
369    }
370    fn name(&self) -> String {
371        self.model_id.clone()
372    }
373    fn reset_non_granular_state(&self) {}
374    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
375        None
376    }
377    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
378        None
379    }
380}
381
382#[async_trait::async_trait]
383impl Pipeline for SpeechPipeline {
384    fn forward_inputs(
385        &mut self,
386        inputs: Box<dyn Any>,
387        return_raw_logits: bool,
388    ) -> candle_core::Result<ForwardInputsResult> {
389        assert!(!return_raw_logits);
390
391        let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
392        let mut pcms = Vec::new();
393        let mut rates = Vec::new();
394        let mut channels_all = Vec::new();
395        for prompt in prompts {
396            let SpeechGenerationOutput {
397                pcm,
398                rate,
399                channels,
400            } = self.model.generate(&prompt, &self.cfg)?;
401            pcms.push(pcm);
402            rates.push(rate);
403            channels_all.push(channels);
404        }
405
406        Ok(ForwardInputsResult::Speech {
407            pcms,
408            rates,
409            channels: channels_all,
410        })
411    }
412
413    async fn sample_causal_gen(
414        &self,
415        _seqs: &mut [&mut Sequence],
416        _logits: Vec<Tensor>,
417        _prefix_cacher: &mut PrefixCacheManagerV2,
418        _disable_eos_stop: bool,
419        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
420    ) -> Result<(), candle_core::Error> {
421        candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
422    }
423
424    fn category(&self) -> ModelCategory {
425        ModelCategory::Speech
426    }
427}
428
429impl AnyMoePipelineMixin for SpeechPipeline {}