1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3 AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4 GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5 Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::varbuilder_utils::DeviceForLoadTensor;
14use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
15use crate::{
16 api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
17 SpeechGenerationConfig, TryIntoDType,
18};
19use anyhow::Result;
20use candle_core::{Device, Tensor};
21use candle_nn::VarBuilder;
22use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
23use indexmap::IndexMap;
24use mistralrs_quant::IsqType;
25use rand_isaac::Isaac64Rng;
26use regex::Regex;
27use std::any::Any;
28use std::path::PathBuf;
29use std::sync::Arc;
30use tokenizers::Tokenizer;
31use tokio::sync::Mutex;
32use tracing::info;
33
34#[derive(Clone, Debug)]
35pub struct SpeechModelPaths {
36 weights: Vec<PathBuf>,
37 config: PathBuf,
38}
39
40impl ModelPaths for SpeechModelPaths {
41 fn get_config_filename(&self) -> &PathBuf {
42 &self.config
43 }
44 fn get_tokenizer_filename(&self) -> &PathBuf {
45 unreachable!("Use `std::any::Any`.")
46 }
47 fn get_weight_filenames(&self) -> &[PathBuf] {
48 &self.weights
49 }
50 fn get_template_filename(&self) -> &Option<PathBuf> {
51 unreachable!("Use `std::any::Any`.")
52 }
53 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
54 unreachable!("Use `std::any::Any`.")
55 }
56 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
57 unreachable!("Use `std::any::Any`.")
58 }
59 fn get_processor_config(&self) -> &Option<PathBuf> {
60 unreachable!("Use `std::any::Any`.")
61 }
62 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
63 unreachable!("Use `std::any::Any`.")
64 }
65 fn get_adapter_paths(&self) -> &AdapterPaths {
66 unreachable!("Use `std::any::Any`.")
67 }
68}
69
70pub struct SpeechProcessor;
71
72impl Processor for SpeechProcessor {
73 fn process(
74 &self,
75 _pipeline: &dyn Pipeline,
76 _messages: Vec<IndexMap<String, MessageContent>>,
77 _add_generation_prompt: bool,
78 _add_special_tokens: bool,
79 _enable_thinking: Option<bool>,
80 _tools: Vec<crate::Tool>,
81 ) -> Result<(Vec<u32>, String)> {
82 anyhow::bail!(
83 "SpeechProcessor::process should not be used. It does not expect chat messages."
84 )
85 }
86 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
87 Arc::new(SpeechInputsProcessor)
88 }
89 fn get_special_tokens(&self) -> &[&'static str] {
90 &[]
91 }
92 fn template_action(&self) -> MessagesAction {
93 MessagesAction::FlattenOnlyText
95 }
96}
97
98pub struct SpeechInputsProcessor;
99
100#[derive(Clone)]
101pub struct ModelInputs {
102 pub(crate) prompts: Vec<String>,
103}
104
105impl InputsProcessor for SpeechInputsProcessor {
106 fn get_type(&self) -> InputsProcessorType {
107 InputsProcessorType::Text
108 }
109
110 fn process_inputs(
111 &self,
112 _tokenizer: Option<Arc<Tokenizer>>,
113 input_seqs: &mut [&mut Sequence],
114 _is_prompt: bool,
115 _is_xlora: bool,
116 _device: &Device,
117 _no_kv_cache: bool,
118 _last_n_context_len: Option<(usize, usize)>,
119 _return_raw_logits: bool,
120 _other_config: Option<Arc<dyn Any>>,
121 _paged_attn_metadata: Option<PagedAttentionMeta>,
122 _mapper: Option<&dyn DeviceMapper>,
123 ) -> Result<InputProcessorOutput> {
124 let inputs = ModelInputs {
125 prompts: input_seqs
126 .iter()
127 .map(|seq| seq.get_initial_prompt().to_string())
128 .collect(),
129 };
130 Ok(InputProcessorOutput {
131 inputs: Box::new(inputs),
132 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
133 })
134 }
135}
136
137pub struct SpeechPipeline {
138 model_id: String,
139 model: DiaPipeline,
140 metadata: Arc<GeneralMetadata>,
141 dummy_cache: EitherCache,
142 cfg: SpeechGenerationConfig,
143}
144
145pub struct SpeechLoader {
146 pub model_id: String,
147 pub dac_model_id: Option<String>,
148 pub arch: SpeechLoaderType,
149 pub cfg: Option<SpeechGenerationConfig>,
150}
151
152impl Loader for SpeechLoader {
153 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
154 fn load_model_from_hf(
155 &self,
156 revision: Option<String>,
157 token_source: TokenSource,
158 dtype: &dyn TryIntoDType,
159 device: &Device,
160 silent: bool,
161 mapper: DeviceMapSetting,
162 in_situ_quant: Option<IsqType>,
163 paged_attn_config: Option<PagedAttentionConfig>,
164 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
165 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
166 let mut weights = Vec::new();
168
169 let config = {
171 let api = ApiBuilder::new()
172 .with_progress(!silent)
173 .with_token(get_token(&token_source)?)
174 .build()?;
175 let revision = revision.clone().unwrap_or("main".to_string());
176 let api = api.repo(Repo::with_revision(
177 self.model_id.to_string(),
178 RepoType::Model,
179 revision.clone(),
180 ));
181 let model_id = std::path::Path::new(&self.model_id);
182
183 let weight = api_get_file!(api, "model.safetensors", &model_id);
184 let config = api_get_file!(api, "config.json", &model_id);
185 weights.push(weight);
186 config
187 };
188
189 {
191 let api = ApiBuilder::new()
192 .with_progress(!silent)
193 .with_token(get_token(&token_source)?)
194 .build()?;
195 let revision = revision.unwrap_or("main".to_string());
196
197 let dac_model = self
199 .dac_model_id
200 .clone()
201 .unwrap_or_else(|| match self.arch {
202 SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
203 });
204
205 let api = api.repo(Repo::with_revision(
206 dac_model.clone(),
207 RepoType::Model,
208 revision.clone(),
209 ));
210 let model_id = std::path::Path::new(&dac_model);
211
212 let weight = api_get_file!(api, "model.safetensors", &model_id);
213 weights.push(weight);
214 }
215
216 Ok(Box::new(SpeechModelPaths { weights, config }))
217 };
218 self.load_model_from_path(
219 &paths?,
220 dtype,
221 device,
222 silent,
223 mapper,
224 in_situ_quant,
225 paged_attn_config,
226 )
227 }
228
229 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
230 fn load_model_from_path(
231 &self,
232 paths: &Box<dyn ModelPaths>,
233 dtype: &dyn TryIntoDType,
234 device: &Device,
235 silent: bool,
236 mapper: DeviceMapSetting,
237 in_situ_quant: Option<IsqType>,
238 _paged_attn_config: Option<PagedAttentionConfig>,
239 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
240 let paths = &paths
241 .as_ref()
242 .as_any()
243 .downcast_ref::<SpeechModelPaths>()
244 .expect("Path downcast failed.");
245
246 if matches!(mapper, DeviceMapSetting::Map(_)) {
247 anyhow::bail!("Device mapping is not supported for speech models.")
248 }
249
250 mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
251
252 let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
253
254 #[cfg(feature = "cuda")]
255 if let Device::Cuda(dev) = &device {
256 unsafe { dev.disable_event_tracking() };
257 }
258
259 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
260 let dtype = mapper.get_min_dtype(dtype)?;
261
262 let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
264 let vb = from_mmaped_safetensors(
265 model_weights,
266 Vec::new(),
267 Some(dtype),
268 device,
269 vec![None],
270 silent,
271 None,
272 |_| true,
273 Arc::new(|_| DeviceForLoadTensor::Base),
274 )?;
275
276 let dac_vb = unsafe {
277 VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
278 };
279
280 assert_eq!(self.arch, SpeechLoaderType::Dia);
282
283 let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
284
285 Ok(Arc::new(Mutex::new(SpeechPipeline {
286 model_id: self.model_id.clone(),
287 model,
288 metadata: Arc::new(GeneralMetadata {
289 max_seq_len: 1024,
290 llg_factory: None,
291 is_xlora: false,
292 no_prefix_cache: false,
293 num_hidden_layers: 1, eos_tok: vec![],
295 kind: ModelKind::Normal,
296 no_kv_cache: true, activation_dtype: dtype,
298 sliding_window: None,
299 cache_config: None,
300 cache_engine: None,
301 model_metadata: None,
302 modalities: Modalities {
303 input: vec![SupportedModality::Text],
304 output: vec![SupportedModality::Audio],
305 },
306 }),
307 dummy_cache: EitherCache::Full(Cache::new(0, false)),
308 cfg: self
309 .cfg
310 .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
311 })))
312 }
313
314 fn get_id(&self) -> String {
315 self.model_id.clone()
316 }
317
318 fn get_kind(&self) -> ModelKind {
319 ModelKind::Normal
320 }
321}
322
323impl PreProcessingMixin for SpeechPipeline {
324 fn get_processor(&self) -> Arc<dyn Processor> {
325 Arc::new(SpeechProcessor)
326 }
327 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
328 None
329 }
330 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
331 None
332 }
333}
334
335impl IsqPipelineMixin for SpeechPipeline {
336 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
337 anyhow::bail!("Speech models do not support ISQ for now.")
338 }
339}
340
341impl CacheManagerMixin for SpeechPipeline {
342 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
343 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
344 fn set_none_cache(
345 &self,
346 _seqs: &mut [&mut Sequence],
347 _reset_non_granular: bool,
348 _modify_draft_cache: bool,
349 _load_preallocated_cache: bool,
350 ) {
351 }
352 fn cache(&self) -> &EitherCache {
353 &self.dummy_cache
354 }
355}
356
357impl MetadataMixin for SpeechPipeline {
358 fn device(&self) -> Device {
359 self.model.device().clone()
360 }
361 fn get_metadata(&self) -> Arc<GeneralMetadata> {
362 self.metadata.clone()
363 }
364 fn name(&self) -> String {
365 self.model_id.clone()
366 }
367 fn reset_non_granular_state(&self) {}
368 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
369 None
370 }
371 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
372 None
373 }
374}
375
376#[async_trait::async_trait]
377impl Pipeline for SpeechPipeline {
378 fn forward_inputs(
379 &mut self,
380 inputs: Box<dyn Any>,
381 return_raw_logits: bool,
382 ) -> candle_core::Result<ForwardInputsResult> {
383 assert!(!return_raw_logits);
384
385 let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
386 let mut pcms = Vec::new();
387 let mut rates = Vec::new();
388 let mut channels_all = Vec::new();
389 for prompt in prompts {
390 let SpeechGenerationOutput {
391 pcm,
392 rate,
393 channels,
394 } = self.model.generate(&prompt, &self.cfg)?;
395 pcms.push(pcm);
396 rates.push(rate);
397 channels_all.push(channels);
398 }
399
400 Ok(ForwardInputsResult::Speech {
401 pcms,
402 rates,
403 channels: channels_all,
404 })
405 }
406
407 async fn sample_causal_gen(
408 &self,
409 _seqs: &mut [&mut Sequence],
410 _logits: Vec<Tensor>,
411 _prefix_cacher: &mut PrefixCacheManagerV2,
412 _disable_eos_stop: bool,
413 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
414 ) -> Result<(), candle_core::Error> {
415 candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
416 }
417
418 fn category(&self) -> ModelCategory {
419 ModelCategory::Speech
420 }
421}
422
423impl AnyMoePipelineMixin for SpeechPipeline {}