1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod speech;
19mod vision;
20
21pub use super::diffusion_models::DiffusionGenerationParams;
22use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
23use crate::device_map::DeviceMapper;
24use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26pub use amoe::{AnyMoeLoader, AnyMoePipeline};
27pub use auto::{AutoLoader, AutoLoaderBuilder};
28use chat_template::ChatTemplate;
29pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
30pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
31pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
32use image::DynamicImage;
33pub use inputs_processor::InputProcessorOutput;
34pub(crate) use isq::IsqModelLoader;
35pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
36use llguidance::toktrie::TokEnv;
37pub use loaders::{
38 AdapterKind, AutoDeviceMapParams, AutoNormalLoader, AutoVisionLoader, DeepSeekV2Loader,
39 DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel,
40 DiffusionModelLoader, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader, Gemma3nLoader,
41 GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
42 LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind,
43 ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
44 Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
45 QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader,
46 SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType,
47 VisionModel, VisionModelLoader,
48};
49use mistralrs_quant::IsqType;
50pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
51pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
52pub use paths::{AdapterPaths, LoraAdapterPaths};
53pub(crate) use processing::{
54 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
55};
56use rand_isaac::Isaac64Rng;
57pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
58pub use speech::{SpeechLoader, SpeechPipeline};
59use std::any::Any;
60use std::collections::HashMap;
61use std::fmt::Debug;
62use std::sync::Arc;
63use std::time::{Duration, Instant};
64use tokenizers::Tokenizer;
65pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
66
67use anyhow::Result;
68use candle_core::{DType, Device, IndexOp, Tensor, Var};
69
70use crate::sequence::Sequence;
71
72pub use self::inputs_processor::{
73 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
74};
75use self::text_models_inputs_processor::PagedAttentionMeta;
76pub use crate::kv_cache::{
77 Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
78};
79
80#[derive(Clone, PartialEq, Eq)]
81pub enum SupportedModality {
82 Text,
83 Audio,
84 Vision,
85}
86
87impl Debug for SupportedModality {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match self {
90 Self::Text => write!(f, "📝 Text"),
91 Self::Audio => write!(f, "🔊 Audio"),
92 Self::Vision => write!(f, "🖼️ Vision"),
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
98pub struct Modalities {
99 pub input: Vec<SupportedModality>,
100 pub output: Vec<SupportedModality>,
101}
102
103pub struct GeneralMetadata {
104 pub max_seq_len: usize,
105 pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
107 pub no_kv_cache: bool,
108 pub no_prefix_cache: bool,
109 pub num_hidden_layers: usize,
110 pub eos_tok: Vec<u32>,
111 pub kind: ModelKind,
112 pub is_xlora: bool,
114 pub activation_dtype: DType,
115 pub sliding_window: Option<usize>,
116 pub cache_config: Option<CacheConfig>,
118 pub cache_engine: Option<CacheEngine>,
119 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
120 pub modalities: Modalities,
121}
122
123impl GeneralMetadata {
124 pub fn tok_env(&self) -> Option<TokEnv> {
125 self.llg_factory.as_ref().map(|f| f.tok_env().clone())
126 }
127}
128
129pub enum CacheInstruction {
130 In,
131 Out,
132 Reset {
134 load_preallocated_cache: bool,
135 reset_non_granular: bool,
136 },
137 Nothing,
138}
139
140pub trait PreProcessingMixin: MetadataMixin {
141 fn get_processor(&self) -> Arc<dyn Processor> {
142 Arc::new(BasicProcessor)
143 }
144 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
146 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
147}
148
149pub trait IsqPipelineMixin {
150 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
151}
152
153pub trait CacheManagerMixin {
154 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
157 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
160 fn set_none_cache(
164 &self,
165 seqs: &mut [&mut Sequence],
166 reset_non_granular: bool,
167 modify_draft_cache: bool,
168 load_preallocated_cache: bool,
169 );
170 fn cache(&self) -> &EitherCache;
171 fn do_preallocated_cache(&self) -> bool {
172 matches!(self.cache(), EitherCache::Normal(_))
173 }
174}
175
176pub trait MetadataMixin {
177 fn device(&self) -> Device;
178 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
180 fn name(&self) -> String;
181 fn reset_non_granular_state(&self);
182 fn get_metadata(&self) -> Arc<GeneralMetadata>;
183 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
184}
185
186pub trait AnyMoePipelineMixin {
188 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
190 unreachable!()
191 }
192 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
193 unreachable!()
194 }
195 fn amoe_base_model_trainable_params(&self) -> usize {
196 unreachable!()
197 }
198 fn amoe_supported(&self) -> bool {
199 false
200 }
201 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
203 unreachable!()
204 }
205 #[allow(clippy::too_many_arguments)]
207 fn amoe_create_layers(
208 &mut self,
209 _model_ids: Vec<String>,
210 _token: &TokenSource,
211 _revision: Option<String>,
212 _match_regex: &str,
213 _config: AnyMoeConfig,
214 _dtype: DType,
215 _dev: &Device,
216 (_prefix, _mlp): (String, String),
217 _layers: Vec<usize>,
218 _expert_type: AnyMoeExpertType,
219 _silent: bool,
220 _gate_model_id: Option<String>,
221 ) -> candle_core::Result<()> {
222 unreachable!()
223 }
224 #[allow(clippy::too_many_arguments)]
226 fn amoe_pre_train(
227 &self,
228 _inputs: AnyMoeTrainingInputs,
229 (_prefix, _mlp): (String, String),
230 _model_ids: Vec<String>,
231 _token: TokenSource,
232 _revision: Option<String>,
233 _layers: Vec<usize>,
234 _silent: bool,
235 ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
236 unreachable!()
237 }
238}
239
240#[derive(Clone)]
243pub enum ModelCategory {
244 Text,
245 Vision {
246 prefixer: Arc<dyn MultimodalPromptPrefixer>,
247 },
248 Diffusion,
249 Audio,
250 Speech,
251}
252
253impl std::fmt::Debug for ModelCategory {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 match self {
256 ModelCategory::Text => write!(f, "ModelCategory::Text"),
257 ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
258 ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
259 ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
260 ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
261 }
262 }
263}
264
265impl PartialEq for ModelCategory {
266 fn eq(&self, other: &Self) -> bool {
267 match (self, other) {
268 (Self::Text, Self::Text) => true,
269 (Self::Vision { .. }, Self::Vision { .. }) => true,
270 (Self::Audio, Self::Audio) => true,
271 (Self::Speech, Self::Speech) => true,
272 (Self::Diffusion, Self::Diffusion) => true,
273 (
274 Self::Text | Self::Vision { .. } | Self::Diffusion | Self::Audio | Self::Speech,
275 _,
276 ) => false,
277 }
278 }
279}
280
281pub trait MultimodalPromptPrefixer: Send + Sync {
283 fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
285 prompt.to_string()
286 }
287 fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
289 prompt.to_string()
290 }
291}
292
293pub enum CacheBackendMetadata {
294 DefaultInstructions {
295 pre_op: CacheInstruction,
296 post_op: CacheInstruction,
297 },
298 PagedAttention {
299 metadata: PagedAttentionMeta,
300 blocks_to_swap_in: HashMap<usize, usize>,
301 blocks_to_swap_out: HashMap<usize, usize>,
302 blocks_to_copy: HashMap<usize, Vec<usize>>,
303 },
304}
305
306#[derive(Clone, Debug)]
307pub enum ForwardInputsResult {
308 RawLogits {
309 logits: Tensor,
310 },
311 CausalGeneration {
312 logits: Tensor,
313 },
314 Image {
315 images: Vec<DynamicImage>,
316 },
317 Speech {
318 pcms: Vec<Arc<Vec<f32>>>,
319 rates: Vec<usize>,
320 channels: Vec<usize>,
321 },
322}
323
324impl ForwardInputsResult {
325 fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
326 match self {
327 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
328 logits: logits.i(bs_idx)?,
329 }),
330 Self::RawLogits { logits } => Ok(Self::RawLogits {
331 logits: logits.i(bs_idx)?,
332 }),
333 Self::Image { images } => Ok(Self::Image {
334 images: vec![images[bs_idx].clone()],
335 }),
336 Self::Speech {
337 pcms,
338 rates,
339 channels,
340 } => Ok(Self::Speech {
341 pcms: vec![pcms[bs_idx].clone()],
342 rates: vec![rates[bs_idx]],
343 channels: vec![channels[bs_idx]],
344 }),
345 }
346 }
347
348 fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
349 match self {
350 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
351 logits: logits.to_device(device)?,
352 }),
353 Self::RawLogits { logits } => Ok(Self::RawLogits {
354 logits: logits.to_device(device)?,
355 }),
356 Self::Image { .. } => Ok(self.clone()),
357 Self::Speech { .. } => Ok(self.clone()),
358 }
359 }
360}
361
362#[derive(serde::Serialize, serde::Deserialize)]
363pub(crate) struct FileListCache {
364 files: Vec<String>,
365}
366
367#[async_trait::async_trait]
368pub trait Pipeline:
369 Send
370 + Sync
371 + PreProcessingMixin
372 + IsqPipelineMixin
373 + CacheManagerMixin
374 + MetadataMixin
375 + AnyMoePipelineMixin
376{
377 fn forward_inputs(
378 &mut self,
379 inputs: Box<dyn Any>,
380 return_raw_logits: bool,
381 ) -> Result<ForwardInputsResult, candle_core::Error>;
382
383 #[allow(clippy::too_many_arguments)]
385 async fn step(
386 &mut self,
387 input_seqs: &mut [&mut Sequence],
388 is_prompt: bool,
389 return_raw_logits: bool,
390 prefix_cacher: &mut PrefixCacheManagerV2,
391 disable_eos_stop: bool,
392 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
393 backend_metadata: CacheBackendMetadata,
394 ) -> Result<Duration, candle_core::Error> {
395 match backend_metadata {
396 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
397 let inputs_iter =
398 std::iter::once(self.get_processor().inputs_processor().process_inputs(
399 self.tokenizer(),
400 input_seqs,
401 is_prompt,
402 self.get_metadata().is_xlora,
403 &self.device(),
404 self.get_metadata().no_kv_cache,
405 None,
406 return_raw_logits,
407 self.get_input_processor_config(),
408 None,
409 self.device_mapper(),
410 ));
411
412 let mut logits = vec![None; input_seqs.len()];
413 let len_inputs = 1;
414 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
415
416 let mut exec_duration = Duration::ZERO;
417 for (i, inputs) in inputs_iter.into_iter().enumerate() {
418 let InputProcessorOutput {
419 inputs,
420 seq_indices,
421 } = inputs.map_err(candle_core::Error::msg)?;
422 if i == 0 {
423 match pre_op {
424 CacheInstruction::In => self.clone_in_cache(input_seqs),
425 CacheInstruction::Nothing => (),
426 CacheInstruction::Reset {
427 load_preallocated_cache,
428 reset_non_granular,
429 } => self.set_none_cache(
430 input_seqs,
431 reset_non_granular,
432 false,
433 load_preallocated_cache,
434 ),
435 _ => unreachable!("Unreachable PRE cache op."),
436 }
437 }
438
439 let start = Instant::now();
440 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
441 let end = Instant::now();
442 exec_duration += end.duration_since(start);
443
444 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
445 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
446 raw_out_logits[seq_idx][i] =
447 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
448 } else {
449 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
450 }
451 }
452 }
453
454 match post_op {
455 CacheInstruction::Out => self.clone_out_cache(input_seqs),
456 CacheInstruction::Nothing => (),
457 CacheInstruction::Reset {
458 load_preallocated_cache,
459 reset_non_granular,
460 } => self.set_none_cache(
461 input_seqs,
462 reset_non_granular,
463 false,
464 load_preallocated_cache,
465 ),
466 _ => unreachable!("Unreachable POST cache op."),
467 }
468
469 if raw_out_logits[0][0].is_some() {
470 let start = Instant::now();
471 response::send_raw_responses(
472 input_seqs,
473 raw_out_logits
474 .into_iter()
475 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
476 .collect(),
477 )
478 .await?;
479 let end = Instant::now();
480 exec_duration += end.duration_since(start);
481
482 return Ok(exec_duration);
483 }
484
485 let start = Instant::now();
486 let logits_on_cpu = logits.len() > 1;
487 let logits = logits
488 .into_iter()
489 .map(|l| {
490 let l = l.expect("Did not get any inputs. This is shocking.");
491 if logits_on_cpu {
492 l.to_device(&Device::Cpu)
493 } else {
494 Ok(l)
495 }
496 })
497 .collect::<candle_core::Result<Vec<_>>>()?;
498
499 match &logits[0] {
500 ForwardInputsResult::RawLogits { .. } => unreachable!(),
501 ForwardInputsResult::CausalGeneration { .. } => {
502 self.sample_causal_gen(
503 input_seqs,
504 logits
505 .into_iter()
506 .map(|r| {
507 #[allow(irrefutable_let_patterns)]
508 let ForwardInputsResult::CausalGeneration { logits } = r
509 else {
510 unreachable!(
511 "All results must have same type, `CausalGeneration`"
512 )
513 };
514 logits
515 })
516 .collect::<Vec<_>>(),
517 prefix_cacher,
518 disable_eos_stop,
519 rng,
520 )
521 .await?;
522 }
523 ForwardInputsResult::Image { .. } => {
524 response::send_image_responses(
525 input_seqs,
526 logits
527 .into_iter()
528 .map(|r| {
529 #[allow(irrefutable_let_patterns)]
530 let ForwardInputsResult::Image { images } = r
531 else {
532 unreachable!("All results must have same type, `Image`")
533 };
534 images
535 .into_iter()
536 .next()
537 .expect("Must have at least 1 element.")
538 })
539 .collect::<Vec<_>>(),
540 )
541 .await?;
542 }
543 ForwardInputsResult::Speech { .. } => {
544 let rates = logits
545 .iter()
546 .map(|r| {
547 #[allow(irrefutable_let_patterns)]
548 let ForwardInputsResult::Speech { rates, .. } = r
549 else {
550 unreachable!("All results must have same type, `Speech`")
551 };
552 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
553 *rates.first().unwrap()
554 })
555 .collect::<Vec<_>>();
556 let channels = logits
557 .iter()
558 .map(|r| {
559 #[allow(irrefutable_let_patterns)]
560 let ForwardInputsResult::Speech { channels, .. } = r
561 else {
562 unreachable!("All results must have same type, `Speech`")
563 };
564 assert_eq!(
565 channels.len(),
566 1,
567 "Each sequence must have 1 PCM output."
568 );
569 *channels.first().unwrap()
570 })
571 .collect::<Vec<_>>();
572 let pcms = logits
573 .into_iter()
574 .map(|r| {
575 #[allow(irrefutable_let_patterns)]
576 let ForwardInputsResult::Speech { pcms, .. } = r
577 else {
578 unreachable!("All results must have same type, `Speech`")
579 };
580 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
581 pcms.into_iter().nth(0).unwrap()
582 })
583 .collect::<Vec<_>>();
584 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
585 .await?;
586 }
587 }
588 let end = Instant::now();
589 exec_duration += end.duration_since(start);
590
591 Ok(exec_duration)
592 }
593 CacheBackendMetadata::PagedAttention {
594 metadata,
595 blocks_to_copy,
596 blocks_to_swap_in,
597 blocks_to_swap_out,
598 } => {
599 self.get_metadata()
601 .cache_engine
602 .as_ref()
603 .expect("PagedAttention must have cache engines.")
604 .execute_scheduler_ops(
605 &blocks_to_swap_in,
606 &blocks_to_swap_out,
607 &blocks_to_copy,
608 )?;
609
610 let inputs_iter =
611 std::iter::once(self.get_processor().inputs_processor().process_inputs(
612 self.tokenizer(),
613 input_seqs,
614 is_prompt,
615 self.get_metadata().is_xlora,
616 &self.device(),
617 self.get_metadata().no_kv_cache,
618 None,
619 return_raw_logits,
620 self.get_input_processor_config(),
621 Some(metadata),
622 self.device_mapper(),
623 ));
624
625 let mut logits = vec![None; input_seqs.len()];
626 let len_inputs = 1;
627 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
628
629 let mut exec_duration = Duration::ZERO;
630 for (i, inputs) in inputs_iter.into_iter().enumerate() {
631 let InputProcessorOutput {
632 inputs,
633 seq_indices,
634 } = inputs.map_err(candle_core::Error::msg)?;
635
636 let start = Instant::now();
637 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
638 let end = Instant::now();
639 exec_duration += end.duration_since(start);
640
641 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
642 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
643 raw_out_logits[seq_idx][i] =
644 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
645 } else {
646 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
647 }
648 }
649 }
650
651 if raw_out_logits[0][0].is_some() {
652 let start = Instant::now();
653 response::send_raw_responses(
654 input_seqs,
655 raw_out_logits
656 .into_iter()
657 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
658 .collect(),
659 )
660 .await?;
661 let end = Instant::now();
662 exec_duration += end.duration_since(start);
663
664 return Ok(exec_duration);
665 }
666
667 let start = Instant::now();
668 let logits_on_cpu = logits.len() > 1;
669 let logits = logits
670 .into_iter()
671 .map(|l| {
672 let l = l.expect("Did not get any inputs. This is shocking.");
673 if logits_on_cpu {
674 l.to_device(&Device::Cpu)
675 } else {
676 Ok(l)
677 }
678 })
679 .collect::<candle_core::Result<Vec<_>>>()?;
680
681 match &logits[0] {
682 ForwardInputsResult::RawLogits { .. } => unreachable!(),
683 ForwardInputsResult::CausalGeneration { .. } => {
684 self.sample_causal_gen(
685 input_seqs,
686 logits
687 .into_iter()
688 .map(|r| {
689 #[allow(irrefutable_let_patterns)]
690 let ForwardInputsResult::CausalGeneration { logits } = r
691 else {
692 unreachable!("All results must have same type")
693 };
694 logits
695 })
696 .collect::<Vec<_>>(),
697 prefix_cacher,
698 disable_eos_stop,
699 rng,
700 )
701 .await?;
702 }
703 ForwardInputsResult::Image { .. } => {
704 response::send_image_responses(
705 input_seqs,
706 logits
707 .into_iter()
708 .map(|r| {
709 #[allow(irrefutable_let_patterns)]
710 let ForwardInputsResult::Image { images } = r
711 else {
712 unreachable!("All results must have same type, `Image`")
713 };
714 images
715 .into_iter()
716 .next()
717 .expect("Must have at least 1 element.")
718 })
719 .collect::<Vec<_>>(),
720 )
721 .await?;
722 }
723 ForwardInputsResult::Speech { .. } => {
724 let rates = logits
725 .iter()
726 .map(|r| {
727 #[allow(irrefutable_let_patterns)]
728 let ForwardInputsResult::Speech { rates, .. } = r
729 else {
730 unreachable!("All results must have same type, `Speech`")
731 };
732 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
733 *rates.first().unwrap()
734 })
735 .collect::<Vec<_>>();
736 let channels = logits
737 .iter()
738 .map(|r| {
739 #[allow(irrefutable_let_patterns)]
740 let ForwardInputsResult::Speech { channels, .. } = r
741 else {
742 unreachable!("All results must have same type, `Speech`")
743 };
744 assert_eq!(
745 channels.len(),
746 1,
747 "Each sequence must have 1 PCM output."
748 );
749 *channels.first().unwrap()
750 })
751 .collect::<Vec<_>>();
752 let pcms = logits
753 .into_iter()
754 .map(|r| {
755 #[allow(irrefutable_let_patterns)]
756 let ForwardInputsResult::Speech { pcms, .. } = r
757 else {
758 unreachable!("All results must have same type, `Speech`")
759 };
760 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
761 pcms.into_iter().nth(0).unwrap()
762 })
763 .collect::<Vec<_>>();
764 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
765 .await?;
766 }
767 }
768 let end = Instant::now();
769 exec_duration += end.duration_since(start);
770
771 Ok(exec_duration)
772 }
773 }
774 }
775
776 async fn sample_causal_gen(
777 &self,
778 seqs: &mut [&mut Sequence],
779 logits: Vec<Tensor>,
780 prefix_cacher: &mut PrefixCacheManagerV2,
781 disable_eos_stop: bool,
782 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
783 ) -> Result<(), candle_core::Error>;
784
785 fn category(&self) -> ModelCategory;
786}
787
788pub(crate) fn extract_logits(
789 logits: &Tensor,
790 context_lens: Vec<(usize, usize)>,
791) -> candle_core::Result<Tensor> {
792 let mut toks = Vec::new();
793 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
794 toks.push(dim.narrow(1, start, len)?);
795 }
796 Tensor::cat(&toks, 0)
797}
798
799#[cfg(test)]
800mod tests {
801 use crate::MessageContent;
802 use either::Either;
803 use indexmap::IndexMap;
804 use serde_json::Value;
805
806 macro_rules! hashmap {
807 (@single $($x:tt)*) => (());
808 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
809
810 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
811 ($($key:expr => $value:expr),*) => {
812 {
813 let _cap = hashmap!(@count $($key),*);
814 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
815 $(
816 let _ = _map.insert($key, Value::String($value));
817 )*
818 _map
819 }
820 };
821 }
822
823 #[cfg(test)]
824 #[track_caller]
825 fn test_with_inputs(
826 templates: &[(bool, &str, &str, &str, &str)],
827 expected_outputs: &[&str],
828 inputs: Vec<IndexMap<String, MessageContent>>,
829 ) {
830 use crate::pipeline::chat_template::ChatTemplateValue;
831
832 use super::chat_template::apply_chat_template_to;
833 let mut failed = Vec::new();
834 let n_templates = templates.len();
835 for ((has_system, bos, eos, unk, template), expected) in
836 templates.iter().zip(expected_outputs)
837 {
838 let output = match apply_chat_template_to(
839 if !has_system {
840 inputs[1..].to_vec()
841 } else {
842 inputs.clone()
843 },
844 true,
845 None,
846 &ChatTemplateValue(Either::Left(template.to_string())),
847 Some(bos.to_string()),
848 Some(eos.to_string()),
849 Some(unk.to_string()),
850 Vec::new(),
851 ) {
852 Ok(v) => v,
853 Err(e) => {
854 failed.push(format!("Failed with {e}."));
855 continue;
856 }
857 };
858 if output != *expected {
859 failed.push(format!(
860 "Expected: `{}` \n\nGot: `{}`",
861 expected.replace('\n', "\\n"),
862 output.replace('\n', "\\n")
863 ));
864 }
865 }
866 if !failed.is_empty() {
867 for (i, line) in failed.iter().enumerate() {
868 println!("------------ Template {i} ------------");
869 println!("{line}");
870 }
871 println!("------------------------");
872 panic!("{}/{n_templates} chat templates failed.", failed.len());
873 }
874 }
875
876 #[test]
877 fn test_chat_templates() {
886 let templates = [
887 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
889 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
891 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
893 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
895 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
897 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
899 ];
900 let expected_outputs = [
901 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
903 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
905 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
907 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
909 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
911 ];
912 let messages = [
913 ["system", "You are a helpful assistant"],
914 ["user", "Hello"],
915 ["assistant", "Hi there"],
916 ["user", "Who are you"],
917 ["assistant", " I am an assistant "],
918 ["user", "Another question"],
919 ];
920 let mut inputs = Vec::new();
921 for [role, content] in messages {
922 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
923 IndexMap::new();
924 message.insert("role".to_string(), Either::Left(role.to_string()));
925 message.insert("content".to_string(), Either::Left(content.to_string()));
926 inputs.push(message);
927 }
928 test_with_inputs(&templates, &expected_outputs, inputs);
929 }
930
931 #[test]
932 fn test_image_chat_templates() {
945 let templates = [
946 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
948 ];
949 let expected_outputs = [
950 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
952 ];
953
954 let mut inputs = Vec::new();
955
956 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
957 IndexMap::new();
958 message.insert("role".to_string(), Either::Left("system".to_string()));
959 message.insert(
960 "content".to_string(),
961 Either::Right(vec![hashmap! {
962 "type".to_string() => "text".to_string(),
963 "text".to_string() => "You are a helpful assistant".to_string()
964 }]),
965 );
966 inputs.push(message);
967
968 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
969 IndexMap::new();
970 message.insert("role".to_string(), Either::Left("user".to_string()));
971 message.insert(
972 "content".to_string(),
973 Either::Right(vec![
974 hashmap! {
975 "type".to_string() => "image".to_string()
976 },
977 hashmap! {
978 "type".to_string() => "text".to_string(),
979 "text".to_string() => "Hello, please describe the above.".to_string()
980 },
981 ]),
982 );
983 inputs.push(message);
984
985 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
986 IndexMap::new();
987 message.insert("role".to_string(), Either::Left("assistant".to_string()));
988 message.insert(
989 "content".to_string(),
990 Either::Right(vec![hashmap! {
991 "type".to_string() => "text".to_string(),
992 "text".to_string() => "Hi there".to_string()
993 }]),
994 );
995 inputs.push(message);
996
997 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
998 IndexMap::new();
999 message.insert("role".to_string(), Either::Left("user".to_string()));
1000 message.insert(
1001 "content".to_string(),
1002 Either::Right(vec![
1003 hashmap! {
1004 "type".to_string() => "image".to_string()
1005 },
1006 hashmap! {
1007 "type".to_string() => "text".to_string(),
1008 "text".to_string() => "This is me, who are you".to_string()
1009 },
1010 ]),
1011 );
1012 inputs.push(message);
1013
1014 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1015 IndexMap::new();
1016 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1017 message.insert(
1018 "content".to_string(),
1019 Either::Right(vec![hashmap! {
1020 "type".to_string() => "text".to_string(),
1021 "text".to_string() => " I am an assistant ".to_string()
1022 }]),
1023 );
1024 inputs.push(message);
1025
1026 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1027 IndexMap::new();
1028 message.insert("role".to_string(), Either::Left("user".to_string()));
1029 message.insert(
1030 "content".to_string(),
1031 Either::Right(vec![
1032 hashmap! {
1033 "type".to_string() => "image".to_string()
1034 },
1035 hashmap! {
1036 "type".to_string() => "text".to_string(),
1037 "text".to_string() => "Another question, what is this?".to_string()
1038 },
1039 ]),
1040 );
1041 inputs.push(message);
1042
1043 test_with_inputs(&templates, &expected_outputs, inputs);
1044 }
1045}