1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3 AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4 GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5 Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, EmbeddingModulePaths, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::varbuilder_utils::DeviceForLoadTensor;
14use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
15use crate::{
16 api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
17 SpeechGenerationConfig, TryIntoDType,
18};
19use anyhow::Result;
20use candle_core::{Device, Tensor};
21use candle_nn::VarBuilder;
22use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
23use indexmap::IndexMap;
24use mistralrs_quant::IsqType;
25use rand_isaac::Isaac64Rng;
26use regex::Regex;
27use std::any::Any;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokenizers::Tokenizer;
31use tokio::sync::Mutex;
32use tracing::info;
33
34#[derive(Clone, Debug)]
35pub struct SpeechModelPaths {
36 weights: Vec<PathBuf>,
37 config: PathBuf,
38}
39
40impl ModelPaths for SpeechModelPaths {
41 fn get_config_filename(&self) -> &PathBuf {
42 &self.config
43 }
44 fn get_tokenizer_filename(&self) -> &PathBuf {
45 unreachable!("Use `std::any::Any`.")
46 }
47 fn get_weight_filenames(&self) -> &[PathBuf] {
48 &self.weights
49 }
50 fn get_template_filename(&self) -> &Option<PathBuf> {
51 unreachable!("Use `std::any::Any`.")
52 }
53 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
54 unreachable!("Use `std::any::Any`.")
55 }
56 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
57 unreachable!("Use `std::any::Any`.")
58 }
59 fn get_processor_config(&self) -> &Option<PathBuf> {
60 unreachable!("Use `std::any::Any`.")
61 }
62 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
63 unreachable!("Use `std::any::Any`.")
64 }
65 fn get_adapter_paths(&self) -> &AdapterPaths {
66 unreachable!("Use `std::any::Any`.")
67 }
68 fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
69 unreachable!("Use `std::any::Any`.")
70 }
71}
72
73pub struct SpeechProcessor;
74
75impl Processor for SpeechProcessor {
76 fn process(
77 &self,
78 _pipeline: &dyn Pipeline,
79 _messages: Vec<IndexMap<String, MessageContent>>,
80 _add_generation_prompt: bool,
81 _add_special_tokens: bool,
82 _enable_thinking: Option<bool>,
83 _tools: Vec<crate::Tool>,
84 ) -> Result<(Vec<u32>, String)> {
85 anyhow::bail!(
86 "SpeechProcessor::process should not be used. It does not expect chat messages."
87 )
88 }
89 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
90 Arc::new(SpeechInputsProcessor)
91 }
92 fn get_special_tokens(&self) -> &[&'static str] {
93 &[]
94 }
95 fn template_action(&self) -> MessagesAction {
96 MessagesAction::FlattenOnlyText
98 }
99}
100
101pub struct SpeechInputsProcessor;
102
103#[derive(Clone)]
104pub struct ModelInputs {
105 pub(crate) prompts: Vec<String>,
106}
107
108impl InputsProcessor for SpeechInputsProcessor {
109 fn get_type(&self) -> InputsProcessorType {
110 InputsProcessorType::Text
111 }
112
113 fn process_inputs(
114 &self,
115 _tokenizer: Option<Arc<Tokenizer>>,
116 input_seqs: &mut [&mut Sequence],
117 _is_prompt: bool,
118 _is_xlora: bool,
119 _device: &Device,
120 _no_kv_cache: bool,
121 _last_n_context_len: Option<(usize, usize)>,
122 _return_raw_logits: bool,
123 _other_config: Option<Arc<dyn Any>>,
124 _paged_attn_metadata: Option<PagedAttentionMeta>,
125 _mapper: Option<&dyn DeviceMapper>,
126 ) -> Result<InputProcessorOutput> {
127 let inputs = ModelInputs {
128 prompts: input_seqs
129 .iter()
130 .map(|seq| seq.get_initial_prompt().to_string())
131 .collect(),
132 };
133 Ok(InputProcessorOutput {
134 inputs: Box::new(inputs),
135 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
136 })
137 }
138}
139
140pub struct SpeechPipeline {
141 model_id: String,
142 model: DiaPipeline,
143 metadata: Arc<GeneralMetadata>,
144 dummy_cache: EitherCache,
145 cfg: SpeechGenerationConfig,
146}
147
148pub struct SpeechLoader {
149 pub model_id: String,
150 pub dac_model_id: Option<String>,
151 pub arch: SpeechLoaderType,
152 pub cfg: Option<SpeechGenerationConfig>,
153}
154
155impl Loader for SpeechLoader {
156 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
157 fn load_model_from_hf(
158 &self,
159 revision: Option<String>,
160 token_source: TokenSource,
161 dtype: &dyn TryIntoDType,
162 device: &Device,
163 silent: bool,
164 mapper: DeviceMapSetting,
165 in_situ_quant: Option<IsqType>,
166 paged_attn_config: Option<PagedAttentionConfig>,
167 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
168 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
169 let mut weights = Vec::new();
171
172 let config = {
174 let api = ApiBuilder::new()
175 .with_progress(!silent)
176 .with_token(get_token(&token_source)?)
177 .build()?;
178 let revision = revision.clone().unwrap_or("main".to_string());
179 let api = api.repo(Repo::with_revision(
180 self.model_id.to_string(),
181 RepoType::Model,
182 revision.clone(),
183 ));
184 let model_id = std::path::Path::new(&self.model_id);
185
186 let weight = api_get_file!(api, "model.safetensors", &model_id);
187 let config = api_get_file!(api, "config.json", &model_id);
188 weights.push(weight);
189 config
190 };
191
192 {
194 let api = ApiBuilder::new()
195 .with_progress(!silent)
196 .with_token(get_token(&token_source)?)
197 .build()?;
198 let revision = revision.unwrap_or("main".to_string());
199
200 let dac_model = self
202 .dac_model_id
203 .clone()
204 .unwrap_or_else(|| match self.arch {
205 SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
206 });
207
208 let api = api.repo(Repo::with_revision(
209 dac_model.clone(),
210 RepoType::Model,
211 revision.clone(),
212 ));
213 let model_id = std::path::Path::new(&dac_model);
214
215 let weight = api_get_file!(api, "model.safetensors", &model_id);
216 weights.push(weight);
217 }
218
219 Ok(Box::new(SpeechModelPaths { weights, config }))
220 };
221 self.load_model_from_path(
222 &paths?,
223 dtype,
224 device,
225 silent,
226 mapper,
227 in_situ_quant,
228 paged_attn_config,
229 )
230 }
231
232 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
233 fn load_model_from_path(
234 &self,
235 paths: &Box<dyn ModelPaths>,
236 dtype: &dyn TryIntoDType,
237 device: &Device,
238 silent: bool,
239 mapper: DeviceMapSetting,
240 in_situ_quant: Option<IsqType>,
241 _paged_attn_config: Option<PagedAttentionConfig>,
242 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
243 let paths = &paths
244 .as_ref()
245 .as_any()
246 .downcast_ref::<SpeechModelPaths>()
247 .expect("Path downcast failed.");
248
249 if matches!(mapper, DeviceMapSetting::Map(_)) {
250 anyhow::bail!("Device mapping is not supported for speech models.")
251 }
252
253 mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
254
255 let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
256
257 #[cfg(feature = "cuda")]
258 if let Device::Cuda(dev) = &device {
259 unsafe { dev.disable_event_tracking() };
260 }
261
262 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
263 let dtype = mapper.get_min_dtype(dtype)?;
264
265 let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
267 let vb = from_mmaped_safetensors(
268 model_weights,
269 Vec::new(),
270 Some(dtype),
271 device,
272 vec![None],
273 silent,
274 None,
275 |_| true,
276 Arc::new(|_| DeviceForLoadTensor::Base),
277 )?;
278
279 let dac_vb = unsafe {
280 VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
281 };
282
283 assert_eq!(self.arch, SpeechLoaderType::Dia);
285
286 let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
287
288 Ok(Arc::new(Mutex::new(SpeechPipeline {
289 model_id: self.model_id.clone(),
290 model,
291 metadata: Arc::new(GeneralMetadata {
292 max_seq_len: 1024,
293 llg_factory: None,
294 is_xlora: false,
295 no_prefix_cache: false,
296 num_hidden_layers: 1, eos_tok: vec![],
298 kind: ModelKind::Normal,
299 no_kv_cache: true, activation_dtype: dtype,
301 sliding_window: None,
302 cache_config: None,
303 cache_engine: None,
304 model_metadata: None,
305 modalities: Modalities {
306 input: vec![SupportedModality::Text],
307 output: vec![SupportedModality::Audio],
308 },
309 }),
310 dummy_cache: EitherCache::Full(Cache::new(0, false)),
311 cfg: self
312 .cfg
313 .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
314 })))
315 }
316
317 fn get_id(&self) -> String {
318 self.model_id.clone()
319 }
320
321 fn get_kind(&self) -> ModelKind {
322 ModelKind::Normal
323 }
324}
325
326impl PreProcessingMixin for SpeechPipeline {
327 fn get_processor(&self) -> Arc<dyn Processor> {
328 Arc::new(SpeechProcessor)
329 }
330 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
331 None
332 }
333 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
334 None
335 }
336}
337
338impl IsqPipelineMixin for SpeechPipeline {
339 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
340 anyhow::bail!("Speech models do not support ISQ for now.")
341 }
342}
343
344impl CacheManagerMixin for SpeechPipeline {
345 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
346 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
347 fn set_none_cache(
348 &self,
349 _seqs: &mut [&mut Sequence],
350 _reset_non_granular: bool,
351 _modify_draft_cache: bool,
352 _load_preallocated_cache: bool,
353 ) {
354 }
355 fn cache(&self) -> &EitherCache {
356 &self.dummy_cache
357 }
358}
359
360impl MetadataMixin for SpeechPipeline {
361 fn device(&self) -> Device {
362 self.model.device().clone()
363 }
364 fn get_metadata(&self) -> Arc<GeneralMetadata> {
365 self.metadata.clone()
366 }
367 fn name(&self) -> String {
368 self.model_id.clone()
369 }
370 fn reset_non_granular_state(&self) {}
371 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
372 None
373 }
374 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
375 None
376 }
377}
378
379#[async_trait::async_trait]
380impl Pipeline for SpeechPipeline {
381 fn forward_inputs(
382 &mut self,
383 inputs: Box<dyn Any>,
384 return_raw_logits: bool,
385 ) -> candle_core::Result<ForwardInputsResult> {
386 assert!(!return_raw_logits);
387
388 let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
389 let mut pcms = Vec::new();
390 let mut rates = Vec::new();
391 let mut channels_all = Vec::new();
392 for prompt in prompts {
393 let SpeechGenerationOutput {
394 pcm,
395 rate,
396 channels,
397 } = self.model.generate(&prompt, &self.cfg)?;
398 pcms.push(pcm);
399 rates.push(rate);
400 channels_all.push(channels);
401 }
402
403 Ok(ForwardInputsResult::Speech {
404 pcms,
405 rates,
406 channels: channels_all,
407 })
408 }
409
410 async fn sample_causal_gen(
411 &self,
412 _seqs: &mut [&mut Sequence],
413 _logits: Vec<Tensor>,
414 _prefix_cacher: &mut PrefixCacheManagerV2,
415 _disable_eos_stop: bool,
416 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
417 ) -> Result<(), candle_core::Error> {
418 candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
419 }
420
421 fn category(&self) -> ModelCategory {
422 ModelCategory::Speech
423 }
424}
425
426impl AnyMoePipelineMixin for SpeechPipeline {}