mistralrs_core/pipeline/
speech.rs

1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3    AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4    GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5    Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::varbuilder_utils::DeviceForLoadTensor;
14use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
15use crate::{
16    api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
17    SpeechGenerationConfig, TryIntoDType,
18};
19use anyhow::Result;
20use candle_core::{Device, Tensor};
21use candle_nn::VarBuilder;
22use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
23use indexmap::IndexMap;
24use mistralrs_quant::IsqType;
25use rand_isaac::Isaac64Rng;
26use regex::Regex;
27use std::any::Any;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokenizers::Tokenizer;
31use tokio::sync::Mutex;
32use tracing::info;
33
34#[derive(Clone, Debug)]
35pub struct SpeechModelPaths {
36    weights: Vec<PathBuf>,
37    config: PathBuf,
38}
39
40impl ModelPaths for SpeechModelPaths {
41    fn get_config_filename(&self) -> &PathBuf {
42        &self.config
43    }
44    fn get_tokenizer_filename(&self) -> &PathBuf {
45        unreachable!("Use `std::any::Any`.")
46    }
47    fn get_weight_filenames(&self) -> &[PathBuf] {
48        &self.weights
49    }
50    fn get_template_filename(&self) -> &Option<PathBuf> {
51        unreachable!("Use `std::any::Any`.")
52    }
53    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
54        unreachable!("Use `std::any::Any`.")
55    }
56    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
57        unreachable!("Use `std::any::Any`.")
58    }
59    fn get_processor_config(&self) -> &Option<PathBuf> {
60        unreachable!("Use `std::any::Any`.")
61    }
62    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
63        unreachable!("Use `std::any::Any`.")
64    }
65    fn get_adapter_paths(&self) -> &AdapterPaths {
66        unreachable!("Use `std::any::Any`.")
67    }
68    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
69        unreachable!("Use `std::any::Any`.")
70    }
71}
72
73pub struct SpeechProcessor;
74
75impl Processor for SpeechProcessor {
76    fn process(
77        &self,
78        _pipeline: &dyn Pipeline,
79        _messages: Vec<IndexMap<String, MessageContent>>,
80        _add_generation_prompt: bool,
81        _add_special_tokens: bool,
82        _enable_thinking: Option<bool>,
83        _tools: Vec<crate::Tool>,
84    ) -> Result<(Vec<u32>, String)> {
85        anyhow::bail!(
86            "SpeechProcessor::process should not be used. It does not expect chat messages."
87        )
88    }
89    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
90        Arc::new(SpeechInputsProcessor)
91    }
92    fn get_special_tokens(&self) -> &[&'static str] {
93        &[]
94    }
95    fn template_action(&self) -> MessagesAction {
96        // Just a default
97        MessagesAction::FlattenOnlyText
98    }
99}
100
101pub struct SpeechInputsProcessor;
102
103#[derive(Clone)]
104pub struct ModelInputs {
105    pub(crate) prompts: Vec<String>,
106}
107
108impl InputsProcessor for SpeechInputsProcessor {
109    fn get_type(&self) -> InputsProcessorType {
110        InputsProcessorType::Text
111    }
112
113    fn process_inputs(
114        &self,
115        _tokenizer: Option<Arc<Tokenizer>>,
116        input_seqs: &mut [&mut Sequence],
117        _is_prompt: bool,
118        _is_xlora: bool,
119        _device: &Device,
120        _no_kv_cache: bool,
121        _last_n_context_len: Option<(usize, usize)>,
122        _return_raw_logits: bool,
123        _other_config: Option<Arc<dyn Any>>,
124        _paged_attn_metadata: Option<PagedAttentionMeta>,
125        _mapper: Option<&dyn DeviceMapper>,
126    ) -> Result<InputProcessorOutput> {
127        let inputs = ModelInputs {
128            prompts: input_seqs
129                .iter()
130                .map(|seq| seq.get_initial_prompt().to_string())
131                .collect(),
132        };
133        Ok(InputProcessorOutput {
134            inputs: Box::new(inputs),
135            seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
136        })
137    }
138}
139
140pub struct SpeechPipeline {
141    model_id: String,
142    model: DiaPipeline,
143    metadata: Arc<GeneralMetadata>,
144    dummy_cache: EitherCache,
145    cfg: SpeechGenerationConfig,
146}
147
148pub struct SpeechLoader {
149    pub model_id: String,
150    pub dac_model_id: Option<String>,
151    pub arch: SpeechLoaderType,
152    pub cfg: Option<SpeechGenerationConfig>,
153}
154
155impl Loader for SpeechLoader {
156    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
157    fn load_model_from_hf(
158        &self,
159        revision: Option<String>,
160        token_source: TokenSource,
161        dtype: &dyn TryIntoDType,
162        device: &Device,
163        silent: bool,
164        mapper: DeviceMapSetting,
165        in_situ_quant: Option<IsqType>,
166        paged_attn_config: Option<PagedAttentionConfig>,
167    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
168        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
169            // Main weights first, DAC is the final one.
170            let mut weights = Vec::new();
171
172            // Main model
173            let config = {
174                let api = ApiBuilder::new()
175                    .with_progress(!silent)
176                    .with_token(get_token(&token_source)?)
177                    .build()?;
178                let revision = revision.clone().unwrap_or("main".to_string());
179                let api = api.repo(Repo::with_revision(
180                    self.model_id.to_string(),
181                    RepoType::Model,
182                    revision.clone(),
183                ));
184                let model_id = std::path::Path::new(&self.model_id);
185
186                let weight = api_get_file!(api, "model.safetensors", &model_id);
187                let config = api_get_file!(api, "config.json", &model_id);
188                weights.push(weight);
189                config
190            };
191
192            // DAC model
193            {
194                let api = ApiBuilder::new()
195                    .with_progress(!silent)
196                    .with_token(get_token(&token_source)?)
197                    .build()?;
198                let revision = revision.unwrap_or("main".to_string());
199
200                // Apply default here
201                let dac_model = self
202                    .dac_model_id
203                    .clone()
204                    .unwrap_or_else(|| match self.arch {
205                        SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
206                    });
207
208                let api = api.repo(Repo::with_revision(
209                    dac_model.clone(),
210                    RepoType::Model,
211                    revision.clone(),
212                ));
213                let model_id = std::path::Path::new(&dac_model);
214
215                let weight = api_get_file!(api, "model.safetensors", &model_id);
216                weights.push(weight);
217            }
218
219            Ok(Box::new(SpeechModelPaths { weights, config }))
220        };
221        self.load_model_from_path(
222            &paths?,
223            dtype,
224            device,
225            silent,
226            mapper,
227            in_situ_quant,
228            paged_attn_config,
229        )
230    }
231
232    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
233    fn load_model_from_path(
234        &self,
235        paths: &Box<dyn ModelPaths>,
236        dtype: &dyn TryIntoDType,
237        device: &Device,
238        silent: bool,
239        mapper: DeviceMapSetting,
240        in_situ_quant: Option<IsqType>,
241        _paged_attn_config: Option<PagedAttentionConfig>,
242    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
243        let paths = &paths
244            .as_ref()
245            .as_any()
246            .downcast_ref::<SpeechModelPaths>()
247            .expect("Path downcast failed.");
248
249        if matches!(mapper, DeviceMapSetting::Map(_)) {
250            anyhow::bail!("Device mapping is not supported for speech models.")
251        }
252
253        mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
254
255        let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
256
257        #[cfg(feature = "cuda")]
258        if let Device::Cuda(dev) = &device {
259            unsafe { dev.disable_event_tracking() };
260        }
261
262        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
263        let dtype = mapper.get_min_dtype(dtype)?;
264
265        // Last weight is the dac.
266        let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
267        let vb = from_mmaped_safetensors(
268            model_weights,
269            Vec::new(),
270            Some(dtype),
271            device,
272            vec![None],
273            silent,
274            None,
275            |_| true,
276            Arc::new(|_| DeviceForLoadTensor::Base),
277        )?;
278
279        let dac_vb = unsafe {
280            VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
281        };
282
283        // Only Dia is supported for now.
284        assert_eq!(self.arch, SpeechLoaderType::Dia);
285
286        let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
287
288        Ok(Arc::new(Mutex::new(SpeechPipeline {
289            model_id: self.model_id.clone(),
290            model,
291            metadata: Arc::new(GeneralMetadata {
292                max_seq_len: 1024,
293                llg_factory: None,
294                is_xlora: false,
295                no_prefix_cache: false,
296                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
297                eos_tok: vec![],
298                kind: ModelKind::Normal,
299                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
300                activation_dtype: dtype,
301                sliding_window: None,
302                cache_config: None,
303                cache_engine: None,
304                model_metadata: None,
305                modalities: Modalities {
306                    input: vec![SupportedModality::Text],
307                    output: vec![SupportedModality::Audio],
308                },
309            }),
310            dummy_cache: EitherCache::Full(Cache::new(0, false)),
311            cfg: self
312                .cfg
313                .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
314        })))
315    }
316
317    fn get_id(&self) -> String {
318        self.model_id.clone()
319    }
320
321    fn get_kind(&self) -> ModelKind {
322        ModelKind::Normal
323    }
324}
325
326impl PreProcessingMixin for SpeechPipeline {
327    fn get_processor(&self) -> Arc<dyn Processor> {
328        Arc::new(SpeechProcessor)
329    }
330    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
331        None
332    }
333    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
334        None
335    }
336}
337
338impl IsqPipelineMixin for SpeechPipeline {
339    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
340        anyhow::bail!("Speech models do not support ISQ for now.")
341    }
342}
343
344impl CacheManagerMixin for SpeechPipeline {
345    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
346    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
347    fn set_none_cache(
348        &self,
349        _seqs: &mut [&mut Sequence],
350        _reset_non_granular: bool,
351        _modify_draft_cache: bool,
352        _load_preallocated_cache: bool,
353    ) {
354    }
355    fn cache(&self) -> &EitherCache {
356        &self.dummy_cache
357    }
358}
359
360impl MetadataMixin for SpeechPipeline {
361    fn device(&self) -> Device {
362        self.model.device().clone()
363    }
364    fn get_metadata(&self) -> Arc<GeneralMetadata> {
365        self.metadata.clone()
366    }
367    fn name(&self) -> String {
368        self.model_id.clone()
369    }
370    fn reset_non_granular_state(&self) {}
371    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
372        None
373    }
374    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
375        None
376    }
377}
378
379#[async_trait::async_trait]
380impl Pipeline for SpeechPipeline {
381    fn forward_inputs(
382        &mut self,
383        inputs: Box<dyn Any>,
384        return_raw_logits: bool,
385    ) -> candle_core::Result<ForwardInputsResult> {
386        assert!(!return_raw_logits);
387
388        let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
389        let mut pcms = Vec::new();
390        let mut rates = Vec::new();
391        let mut channels_all = Vec::new();
392        for prompt in prompts {
393            let SpeechGenerationOutput {
394                pcm,
395                rate,
396                channels,
397            } = self.model.generate(&prompt, &self.cfg)?;
398            pcms.push(pcm);
399            rates.push(rate);
400            channels_all.push(channels);
401        }
402
403        Ok(ForwardInputsResult::Speech {
404            pcms,
405            rates,
406            channels: channels_all,
407        })
408    }
409
410    async fn sample_causal_gen(
411        &self,
412        _seqs: &mut [&mut Sequence],
413        _logits: Vec<Tensor>,
414        _prefix_cacher: &mut PrefixCacheManagerV2,
415        _disable_eos_stop: bool,
416        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
417    ) -> Result<(), candle_core::Error> {
418        candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
419    }
420
421    fn category(&self) -> ModelCategory {
422        ModelCategory::Speech
423    }
424}
425
426impl AnyMoePipelineMixin for SpeechPipeline {}