1use super::text_models_inputs_processor::PagedAttentionMeta;
2use super::{
3 AdapterPaths, AnyMoePipelineMixin, Cache, CacheManagerMixin, EitherCache, ForwardInputsResult,
4 GeneralMetadata, InputProcessorOutput, InputsProcessor, InputsProcessorType, IsqPipelineMixin,
5 Loader, MessagesAction, MetadataMixin, ModelCategory, ModelKind, ModelPaths,
6 PreProcessingMixin, Processor, TokenSource,
7};
8use crate::device_map::DeviceMapper;
9use crate::pipeline::{ChatTemplate, Modalities, SupportedModality};
10use crate::prefix_cacher::PrefixCacheManagerV2;
11use crate::sequence::Sequence;
12use crate::speech_models::{DiaConfig, DiaPipeline, SpeechGenerationOutput, SpeechLoaderType};
13use crate::utils::varbuilder_utils::DeviceForLoadTensor;
14use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
15use crate::{
16 api_get_file, DeviceMapSetting, MessageContent, PagedAttentionConfig, Pipeline,
17 SpeechGenerationConfig, TryIntoDType,
18};
19use anyhow::Result;
20use candle_core::{Device, Tensor};
21use candle_nn::VarBuilder;
22use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
23use indexmap::IndexMap;
24use mistralrs_quant::IsqType;
25use rand_isaac::Isaac64Rng;
26use regex::Regex;
27use std::any::Any;
28use std::num::NonZeroUsize;
29use std::path::PathBuf;
30use std::sync::Arc;
31use tokenizers::Tokenizer;
32use tokio::sync::Mutex;
33use tracing::info;
34
35#[derive(Clone, Debug)]
36pub struct SpeechModelPaths {
37 weights: Vec<PathBuf>,
38 config: PathBuf,
39}
40
41impl ModelPaths for SpeechModelPaths {
42 fn get_config_filename(&self) -> &PathBuf {
43 &self.config
44 }
45 fn get_tokenizer_filename(&self) -> &PathBuf {
46 unreachable!("Use `std::any::Any`.")
47 }
48 fn get_weight_filenames(&self) -> &[PathBuf] {
49 &self.weights
50 }
51 fn get_template_filename(&self) -> &Option<PathBuf> {
52 unreachable!("Use `std::any::Any`.")
53 }
54 fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
55 unreachable!("Use `std::any::Any`.")
56 }
57 fn get_preprocessor_config(&self) -> &Option<PathBuf> {
58 unreachable!("Use `std::any::Any`.")
59 }
60 fn get_processor_config(&self) -> &Option<PathBuf> {
61 unreachable!("Use `std::any::Any`.")
62 }
63 fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
64 unreachable!("Use `std::any::Any`.")
65 }
66 fn get_adapter_paths(&self) -> &AdapterPaths {
67 unreachable!("Use `std::any::Any`.")
68 }
69}
70
71pub struct SpeechProcessor;
72
73impl Processor for SpeechProcessor {
74 fn process(
75 &self,
76 _pipeline: &dyn Pipeline,
77 _messages: Vec<IndexMap<String, MessageContent>>,
78 _add_generation_prompt: bool,
79 _add_special_tokens: bool,
80 _enable_thinking: Option<bool>,
81 _tools: Vec<crate::Tool>,
82 ) -> Result<(Vec<u32>, String)> {
83 anyhow::bail!(
84 "SpeechProcessor::process should not be used. It does not expect chat messages."
85 )
86 }
87 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
88 Arc::new(SpeechInputsProcessor)
89 }
90 fn get_special_tokens(&self) -> &[&'static str] {
91 &[]
92 }
93 fn template_action(&self) -> MessagesAction {
94 MessagesAction::FlattenOnlyText
96 }
97}
98
99pub struct SpeechInputsProcessor;
100
101#[derive(Clone)]
102pub struct ModelInputs {
103 pub(crate) prompts: Vec<String>,
104}
105
106impl InputsProcessor for SpeechInputsProcessor {
107 fn get_type(&self) -> InputsProcessorType {
108 InputsProcessorType::Text
109 }
110
111 fn process_inputs(
112 &self,
113 _tokenizer: Option<Arc<Tokenizer>>,
114 input_seqs: &mut [&mut Sequence],
115 _is_prompt: bool,
116 _is_xlora: bool,
117 _device: &Device,
118 _no_kv_cache: bool,
119 _last_n_context_len: Option<(usize, usize)>,
120 _return_raw_logits: bool,
121 _other_config: Option<Arc<dyn Any>>,
122 _paged_attn_metadata: Option<PagedAttentionMeta>,
123 prompt_chunksize: Option<NonZeroUsize>,
124 _mapper: Option<&dyn DeviceMapper>,
125 ) -> Box<dyn Iterator<Item = Result<InputProcessorOutput>>> {
126 let make_value = if prompt_chunksize.is_some() {
127 return Box::new(std::iter::once(Err(anyhow::Error::msg(
128 "Prompt batching is unsupported for speech models",
129 ))));
130 } else {
131 || {
132 let inputs = ModelInputs {
133 prompts: input_seqs
134 .iter()
135 .map(|seq| seq.get_initial_prompt().to_string())
136 .collect(),
137 };
138 Ok(InputProcessorOutput {
139 inputs: Box::new(inputs),
140 seq_indices: (0..input_seqs.len()).collect::<Vec<_>>(),
141 })
142 }
143 };
144 Box::new(std::iter::once(make_value()))
145 }
146}
147
148pub struct SpeechPipeline {
149 model_id: String,
150 model: DiaPipeline,
151 metadata: Arc<GeneralMetadata>,
152 dummy_cache: EitherCache,
153 cfg: SpeechGenerationConfig,
154}
155
156pub struct SpeechLoader {
157 pub model_id: String,
158 pub dac_model_id: Option<String>,
159 pub arch: SpeechLoaderType,
160 pub cfg: Option<SpeechGenerationConfig>,
161}
162
163impl Loader for SpeechLoader {
164 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
165 fn load_model_from_hf(
166 &self,
167 revision: Option<String>,
168 token_source: TokenSource,
169 dtype: &dyn TryIntoDType,
170 device: &Device,
171 silent: bool,
172 mapper: DeviceMapSetting,
173 in_situ_quant: Option<IsqType>,
174 paged_attn_config: Option<PagedAttentionConfig>,
175 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
176 let paths: anyhow::Result<Box<dyn ModelPaths>> = {
177 let mut weights = Vec::new();
179
180 let config = {
182 let api = ApiBuilder::new()
183 .with_progress(!silent)
184 .with_token(get_token(&token_source)?)
185 .build()?;
186 let revision = revision.clone().unwrap_or("main".to_string());
187 let api = api.repo(Repo::with_revision(
188 self.model_id.to_string(),
189 RepoType::Model,
190 revision.clone(),
191 ));
192 let model_id = std::path::Path::new(&self.model_id);
193
194 let weight = api_get_file!(api, "model.safetensors", &model_id);
195 let config = api_get_file!(api, "config.json", &model_id);
196 weights.push(weight);
197 config
198 };
199
200 {
202 let api = ApiBuilder::new()
203 .with_progress(!silent)
204 .with_token(get_token(&token_source)?)
205 .build()?;
206 let revision = revision.unwrap_or("main".to_string());
207
208 let dac_model = self
210 .dac_model_id
211 .clone()
212 .unwrap_or_else(|| match self.arch {
213 SpeechLoaderType::Dia => "EricB/dac_44khz".to_string(),
214 });
215
216 let api = api.repo(Repo::with_revision(
217 dac_model.clone(),
218 RepoType::Model,
219 revision.clone(),
220 ));
221 let model_id = std::path::Path::new(&dac_model);
222
223 let weight = api_get_file!(api, "model.safetensors", &model_id);
224 weights.push(weight);
225 }
226
227 Ok(Box::new(SpeechModelPaths { weights, config }))
228 };
229 self.load_model_from_path(
230 &paths?,
231 dtype,
232 device,
233 silent,
234 mapper,
235 in_situ_quant,
236 paged_attn_config,
237 )
238 }
239
240 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
241 fn load_model_from_path(
242 &self,
243 paths: &Box<dyn ModelPaths>,
244 dtype: &dyn TryIntoDType,
245 device: &Device,
246 silent: bool,
247 mapper: DeviceMapSetting,
248 in_situ_quant: Option<IsqType>,
249 _paged_attn_config: Option<PagedAttentionConfig>,
250 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
251 let paths = &paths
252 .as_ref()
253 .as_any()
254 .downcast_ref::<SpeechModelPaths>()
255 .expect("Path downcast failed.");
256
257 if matches!(mapper, DeviceMapSetting::Map(_)) {
258 anyhow::bail!("Device mapping is not supported for speech models.")
259 }
260
261 mistralrs_quant::set_immediate_isq(in_situ_quant, vec![Regex::new(".*")?]);
262
263 let cfg: DiaConfig = serde_json::from_str(&std::fs::read_to_string(&paths.config)?)?;
264
265 let mapper = DeviceMapSetting::dummy().into_mapper(usize::MAX, device, None)?;
266 let dtype = mapper.get_min_dtype(dtype)?;
267
268 let model_weights = paths.weights[..paths.weights.len() - 1].to_vec();
270 let vb = from_mmaped_safetensors(
271 model_weights,
272 Vec::new(),
273 Some(dtype),
274 device,
275 vec![None],
276 silent,
277 None,
278 |_| true,
279 Arc::new(|_| DeviceForLoadTensor::Base),
280 )?;
281
282 let dac_vb = unsafe {
283 VarBuilder::from_mmaped_safetensors(&[paths.weights.last().unwrap()], dtype, device)?
284 };
285
286 assert_eq!(self.arch, SpeechLoaderType::Dia);
288
289 let model = DiaPipeline::new(&cfg, vb, dac_vb)?;
290
291 Ok(Arc::new(Mutex::new(SpeechPipeline {
292 model_id: self.model_id.clone(),
293 model,
294 metadata: Arc::new(GeneralMetadata {
295 max_seq_len: 1024,
296 llg_factory: None,
297 is_xlora: false,
298 no_prefix_cache: false,
299 num_hidden_layers: 1, eos_tok: vec![],
301 kind: ModelKind::Normal,
302 no_kv_cache: true, activation_dtype: dtype,
304 sliding_window: None,
305 cache_config: None,
306 cache_engine: None,
307 prompt_chunksize: None,
308 model_metadata: None,
309 modalities: Modalities {
310 input: vec![SupportedModality::Text],
311 output: vec![SupportedModality::Audio],
312 },
313 }),
314 dummy_cache: EitherCache::Full(Cache::new(0, false)),
315 cfg: self
316 .cfg
317 .unwrap_or_else(|| SpeechGenerationConfig::default(self.arch)),
318 })))
319 }
320
321 fn get_id(&self) -> String {
322 self.model_id.clone()
323 }
324
325 fn get_kind(&self) -> ModelKind {
326 ModelKind::Normal
327 }
328}
329
330impl PreProcessingMixin for SpeechPipeline {
331 fn get_processor(&self) -> Arc<dyn Processor> {
332 Arc::new(SpeechProcessor)
333 }
334 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
335 None
336 }
337 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
338 None
339 }
340}
341
342impl IsqPipelineMixin for SpeechPipeline {
343 fn re_isq_model(&mut self, _dtype: IsqType) -> Result<()> {
344 anyhow::bail!("Speech models do not support ISQ for now.")
345 }
346}
347
348impl CacheManagerMixin for SpeechPipeline {
349 fn clone_in_cache(&self, _seqs: &mut [&mut Sequence]) {}
350 fn clone_out_cache(&self, _seqs: &mut [&mut Sequence]) {}
351 fn set_none_cache(
352 &self,
353 _seqs: &mut [&mut Sequence],
354 _reset_non_granular: bool,
355 _modify_draft_cache: bool,
356 _load_preallocated_cache: bool,
357 ) {
358 }
359 fn cache(&self) -> &EitherCache {
360 &self.dummy_cache
361 }
362}
363
364impl MetadataMixin for SpeechPipeline {
365 fn device(&self) -> Device {
366 self.model.device().clone()
367 }
368 fn get_metadata(&self) -> Arc<GeneralMetadata> {
369 self.metadata.clone()
370 }
371 fn name(&self) -> String {
372 self.model_id.clone()
373 }
374 fn reset_non_granular_state(&self) {}
375 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
376 None
377 }
378 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
379 None
380 }
381}
382
383#[async_trait::async_trait]
384impl Pipeline for SpeechPipeline {
385 fn forward_inputs(
386 &mut self,
387 inputs: Box<dyn Any>,
388 return_raw_logits: bool,
389 ) -> candle_core::Result<ForwardInputsResult> {
390 assert!(!return_raw_logits);
391
392 let ModelInputs { prompts } = *inputs.downcast().expect("Downcast failed.");
393 let mut pcms = Vec::new();
394 let mut rates = Vec::new();
395 let mut channels_all = Vec::new();
396 for prompt in prompts {
397 let SpeechGenerationOutput {
398 pcm,
399 rate,
400 channels,
401 } = self.model.generate(&prompt, &self.cfg)?;
402 pcms.push(pcm);
403 rates.push(rate);
404 channels_all.push(channels);
405 }
406
407 Ok(ForwardInputsResult::Speech {
408 pcms,
409 rates,
410 channels: channels_all,
411 })
412 }
413
414 async fn sample_causal_gen(
415 &self,
416 _seqs: &mut [&mut Sequence],
417 _logits: Vec<Tensor>,
418 _prefix_cacher: &mut PrefixCacheManagerV2,
419 _disable_eos_stop: bool,
420 _srng: Arc<std::sync::Mutex<Isaac64Rng>>,
421 ) -> Result<(), candle_core::Error> {
422 candle_core::bail!("`sample_causal_gen` is incompatible with `SpeechPipeline`");
423 }
424
425 fn category(&self) -> ModelCategory {
426 ModelCategory::Speech
427 }
428}
429
430impl AnyMoePipelineMixin for SpeechPipeline {}