1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3 AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4 GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5 Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::progress::ProgressScopeGuard;
14use crate::utils::varbuilder_utils::DeviceForLoadTensor;
15use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
16use crate::{
17 api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
18 SpeechGenerationConfig, TryIntoDType,
19};
20use anyhow::Result;
21use candle_core::{Device, Tensor};
22use candle_nn::VarBuilder;
23use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
24use indexmap::IndexMap;
25use mistralrs_quant::IsqType;
26use rand_isaac::Isaac64Rng;
27use regex::Regex;
28use std::any::Any;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokenizers::Tokenizer;
32use tokio::sync::Mutex;
33use tracing::info;
34
35#[derive(Clone, Debug)]
36pub struct SpeechModelPaths {
37 weights: Vec<PathBuf>,
38 config: PathBuf,
39}
40
41impl ModelPaths for SpeechModelPaths {
42 fn get_config_filename(&self) -> &PathBuf {
43 &self.config
44 }
45 fn get_tokenizer_filename(&self) -> &PathBuf {
46 unreachable!("Use `std::any::Any`.")
47 }
48 fn get_weight_filenames(&self) -> &[PathBuf] {
49 &self.weights
50 }
51 fn get_template_filename(&self) -> &Option<PathBuf> {
52 unreachable!("Use `std::any::Any`.")
53 }
54 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
55 unreachable!("Use `std::any::Any`.")
56 }
57 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
58 unreachable!("Use `std::any::Any`.")
59 }
60 fn get_processor_config(&self) -> &Option<PathBuf> {
61 unreachable!("Use `std::any::Any`.")
62 }
63 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
64 unreachable!("Use `std::any::Any`.")
65 }
66 fn get_adapter_paths(&self) -> &AdapterPaths {
67 unreachable!("Use `std::any::Any`.")
68 }
69 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
70 unreachable!("Use `std::any::Any`.")
71 }
72}
73
74pub struct SpeechProcessor;
75
76impl Processor for SpeechProcessor {
77 fn process(
78 &self,
79 _pipeline: &dyn Pipeline,
80 _messages: Vec<IndexMap<String, MessageContent>>,
81 _add_generation_prompt: bool,
82 _add_special_tokens: bool,
83 _enable_thinking: Option<bool>,
84 _tools: Vec<crate::Tool>,
85 ) -> Result<(Vec<u32>, String)> {
86 anyhow::bail!(
87 "SpeechProcessor::process should not be used. It does not expect chat messages."
88 )
89 }
90 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
91 Arc::new(SpeechInputsProcessor)
92 }
93 fn get_special_tokens(&self) -> &[&'static str] {
94 &[]
95 }
96 fn template_action(&self) -> MessagesAction {
97 MessagesAction::FlattenOnlyText
99 }
100}
101
102pub struct SpeechInputsProcessor;
103
104#[derive(Clone)]
105pub struct ModelInputs {
106 pub(crate) prompts: Vec<String>,
107}
108
109impl InputsProcessor for SpeechInputsProcessor {
110 fn get_type(&self) -> InputsProcessorType {
111 InputsProcessorType::Text
112 }
113
114 fn process_inputs(
115 &self,
116 _tokenizer: Option<Arc<Tokenizer>>,
117 input_seqs: &mut [&mut Sequence],
118 _is_prompt: bool,
119 _is_xlora: bool,
120 _device: &Device,
121 _no_kv_cache: bool,
122 _last_n_context_len: Option<(usize, usize)>,
123 _return_raw_logits: bool,
124 _other_config: Option<Arc<dyn Any>>,
125 _paged_attn_metadata: Option<PagedAttentionMeta>,
126 _mapper: Option<&dyn DeviceMapper>,
127 ) -> Result<InputProcessorOutput> {
128 let inputs = ModelInputs {
129 prompts: input_seqs
130 .iter()
131 .map(|seq| seq.get_initial_prompt().to_string())
132 .collect(),
133 };
134 Ok(InputProcessorOutput {
135 inputs: Box::new(inputs),
136 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
137 })
138 }
139}
140
141pub struct SpeechPipeline {
142 model_id: String,
143 model: DiaPipeline,
144 metadata: Arc<GeneralMetadata>,
145 dummy_cache: EitherCache,
146 cfg: SpeechGenerationConfig,
147}
148
149pub struct SpeechLoader {
150 pub model_id: String,
151 pub dac_model_id: Option<String>,
152 pub arch: SpeechLoaderType,
153 pub cfg: Option<SpeechGenerationConfig>,
154}
155
156impl Loader for SpeechLoader {
157 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
158 fn load_model_from_hf(
159 &self,
160 revision: Option<String>,
161 token_source: TokenSource,
162 dtype: &dyn TryIntoDType,
163 device: &Device,
164 silent: bool,
165 mapper: DeviceMapSetting,
166 in_situ_quant: Option<IsqType>,
167 paged_attn_config: Option<PagedAttentionConfig>,
168 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
169 let _progress_guard = ProgressScopeGuard::new(silent);
170 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
171 let mut weights = Vec::new();
173
174 let config = {
176 let api = ApiBuilder::new()
177 .with_progress(!silent)
178 .with_token(get_token(&token_source)?)
179 .build()?;
180 let revision = revision.clone().unwrap_or("main".to_string());
181 let api = api.repo(Repo::with_revision(
182 self.model_id.to_string(),
183 RepoType::Model,
184 revision.clone(),
185 ));
186 let model_id = std::path::Path::new(&self.model_id);
187
188 let weight = api_get_file!(api, "model.safetensors", &model_id);
189 let config = api_get_file!(api, "config.json", &model_id);
190 weights.push(weight);
191 config
192 };
193
194 {
196 let api = ApiBuilder::new()
197 .with_progress(!silent)
198 .with_token(get_token(&token_source)?)
199 .build()?;
200 let revision = revision.unwrap_or("main".to_string());
201
202 let dac_model = self
204 .dac_model_id
205 .clone()
206 .unwrap_or_else(|| match self.arch {
207 SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
208 });
209
210 let api = api.repo(Repo::with_revision(
211 dac_model.clone(),
212 RepoType::Model,
213 revision.clone(),
214 ));
215 let model_id = std::path::Path::new(&dac_model);
216
217 let weight = api_get_file!(api, "model.safetensors", &model_id);
218 weights.push(weight);
219 }
220
221 Ok(Box::new(SpeechModelPaths { weights, config }))
222 };
223 self.load_model_from_path(
224 &paths?,
225 dtype,
226 device,
227 silent,
228 mapper,
229 in_situ_quant,
230 paged_attn_config,
231 )
232 }
233
234 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
235 fn load_model_from_path(
236 &self,
237 paths: &Box<dyn ModelPaths>,
238 dtype: &dyn TryIntoDType,
239 device: &Device,
240 silent: bool,
241 mapper: DeviceMapSetting,
242 in_situ_quant: Option<IsqType>,
243 _paged_attn_config: Option<PagedAttentionConfig>,
244 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
245 let _progress_guard = ProgressScopeGuard::new(silent);
246 let paths = &paths
247 .as_ref()
248 .as_any()
249 .downcast_ref::<SpeechModelPaths>()
250 .expect("Path downcast failed.");
251
252 if matches!(mapper, DeviceMapSetting::Map(_)) {
253 anyhow::bail!("Device mapping is not supported for speech models.")
254 }
255
256 mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
257
258 let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
259
260 #[cfg(feature = "cuda")]
261 if let Device::Cuda(dev) = &device {
262 unsafe { dev.disable_event_tracking() };
263 }
264
265 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
266 let dtype = mapper.get_min_dtype(dtype)?;
267
268 let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
270 let vb = from_mmaped_safetensors(
271 model_weights,
272 Vec::new(),
273 Some(dtype),
274 device,
275 vec![None],
276 silent,
277 None,
278 |_| true,
279 Arc::new(|_| DeviceForLoadTensor::Base),
280 )?;
281
282 let dac_vb = unsafe {
283 VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
284 };
285
286 assert_eq!(self.arch, SpeechLoaderType::Dia);
288
289 let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
290
291 Ok(Arc::new(Mutex::new(SpeechPipeline {
292 model_id: self.model_id.clone(),
293 model,
294 metadata: Arc::new(GeneralMetadata {
295 max_seq_len: 1024,
296 llg_factory: None,
297 is_xlora: false,
298 no_prefix_cache: false,
299 num_hidden_layers: 1, eos_tok: vec![],
301 kind: ModelKind::Normal,
302 no_kv_cache: true, activation_dtype: dtype,
304 sliding_window: None,
305 cache_config: None,
306 cache_engine: None,
307 model_metadata: None,
308 modalities: Modalities {
309 input: vec![SupportedModality::Text],
310 output: vec![SupportedModality::Audio],
311 },
312 }),
313 dummy_cache: EitherCache::Full(Cache::new(0, false)),
314 cfg: self
315 .cfg
316 .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
317 })))
318 }
319
320 fn get_id(&self) -> String {
321 self.model_id.clone()
322 }
323
324 fn get_kind(&self) -> ModelKind {
325 ModelKind::Normal
326 }
327}
328
329impl PreProcessingMixin for SpeechPipeline {
330 fn get_processor(&self) -> Arc<dyn Processor> {
331 Arc::new(SpeechProcessor)
332 }
333 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
334 None
335 }
336 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
337 None
338 }
339}
340
341impl IsqPipelineMixin for SpeechPipeline {
342 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
343 anyhow::bail!("Speech models do not support ISQ for now.")
344 }
345}
346
347impl CacheManagerMixin for SpeechPipeline {
348 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
349 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
350 fn set_none_cache(
351 &self,
352 _seqs: &mut [&mut Sequence],
353 _reset_non_granular: bool,
354 _modify_draft_cache: bool,
355 _load_preallocated_cache: bool,
356 ) {
357 }
358 fn cache(&self) -> &EitherCache {
359 &self.dummy_cache
360 }
361}
362
363impl MetadataMixin for SpeechPipeline {
364 fn device(&self) -> Device {
365 self.model.device().clone()
366 }
367 fn get_metadata(&self) -> Arc<GeneralMetadata> {
368 self.metadata.clone()
369 }
370 fn name(&self) -> String {
371 self.model_id.clone()
372 }
373 fn reset_non_granular_state(&self) {}
374 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
375 None
376 }
377 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
378 None
379 }
380}
381
382#[async_trait::async_trait]
383impl Pipeline for SpeechPipeline {
384 fn forward_inputs(
385 &mut self,
386 inputs: Box<dyn Any>,
387 return_raw_logits: bool,
388 ) -> candle_core::Result<ForwardInputsResult> {
389 assert!(!return_raw_logits);
390
391 let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
392 let mut pcms = Vec::new();
393 let mut rates = Vec::new();
394 let mut channels_all = Vec::new();
395 for prompt in prompts {
396 let SpeechGenerationOutput {
397 pcm,
398 rate,
399 channels,
400 } = self.model.generate(&prompt, &self.cfg)?;
401 pcms.push(pcm);
402 rates.push(rate);
403 channels_all.push(channels);
404 }
405
406 Ok(ForwardInputsResult::Speech {
407 pcms,
408 rates,
409 channels: channels_all,
410 })
411 }
412
413 async fn sample_causal_gen(
414 &self,
415 _seqs: &mut [&mut Sequence],
416 _logits: Vec<Tensor>,
417 _prefix_cacher: &mut PrefixCacheManagerV2,
418 _disable_eos_stop: bool,
419 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
420 ) -> Result<(), candle_core::Error> {
421 candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
422 }
423
424 fn category(&self) -> ModelCategory {
425 ModelCategory::Speech
426 }
427}
428
429impl AnyMoePipelineMixin for SpeechPipeline {}