mistralrs_core/pipeline/
speech.rs

1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3    AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4    GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5    Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6    PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::varbuilder_utils::DeviceForLoadTensor;
14use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
15use crate::{
16    api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
17    SpeechGenerationConfig, TryIntoDType,
18};
19use anyhow::Result;
20use candle_core::{Device, Tensor};
21use candle_nn::VarBuilder;
22use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
23use indexmap::IndexMap;
24use mistralrs_quant::IsqType;
25use rand_isaac::Isaac64Rng;
26use regex::Regex;
27use std::any::Any;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokenizers::Tokenizer;
31use tokio::sync::Mutex;
32use tracing::info;
33
34#[derive(Clone, Debug)]
35pub struct SpeechModelPaths {
36    weights: Vec<PathBuf>,
37    config: PathBuf,
38}
39
40impl ModelPaths for SpeechModelPaths {
41    fn get_config_filename(&self) -> &PathBuf {
42        &self.config
43    }
44    fn get_tokenizer_filename(&self) -> &PathBuf {
45        unreachable!("Use `std::any::Any`.")
46    }
47    fn get_weight_filenames(&self) -> &[PathBuf] {
48        &self.weights
49    }
50    fn get_template_filename(&self) -> &Option<PathBuf> {
51        unreachable!("Use `std::any::Any`.")
52    }
53    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
54        unreachable!("Use `std::any::Any`.")
55    }
56    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
57        unreachable!("Use `std::any::Any`.")
58    }
59    fn get_processor_config(&self) -> &Option<PathBuf> {
60        unreachable!("Use `std::any::Any`.")
61    }
62    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
63        unreachable!("Use `std::any::Any`.")
64    }
65    fn get_adapter_paths(&self) -> &AdapterPaths {
66        unreachable!("Use `std::any::Any`.")
67    }
68}
69
70pub struct SpeechProcessor;
71
72impl Processor for SpeechProcessor {
73    fn process(
74        &self,
75        _pipeline: &dyn Pipeline,
76        _messages: Vec<IndexMap<String, MessageContent>>,
77        _add_generation_prompt: bool,
78        _add_special_tokens: bool,
79        _enable_thinking: Option<bool>,
80        _tools: Vec<crate::Tool>,
81    ) -> Result<(Vec<u32>, String)> {
82        anyhow::bail!(
83            "SpeechProcessor::process should not be used. It does not expect chat messages."
84        )
85    }
86    fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
87        Arc::new(SpeechInputsProcessor)
88    }
89    fn get_special_tokens(&self) -> &[&'static str] {
90        &[]
91    }
92    fn template_action(&self) -> MessagesAction {
93        // Just a default
94        MessagesAction::FlattenOnlyText
95    }
96}
97
98pub struct SpeechInputsProcessor;
99
100#[derive(Clone)]
101pub struct ModelInputs {
102    pub(crate) prompts: Vec<String>,
103}
104
105impl InputsProcessor for SpeechInputsProcessor {
106    fn get_type(&self) -> InputsProcessorType {
107        InputsProcessorType::Text
108    }
109
110    fn process_inputs(
111        &self,
112        _tokenizer: Option<Arc<Tokenizer>>,
113        input_seqs: &mut [&mut Sequence],
114        _is_prompt: bool,
115        _is_xlora: bool,
116        _device: &Device,
117        _no_kv_cache: bool,
118        _last_n_context_len: Option<(usize, usize)>,
119        _return_raw_logits: bool,
120        _other_config: Option<Arc<dyn Any>>,
121        _paged_attn_metadata: Option<PagedAttentionMeta>,
122        _mapper: Option<&dyn DeviceMapper>,
123    ) -> Result<InputProcessorOutput> {
124        let inputs = ModelInputs {
125            prompts: input_seqs
126                .iter()
127                .map(|seq| seq.get_initial_prompt().to_string())
128                .collect(),
129        };
130        Ok(InputProcessorOutput {
131            inputs: Box::new(inputs),
132            seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
133        })
134    }
135}
136
137pub struct SpeechPipeline {
138    model_id: String,
139    model: DiaPipeline,
140    metadata: Arc<GeneralMetadata>,
141    dummy_cache: EitherCache,
142    cfg: SpeechGenerationConfig,
143}
144
145pub struct SpeechLoader {
146    pub model_id: String,
147    pub dac_model_id: Option<String>,
148    pub arch: SpeechLoaderType,
149    pub cfg: Option<SpeechGenerationConfig>,
150}
151
152impl Loader for SpeechLoader {
153    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
154    fn load_model_from_hf(
155        &self,
156        revision: Option<String>,
157        token_source: TokenSource,
158        dtype: &dyn TryIntoDType,
159        device: &Device,
160        silent: bool,
161        mapper: DeviceMapSetting,
162        in_situ_quant: Option<IsqType>,
163        paged_attn_config: Option<PagedAttentionConfig>,
164    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
165        let paths: anyhow::Result<Box<dyn ModelPaths>> = {
166            // Main weights first, DAC is the final one.
167            let mut weights = Vec::new();
168
169            // Main model
170            let config = {
171                let api = ApiBuilder::new()
172                    .with_progress(!silent)
173                    .with_token(get_token(&token_source)?)
174                    .build()?;
175                let revision = revision.clone().unwrap_or("main".to_string());
176                let api = api.repo(Repo::with_revision(
177                    self.model_id.to_string(),
178                    RepoType::Model,
179                    revision.clone(),
180                ));
181                let model_id = std::path::Path::new(&self.model_id);
182
183                let weight = api_get_file!(api, "model.safetensors", &model_id);
184                let config = api_get_file!(api, "config.json", &model_id);
185                weights.push(weight);
186                config
187            };
188
189            // DAC model
190            {
191                let api = ApiBuilder::new()
192                    .with_progress(!silent)
193                    .with_token(get_token(&token_source)?)
194                    .build()?;
195                let revision = revision.unwrap_or("main".to_string());
196
197                // Apply default here
198                let dac_model = self
199                    .dac_model_id
200                    .clone()
201                    .unwrap_or_else(|| match self.arch {
202                        SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
203                    });
204
205                let api = api.repo(Repo::with_revision(
206                    dac_model.clone(),
207                    RepoType::Model,
208                    revision.clone(),
209                ));
210                let model_id = std::path::Path::new(&dac_model);
211
212                let weight = api_get_file!(api, "model.safetensors", &model_id);
213                weights.push(weight);
214            }
215
216            Ok(Box::new(SpeechModelPaths { weights, config }))
217        };
218        self.load_model_from_path(
219            &paths?,
220            dtype,
221            device,
222            silent,
223            mapper,
224            in_situ_quant,
225            paged_attn_config,
226        )
227    }
228
229    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
230    fn load_model_from_path(
231        &self,
232        paths: &Box<dyn ModelPaths>,
233        dtype: &dyn TryIntoDType,
234        device: &Device,
235        silent: bool,
236        mapper: DeviceMapSetting,
237        in_situ_quant: Option<IsqType>,
238        _paged_attn_config: Option<PagedAttentionConfig>,
239    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
240        let paths = &paths
241            .as_ref()
242            .as_any()
243            .downcast_ref::<SpeechModelPaths>()
244            .expect("Path downcast failed.");
245
246        if matches!(mapper, DeviceMapSetting::Map(_)) {
247            anyhow::bail!("Device mapping is not supported for speech models.")
248        }
249
250        mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
251
252        let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
253
254        #[cfg(feature = "cuda")]
255        if let Device::Cuda(dev) = &device {
256            unsafe { dev.disable_event_tracking() };
257        }
258
259        let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
260        let dtype = mapper.get_min_dtype(dtype)?;
261
262        // Last weight is the dac.
263        let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
264        let vb = from_mmaped_safetensors(
265            model_weights,
266            Vec::new(),
267            Some(dtype),
268            device,
269            vec![None],
270            silent,
271            None,
272            |_| true,
273            Arc::new(|_| DeviceForLoadTensor::Base),
274        )?;
275
276        let dac_vb = unsafe {
277            VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
278        };
279
280        // Only Dia is supported for now.
281        assert_eq!(self.arch, SpeechLoaderType::Dia);
282
283        let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
284
285        Ok(Arc::new(Mutex::new(SpeechPipeline {
286            model_id: self.model_id.clone(),
287            model,
288            metadata: Arc::new(GeneralMetadata {
289                max_seq_len: 1024,
290                llg_factory: None,
291                is_xlora: false,
292                no_prefix_cache: false,
293                num_hidden_layers: 1, // FIXME(EricLBuehler): we know this is only for caching, so its OK.
294                eos_tok: vec![],
295                kind: ModelKind::Normal,
296                no_kv_cache: true, // NOTE(EricLBuehler): no cache for these.
297                activation_dtype: dtype,
298                sliding_window: None,
299                cache_config: None,
300                cache_engine: None,
301                model_metadata: None,
302                modalities: Modalities {
303                    input: vec![SupportedModality::Text],
304                    output: vec![SupportedModality::Audio],
305                },
306            }),
307            dummy_cache: EitherCache::Full(Cache::new(0, false)),
308            cfg: self
309                .cfg
310                .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
311        })))
312    }
313
314    fn get_id(&self) -> String {
315        self.model_id.clone()
316    }
317
318    fn get_kind(&self) -> ModelKind {
319        ModelKind::Normal
320    }
321}
322
323impl PreProcessingMixin for SpeechPipeline {
324    fn get_processor(&self) -> Arc<dyn Processor> {
325        Arc::new(SpeechProcessor)
326    }
327    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
328        None
329    }
330    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
331        None
332    }
333}
334
335impl IsqPipelineMixin for SpeechPipeline {
336    fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
337        anyhow::bail!("Speech models do not support ISQ for now.")
338    }
339}
340
341impl CacheManagerMixin for SpeechPipeline {
342    fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
343    fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
344    fn set_none_cache(
345        &self,
346        _seqs: &mut [&mut Sequence],
347        _reset_non_granular: bool,
348        _modify_draft_cache: bool,
349        _load_preallocated_cache: bool,
350    ) {
351    }
352    fn cache(&self) -> &EitherCache {
353        &self.dummy_cache
354    }
355}
356
357impl MetadataMixin for SpeechPipeline {
358    fn device(&self) -> Device {
359        self.model.device().clone()
360    }
361    fn get_metadata(&self) -> Arc<GeneralMetadata> {
362        self.metadata.clone()
363    }
364    fn name(&self) -> String {
365        self.model_id.clone()
366    }
367    fn reset_non_granular_state(&self) {}
368    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
369        None
370    }
371    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
372        None
373    }
374}
375
376#[async_trait::async_trait]
377impl Pipeline for SpeechPipeline {
378    fn forward_inputs(
379        &mut self,
380        inputs: Box<dyn Any>,
381        return_raw_logits: bool,
382    ) -> candle_core::Result<ForwardInputsResult> {
383        assert!(!return_raw_logits);
384
385        let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
386        let mut pcms = Vec::new();
387        let mut rates = Vec::new();
388        let mut channels_all = Vec::new();
389        for prompt in prompts {
390            let SpeechGenerationOutput {
391                pcm,
392                rate,
393                channels,
394            } = self.model.generate(&prompt, &self.cfg)?;
395            pcms.push(pcm);
396            rates.push(rate);
397            channels_all.push(channels);
398        }
399
400        Ok(ForwardInputsResult::Speech {
401            pcms,
402            rates,
403            channels: channels_all,
404        })
405    }
406
407    async fn sample_causal_gen(
408        &self,
409        _seqs: &mut [&mut Sequence],
410        _logits: Vec<Tensor>,
411        _prefix_cacher: &mut PrefixCacheManagerV2,
412        _disable_eos_stop: bool,
413        _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
414    ) -> Result<(), candle_core::Error> {
415        candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
416    }
417
418    fn category(&self) -> ModelCategory {
419        ModelCategory::Speech
420    }
421}
422
423impl AnyMoePipelineMixin for SpeechPipeline {}