mistralrs_core/pipeline/
speech.rs

1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3    AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4    GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5    Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::varbuilder_utils::DeviceForLoadTensor;
14use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
15use crate::{
16    api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
17    SpeechGenerationConfig, TryIntoDType,
18};
19use anyhow::Result;
20use candle_core::{Device, Tensor};
21use candle_nn::VarBuilder;
22use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
23use indexmap::IndexMap;
24use mistralrs_quant::IsqType;
25use rand_isaac::Isaac64Rng;
26use regex::Regex;
27use std::any::Any;
28use std::num::NonZeroUsize;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokenizers::Tokenizer;
32use tokio::sync::Mutex;
33use tracing::info;
34
35#[derive(Clone, Debug)]
36pub struct SpeechModelPaths {
37    weights: Vec<PathBuf>,
38    config: PathBuf,
39}
40
41impl ModelPaths for SpeechModelPaths {
42    fn get_config_filename(&self) -> &PathBuf {
43        &self.config
44    }
45    fn get_tokenizer_filename(&self) -> &PathBuf {
46        unreachable!("Use `std::any::Any`.")
47    }
48    fn get_weight_filenames(&self) -> &[PathBuf] {
49        &self.weights
50    }
51    fn get_template_filename(&self) -> &Option<PathBuf> {
52        unreachable!("Use `std::any::Any`.")
53    }
54    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
55        unreachable!("Use `std::any::Any`.")
56    }
57    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
58        unreachable!("Use `std::any::Any`.")
59    }
60    fn get_processor_config(&self) -> &Option<PathBuf> {
61        unreachable!("Use `std::any::Any`.")
62    }
63    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
64        unreachable!("Use `std::any::Any`.")
65    }
66    fn get_adapter_paths(&self) -> &AdapterPaths {
67        unreachable!("Use `std::any::Any`.")
68    }
69}
70
71pub struct SpeechProcessor;
72
73impl Processor for SpeechProcessor {
74    fn process(
75        &self,
76        _pipeline: &dyn Pipeline,
77        _messages: Vec<IndexMap<String, MessageContent>>,
78        _add_generation_prompt: bool,
79        _add_special_tokens: bool,
80        _enable_thinking: Option<bool>,
81        _tools: Vec<crate::Tool>,
82    ) -> Result<(Vec<u32>, String)> {
83        anyhow::bail!(
84            "SpeechProcessor::process should not be used. It does not expect chat messages."
85        )
86    }
87    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
88        Arc::new(SpeechInputsProcessor)
89    }
90    fn get_special_tokens(&self) -> &[&'static str] {
91        &[]
92    }
93    fn template_action(&self) -> MessagesAction {
94        // Just a default
95        MessagesAction::FlattenOnlyText
96    }
97}
98
99pub struct SpeechInputsProcessor;
100
101#[derive(Clone)]
102pub struct ModelInputs {
103    pub(crate) prompts: Vec<String>,
104}
105
106impl InputsProcessor for SpeechInputsProcessor {
107    fn get_type(&self) -> InputsProcessorType {
108        InputsProcessorType::Text
109    }
110
111    fn process_inputs(
112        &self,
113        _tokenizer: Option<Arc<Tokenizer>>,
114        input_seqs: &mut [&mut Sequence],
115        _is_prompt: bool,
116        _is_xlora: bool,
117        _device: &Device,
118        _no_kv_cache: bool,
119        _last_n_context_len: Option<(usize, usize)>,
120        _return_raw_logits: bool,
121        _other_config: Option<Arc<dyn Any>>,
122        _paged_attn_metadata: Option<PagedAttentionMeta>,
123        prompt_chunksize: Option<NonZeroUsize>,
124        _mapper: Option<&dyn DeviceMapper>,
125    ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
126        let make_value = if prompt_chunksize.is_some() {
127            return Box::new(std::iter::once(Err(anyhow::Error::msg(
128                "Prompt batching is unsupported for speech models",
129            ))));
130        } else {
131            || {
132                let inputs = ModelInputs {
133                    prompts: input_seqs
134                        .iter()
135                        .map(|seq| seq.get_initial_prompt().to_string())
136                        .collect(),
137                };
138                Ok(InputProcessorOutput {
139                    inputs: Box::new(inputs),
140                    seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
141                })
142            }
143        };
144        Box::new(std::iter::once(make_value()))
145    }
146}
147
148pub struct SpeechPipeline {
149    model_id: String,
150    model: DiaPipeline,
151    metadata: Arc<GeneralMetadata>,
152    dummy_cache: EitherCache,
153    cfg: SpeechGenerationConfig,
154}
155
156pub struct SpeechLoader {
157    pub model_id: String,
158    pub dac_model_id: Option<String>,
159    pub arch: SpeechLoaderType,
160    pub cfg: Option<SpeechGenerationConfig>,
161}
162
163impl Loader for SpeechLoader {
164    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
165    fn load_model_from_hf(
166        &self,
167        revision: Option<String>,
168        token_source: TokenSource,
169        dtype: &dyn TryIntoDType,
170        device: &Device,
171        silent: bool,
172        mapper: DeviceMapSetting,
173        in_situ_quant: Option<IsqType>,
174        paged_attn_config: Option<PagedAttentionConfig>,
175    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
176        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
177            // Main weights first, DAC is the final one.
178            let mut weights = Vec::new();
179
180            // Main model
181            let config = {
182                let api = ApiBuilder::new()
183                    .with_progress(!silent)
184                    .with_token(get_token(&token_source)?)
185                    .build()?;
186                let revision = revision.clone().unwrap_or("main".to_string());
187                let api = api.repo(Repo::with_revision(
188                    self.model_id.to_string(),
189                    RepoType::Model,
190                    revision.clone(),
191                ));
192                let model_id = std::path::Path::new(&self.model_id);
193
194                let weight = api_get_file!(api, "model.safetensors", &model_id);
195                let config = api_get_file!(api, "config.json", &model_id);
196                weights.push(weight);
197                config
198            };
199
200            // DAC model
201            {
202                let api = ApiBuilder::new()
203                    .with_progress(!silent)
204                    .with_token(get_token(&token_source)?)
205                    .build()?;
206                let revision = revision.unwrap_or("main".to_string());
207
208                // Apply default here
209                let dac_model = self
210                    .dac_model_id
211                    .clone()
212                    .unwrap_or_else(|| match self.arch {
213                        SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
214                    });
215
216                let api = api.repo(Repo::with_revision(
217                    dac_model.clone(),
218                    RepoType::Model,
219                    revision.clone(),
220                ));
221                let model_id = std::path::Path::new(&dac_model);
222
223                let weight = api_get_file!(api, "model.safetensors", &model_id);
224                weights.push(weight);
225            }
226
227            Ok(Box::new(SpeechModelPaths { weights, config }))
228        };
229        self.load_model_from_path(
230            &paths?,
231            dtype,
232            device,
233            silent,
234            mapper,
235            in_situ_quant,
236            paged_attn_config,
237        )
238    }
239
240    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
241    fn load_model_from_path(
242        &self,
243        paths: &Box<dyn ModelPaths>,
244        dtype: &dyn TryIntoDType,
245        device: &Device,
246        silent: bool,
247        mapper: DeviceMapSetting,
248        in_situ_quant: Option<IsqType>,
249        _paged_attn_config: Option<PagedAttentionConfig>,
250    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
251        let paths = &paths
252            .as_ref()
253            .as_any()
254            .downcast_ref::<SpeechModelPaths>()
255            .expect("Path downcast failed.");
256
257        if matches!(mapper, DeviceMapSetting::Map(_)) {
258            anyhow::bail!("Device mapping is not supported for speech models.")
259        }
260
261        mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
262
263        let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
264
265        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
266        let dtype = mapper.get_min_dtype(dtype)?;
267
268        // Last weight is the dac.
269        let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
270        let vb = from_mmaped_safetensors(
271            model_weights,
272            Vec::new(),
273            Some(dtype),
274            device,
275            vec![None],
276            silent,
277            None,
278            |_| true,
279            Arc::new(|_| DeviceForLoadTensor::Base),
280        )?;
281
282        let dac_vb = unsafe {
283            VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
284        };
285
286        // Only Dia is supported for now.
287        assert_eq!(self.arch, SpeechLoaderType::Dia);
288
289        let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
290
291        Ok(Arc::new(Mutex::new(SpeechPipeline {
292            model_id: self.model_id.clone(),
293            model,
294            metadata: Arc::new(GeneralMetadata {
295                max_seq_len: 1024,
296                llg_factory: None,
297                is_xlora: false,
298                no_prefix_cache: false,
299                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
300                eos_tok: vec![],
301                kind: ModelKind::Normal,
302                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
303                activation_dtype: dtype,
304                sliding_window: None,
305                cache_config: None,
306                cache_engine: None,
307                prompt_chunksize: None,
308                model_metadata: None,
309                modalities: Modalities {
310                    input: vec![SupportedModality::Text],
311                    output: vec![SupportedModality::Audio],
312                },
313            }),
314            dummy_cache: EitherCache::Full(Cache::new(0, false)),
315            cfg: self
316                .cfg
317                .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
318        })))
319    }
320
321    fn get_id(&self) -> String {
322        self.model_id.clone()
323    }
324
325    fn get_kind(&self) -> ModelKind {
326        ModelKind::Normal
327    }
328}
329
330impl PreProcessingMixin for SpeechPipeline {
331    fn get_processor(&self) -> Arc<dyn Processor> {
332        Arc::new(SpeechProcessor)
333    }
334    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
335        None
336    }
337    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
338        None
339    }
340}
341
342impl IsqPipelineMixin for SpeechPipeline {
343    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
344        anyhow::bail!("Speech models do not support ISQ for now.")
345    }
346}
347
348impl CacheManagerMixin for SpeechPipeline {
349    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
350    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
351    fn set_none_cache(
352        &self,
353        _seqs: &mut [&mut Sequence],
354        _reset_non_granular: bool,
355        _modify_draft_cache: bool,
356        _load_preallocated_cache: bool,
357    ) {
358    }
359    fn cache(&self) -> &EitherCache {
360        &self.dummy_cache
361    }
362}
363
364impl MetadataMixin for SpeechPipeline {
365    fn device(&self) -> Device {
366        self.model.device().clone()
367    }
368    fn get_metadata(&self) -> Arc<GeneralMetadata> {
369        self.metadata.clone()
370    }
371    fn name(&self) -> String {
372        self.model_id.clone()
373    }
374    fn reset_non_granular_state(&self) {}
375    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
376        None
377    }
378    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
379        None
380    }
381}
382
383#[async_trait::async_trait]
384impl Pipeline for SpeechPipeline {
385    fn forward_inputs(
386        &mut self,
387        inputs: Box<dyn Any>,
388        return_raw_logits: bool,
389    ) -> candle_core::Result<ForwardInputsResult> {
390        assert!(!return_raw_logits);
391
392        let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
393        let mut pcms = Vec::new();
394        let mut rates = Vec::new();
395        let mut channels_all = Vec::new();
396        for prompt in prompts {
397            let SpeechGenerationOutput {
398                pcm,
399                rate,
400                channels,
401            } = self.model.generate(&prompt, &self.cfg)?;
402            pcms.push(pcm);
403            rates.push(rate);
404            channels_all.push(channels);
405        }
406
407        Ok(ForwardInputsResult::Speech {
408            pcms,
409            rates,
410            channels: channels_all,
411        })
412    }
413
414    async fn sample_causal_gen(
415        &self,
416        _seqs: &mut [&mut Sequence],
417        _logits: Vec<Tensor>,
418        _prefix_cacher: &mut PrefixCacheManagerV2,
419        _disable_eos_stop: bool,
420        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
421    ) -> Result<(), candle_core::Error> {
422        candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
423    }
424
425    fn category(&self) -> ModelCategory {
426        ModelCategory::Speech
427    }
428}
429
430impl AnyMoePipelineMixin for SpeechPipeline {}