1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod speech;
19mod vision;
20
21pub use super::diffusion_models::DiffusionGenerationParams;
22use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
23use crate::device_map::DeviceMapper;
24use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26pub use amoe::{AnyMoeLoader, AnyMoePipeline};
27pub use auto::{AutoLoader, AutoLoaderBuilder};
28use chat_template::ChatTemplate;
29pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
30pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
31pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
32use image::DynamicImage;
33pub use inputs_processor::InputProcessorOutput;
34pub(crate) use isq::IsqModelLoader;
35pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
36use llguidance::toktrie::TokEnv;
37pub use loaders::{
38 AdapterKind, AutoDeviceMapParams, AutoNormalLoader, AutoVisionLoader, DeepSeekV2Loader,
39 DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel,
40 DiffusionModelLoader, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader, GemmaLoader,
41 Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
42 LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind,
43 ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
44 Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
45 QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader,
46 Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
47 VisionModelLoader,
48};
49use mistralrs_quant::IsqType;
50pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
51pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
52pub use paths::{AdapterPaths, LoraAdapterPaths};
53pub(crate) use processing::{
54 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
55};
56use rand_isaac::Isaac64Rng;
57pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
58pub use speech::{SpeechLoader, SpeechPipeline};
59use std::any::Any;
60use std::collections::HashMap;
61use std::fmt::Debug;
62use std::num::NonZeroUsize;
63use std::sync::Arc;
64use std::time::{Duration, Instant};
65use tokenizers::Tokenizer;
66pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
67
68use anyhow::Result;
69use candle_core::{DType, Device, IndexOp, Tensor, Var};
70
71use crate::sequence::Sequence;
72
73pub use self::inputs_processor::{
74 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
75};
76use self::text_models_inputs_processor::PagedAttentionMeta;
77pub use crate::kv_cache::{
78 Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
79};
80
81#[derive(Clone, PartialEq, Eq)]
82pub enum SupportedModality {
83 Text,
84 Audio,
85 Vision,
86}
87
88impl Debug for SupportedModality {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 match self {
91 Self::Text => write!(f, "📝 Text"),
92 Self::Audio => write!(f, "🔊 Audio"),
93 Self::Vision => write!(f, "🖼️ Vision"),
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
99pub struct Modalities {
100 pub input: Vec<SupportedModality>,
101 pub output: Vec<SupportedModality>,
102}
103
104pub struct GeneralMetadata {
105 pub max_seq_len: usize,
106 pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
108 pub no_kv_cache: bool,
109 pub no_prefix_cache: bool,
110 pub num_hidden_layers: usize,
111 pub eos_tok: Vec<u32>,
112 pub kind: ModelKind,
113 pub is_xlora: bool,
115 pub activation_dtype: DType,
116 pub sliding_window: Option<usize>,
117 pub cache_config: Option<CacheConfig>,
119 pub cache_engine: Option<CacheEngine>,
120 pub prompt_chunksize: Option<NonZeroUsize>,
121 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
122 pub modalities: Modalities,
123}
124
125impl GeneralMetadata {
126 pub fn tok_env(&self) -> Option<TokEnv> {
127 self.llg_factory.as_ref().map(|f| f.tok_env().clone())
128 }
129}
130
131pub enum CacheInstruction {
132 In,
133 Out,
134 Reset {
136 load_preallocated_cache: bool,
137 reset_non_granular: bool,
138 },
139 Nothing,
140}
141
142pub trait PreProcessingMixin: MetadataMixin {
143 fn get_processor(&self) -> Arc<dyn Processor> {
144 Arc::new(BasicProcessor)
145 }
146 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
148 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
149}
150
151pub trait IsqPipelineMixin {
152 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
153}
154
155pub trait CacheManagerMixin {
156 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
159 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
162 fn set_none_cache(
166 &self,
167 seqs: &mut [&mut Sequence],
168 reset_non_granular: bool,
169 modify_draft_cache: bool,
170 load_preallocated_cache: bool,
171 );
172 fn cache(&self) -> &EitherCache;
173 fn do_preallocated_cache(&self) -> bool {
174 matches!(self.cache(), EitherCache::Normal(_))
175 }
176}
177
178pub trait MetadataMixin {
179 fn device(&self) -> Device;
180 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
182 fn name(&self) -> String;
183 fn reset_non_granular_state(&self);
184 fn get_metadata(&self) -> Arc<GeneralMetadata>;
185 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
186}
187
188pub trait AnyMoePipelineMixin {
190 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
192 unreachable!()
193 }
194 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
195 unreachable!()
196 }
197 fn amoe_base_model_trainable_params(&self) -> usize {
198 unreachable!()
199 }
200 fn amoe_supported(&self) -> bool {
201 false
202 }
203 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
205 unreachable!()
206 }
207 #[allow(clippy::too_many_arguments)]
209 fn amoe_create_layers(
210 &mut self,
211 _model_ids: Vec<String>,
212 _token: &TokenSource,
213 _revision: Option<String>,
214 _match_regex: &str,
215 _config: AnyMoeConfig,
216 _dtype: DType,
217 _dev: &Device,
218 (_prefix, _mlp): (String, String),
219 _layers: Vec<usize>,
220 _expert_type: AnyMoeExpertType,
221 _silent: bool,
222 _gate_model_id: Option<String>,
223 ) -> candle_core::Result<()> {
224 unreachable!()
225 }
226 #[allow(clippy::too_many_arguments)]
228 fn amoe_pre_train(
229 &self,
230 _inputs: AnyMoeTrainingInputs,
231 (_prefix, _mlp): (String, String),
232 _model_ids: Vec<String>,
233 _token: TokenSource,
234 _revision: Option<String>,
235 _layers: Vec<usize>,
236 _silent: bool,
237 ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
238 unreachable!()
239 }
240}
241
242#[derive(Clone)]
245pub enum ModelCategory {
246 Text,
247 Vision {
248 prefixer: Arc<dyn MultimodalPromptPrefixer>,
249 },
250 Diffusion,
251 Audio,
252 Speech,
253}
254
255impl std::fmt::Debug for ModelCategory {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 match self {
258 ModelCategory::Text => write!(f, "ModelCategory::Text"),
259 ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
260 ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
261 ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
262 ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
263 }
264 }
265}
266
267impl PartialEq for ModelCategory {
268 fn eq(&self, other: &Self) -> bool {
269 match (self, other) {
270 (Self::Text, Self::Text) => true,
271 (Self::Vision { .. }, Self::Vision { .. }) => true,
272 (Self::Audio, Self::Audio) => true,
273 (Self::Speech, Self::Speech) => true,
274 (Self::Diffusion, Self::Diffusion) => true,
275 (
276 Self::Text | Self::Vision { .. } | Self::Diffusion | Self::Audio | Self::Speech,
277 _,
278 ) => false,
279 }
280 }
281}
282
283pub trait MultimodalPromptPrefixer: Send + Sync {
285 fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
287 prompt.to_string()
288 }
289 fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
291 prompt.to_string()
292 }
293}
294
295pub enum CacheBackendMetadata {
296 DefaultInstructions {
297 pre_op: CacheInstruction,
298 post_op: CacheInstruction,
299 },
300 PagedAttention {
301 metadata: PagedAttentionMeta,
302 blocks_to_swap_in: HashMap<usize, usize>,
303 blocks_to_swap_out: HashMap<usize, usize>,
304 blocks_to_copy: HashMap<usize, Vec<usize>>,
305 },
306}
307
308#[derive(Clone, Debug)]
309pub enum ForwardInputsResult {
310 RawLogits {
311 logits: Tensor,
312 },
313 CausalGeneration {
314 logits: Tensor,
315 },
316 Image {
317 images: Vec<DynamicImage>,
318 },
319 Speech {
320 pcms: Vec<Arc<Vec<f32>>>,
321 rates: Vec<usize>,
322 channels: Vec<usize>,
323 },
324}
325
326impl ForwardInputsResult {
327 fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
328 match self {
329 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
330 logits: logits.i(bs_idx)?,
331 }),
332 Self::RawLogits { logits } => Ok(Self::RawLogits {
333 logits: logits.i(bs_idx)?,
334 }),
335 Self::Image { images } => Ok(Self::Image {
336 images: vec![images[bs_idx].clone()],
337 }),
338 Self::Speech {
339 pcms,
340 rates,
341 channels,
342 } => Ok(Self::Speech {
343 pcms: vec![pcms[bs_idx].clone()],
344 rates: vec![rates[bs_idx]],
345 channels: vec![channels[bs_idx]],
346 }),
347 }
348 }
349
350 fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
351 match self {
352 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
353 logits: logits.to_device(device)?,
354 }),
355 Self::RawLogits { logits } => Ok(Self::RawLogits {
356 logits: logits.to_device(device)?,
357 }),
358 Self::Image { .. } => Ok(self.clone()),
359 Self::Speech { .. } => Ok(self.clone()),
360 }
361 }
362}
363
364#[derive(serde::Serialize, serde::Deserialize)]
365pub(crate) struct FileListCache {
366 files: Vec<String>,
367}
368
369#[async_trait::async_trait]
370pub trait Pipeline:
371 Send
372 + Sync
373 + PreProcessingMixin
374 + IsqPipelineMixin
375 + CacheManagerMixin
376 + MetadataMixin
377 + AnyMoePipelineMixin
378{
379 fn forward_inputs(
380 &mut self,
381 inputs: Box<dyn Any>,
382 return_raw_logits: bool,
383 ) -> Result<ForwardInputsResult, candle_core::Error>;
384
385 #[allow(clippy::too_many_arguments)]
387 async fn step(
388 &mut self,
389 input_seqs: &mut [&mut Sequence],
390 is_prompt: bool,
391 return_raw_logits: bool,
392 prefix_cacher: &mut PrefixCacheManagerV2,
393 disable_eos_stop: bool,
394 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
395 backend_metadata: CacheBackendMetadata,
396 ) -> Result<Duration, candle_core::Error> {
397 match backend_metadata {
398 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
399 let inputs_iter = self.get_processor().inputs_processor().process_inputs(
400 self.tokenizer(),
401 input_seqs,
402 is_prompt,
403 self.get_metadata().is_xlora,
404 &self.device(),
405 self.get_metadata().no_kv_cache,
406 None,
407 return_raw_logits,
408 self.get_input_processor_config(),
409 None,
410 self.get_metadata().prompt_chunksize,
411 self.device_mapper(),
412 );
413
414 let mut logits = vec![None; input_seqs.len()];
415 let prompt_chunksize = self
416 .get_metadata()
417 .prompt_chunksize
418 .map(NonZeroUsize::get)
419 .unwrap_or(1);
420 let len_inputs = input_seqs
421 .iter()
422 .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
423 .max()
424 .unwrap();
425 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
426
427 let mut exec_duration = Duration::ZERO;
428 for (i, inputs) in inputs_iter.into_iter().enumerate() {
429 let InputProcessorOutput {
430 inputs,
431 seq_indices,
432 } = inputs.map_err(candle_core::Error::msg)?;
433 if i == 0 {
434 match pre_op {
435 CacheInstruction::In => self.clone_in_cache(input_seqs),
436 CacheInstruction::Nothing => (),
437 CacheInstruction::Reset {
438 load_preallocated_cache,
439 reset_non_granular,
440 } => self.set_none_cache(
441 input_seqs,
442 reset_non_granular,
443 false,
444 load_preallocated_cache,
445 ),
446 _ => unreachable!("Unreachable PRE cache op."),
447 }
448 }
449
450 let start = Instant::now();
451 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
452 let end = Instant::now();
453 exec_duration += end.duration_since(start);
454
455 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
456 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
457 raw_out_logits[seq_idx][i] =
458 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
459 } else {
460 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
461 }
462 }
463 }
464
465 match post_op {
466 CacheInstruction::Out => self.clone_out_cache(input_seqs),
467 CacheInstruction::Nothing => (),
468 CacheInstruction::Reset {
469 load_preallocated_cache,
470 reset_non_granular,
471 } => self.set_none_cache(
472 input_seqs,
473 reset_non_granular,
474 false,
475 load_preallocated_cache,
476 ),
477 _ => unreachable!("Unreachable POST cache op."),
478 }
479
480 if raw_out_logits[0][0].is_some() {
481 let start = Instant::now();
482 response::send_raw_responses(
483 input_seqs,
484 raw_out_logits
485 .into_iter()
486 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
487 .collect(),
488 )
489 .await?;
490 let end = Instant::now();
491 exec_duration += end.duration_since(start);
492
493 return Ok(exec_duration);
494 }
495
496 let start = Instant::now();
497 let logits_on_cpu = logits.len() > 1;
498 let logits = logits
499 .into_iter()
500 .map(|l| {
501 let l = l.expect("Did not get any inputs. This is shocking.");
502 if logits_on_cpu {
503 l.to_device(&Device::Cpu)
504 } else {
505 Ok(l)
506 }
507 })
508 .collect::<candle_core::Result<Vec<_>>>()?;
509
510 match &logits[0] {
511 ForwardInputsResult::RawLogits { .. } => unreachable!(),
512 ForwardInputsResult::CausalGeneration { .. } => {
513 self.sample_causal_gen(
514 input_seqs,
515 logits
516 .into_iter()
517 .map(|r| {
518 #[allow(irrefutable_let_patterns)]
519 let ForwardInputsResult::CausalGeneration { logits } = r
520 else {
521 unreachable!(
522 "All results must have same type, `CausalGeneration`"
523 )
524 };
525 logits
526 })
527 .collect::<Vec<_>>(),
528 prefix_cacher,
529 disable_eos_stop,
530 rng,
531 )
532 .await?;
533 }
534 ForwardInputsResult::Image { .. } => {
535 response::send_image_responses(
536 input_seqs,
537 logits
538 .into_iter()
539 .map(|r| {
540 #[allow(irrefutable_let_patterns)]
541 let ForwardInputsResult::Image { images } = r
542 else {
543 unreachable!("All results must have same type, `Image`")
544 };
545 images
546 .into_iter()
547 .next()
548 .expect("Must have at least 1 element.")
549 })
550 .collect::<Vec<_>>(),
551 )
552 .await?;
553 }
554 ForwardInputsResult::Speech { .. } => {
555 let rates = logits
556 .iter()
557 .map(|r| {
558 #[allow(irrefutable_let_patterns)]
559 let ForwardInputsResult::Speech { rates, .. } = r
560 else {
561 unreachable!("All results must have same type, `Speech`")
562 };
563 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
564 *rates.first().unwrap()
565 })
566 .collect::<Vec<_>>();
567 let channels = logits
568 .iter()
569 .map(|r| {
570 #[allow(irrefutable_let_patterns)]
571 let ForwardInputsResult::Speech { channels, .. } = r
572 else {
573 unreachable!("All results must have same type, `Speech`")
574 };
575 assert_eq!(
576 channels.len(),
577 1,
578 "Each sequence must have 1 PCM output."
579 );
580 *channels.first().unwrap()
581 })
582 .collect::<Vec<_>>();
583 let pcms = logits
584 .into_iter()
585 .map(|r| {
586 #[allow(irrefutable_let_patterns)]
587 let ForwardInputsResult::Speech { pcms, .. } = r
588 else {
589 unreachable!("All results must have same type, `Speech`")
590 };
591 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
592 pcms.into_iter().nth(0).unwrap()
593 })
594 .collect::<Vec<_>>();
595 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
596 .await?;
597 }
598 }
599 let end = Instant::now();
600 exec_duration += end.duration_since(start);
601
602 Ok(exec_duration)
603 }
604 CacheBackendMetadata::PagedAttention {
605 metadata,
606 blocks_to_copy,
607 blocks_to_swap_in,
608 blocks_to_swap_out,
609 } => {
610 self.get_metadata()
612 .cache_engine
613 .as_ref()
614 .expect("PagedAttention must have cache engines.")
615 .execute_scheduler_ops(
616 &blocks_to_swap_in,
617 &blocks_to_swap_out,
618 &blocks_to_copy,
619 )?;
620
621 let inputs_iter = self.get_processor().inputs_processor().process_inputs(
622 self.tokenizer(),
623 input_seqs,
624 is_prompt,
625 self.get_metadata().is_xlora,
626 &self.device(),
627 self.get_metadata().no_kv_cache,
628 None,
629 return_raw_logits,
630 self.get_input_processor_config(),
631 Some(metadata),
632 self.get_metadata().prompt_chunksize,
633 self.device_mapper(),
634 );
635
636 let mut logits = vec![None; input_seqs.len()];
637 let prompt_chunksize = self
638 .get_metadata()
639 .prompt_chunksize
640 .map(NonZeroUsize::get)
641 .unwrap_or(1);
642 let len_inputs = input_seqs
643 .iter()
644 .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
645 .max()
646 .unwrap();
647 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
648
649 let mut exec_duration = Duration::ZERO;
650 for (i, inputs) in inputs_iter.into_iter().enumerate() {
651 let InputProcessorOutput {
652 inputs,
653 seq_indices,
654 } = inputs.map_err(candle_core::Error::msg)?;
655
656 let start = Instant::now();
657 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
658 let end = Instant::now();
659 exec_duration += end.duration_since(start);
660
661 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
662 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
663 raw_out_logits[seq_idx][i] =
664 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
665 } else {
666 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
667 }
668 }
669 }
670
671 if raw_out_logits[0][0].is_some() {
672 let start = Instant::now();
673 response::send_raw_responses(
674 input_seqs,
675 raw_out_logits
676 .into_iter()
677 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
678 .collect(),
679 )
680 .await?;
681 let end = Instant::now();
682 exec_duration += end.duration_since(start);
683
684 return Ok(exec_duration);
685 }
686
687 let start = Instant::now();
688 let logits_on_cpu = logits.len() > 1;
689 let logits = logits
690 .into_iter()
691 .map(|l| {
692 let l = l.expect("Did not get any inputs. This is shocking.");
693 if logits_on_cpu {
694 l.to_device(&Device::Cpu)
695 } else {
696 Ok(l)
697 }
698 })
699 .collect::<candle_core::Result<Vec<_>>>()?;
700
701 match &logits[0] {
702 ForwardInputsResult::RawLogits { .. } => unreachable!(),
703 ForwardInputsResult::CausalGeneration { .. } => {
704 self.sample_causal_gen(
705 input_seqs,
706 logits
707 .into_iter()
708 .map(|r| {
709 #[allow(irrefutable_let_patterns)]
710 let ForwardInputsResult::CausalGeneration { logits } = r
711 else {
712 unreachable!("All results must have same type")
713 };
714 logits
715 })
716 .collect::<Vec<_>>(),
717 prefix_cacher,
718 disable_eos_stop,
719 rng,
720 )
721 .await?;
722 }
723 ForwardInputsResult::Image { .. } => {
724 response::send_image_responses(
725 input_seqs,
726 logits
727 .into_iter()
728 .map(|r| {
729 #[allow(irrefutable_let_patterns)]
730 let ForwardInputsResult::Image { images } = r
731 else {
732 unreachable!("All results must have same type, `Image`")
733 };
734 images
735 .into_iter()
736 .next()
737 .expect("Must have at least 1 element.")
738 })
739 .collect::<Vec<_>>(),
740 )
741 .await?;
742 }
743 ForwardInputsResult::Speech { .. } => {
744 let rates = logits
745 .iter()
746 .map(|r| {
747 #[allow(irrefutable_let_patterns)]
748 let ForwardInputsResult::Speech { rates, .. } = r
749 else {
750 unreachable!("All results must have same type, `Speech`")
751 };
752 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
753 *rates.first().unwrap()
754 })
755 .collect::<Vec<_>>();
756 let channels = logits
757 .iter()
758 .map(|r| {
759 #[allow(irrefutable_let_patterns)]
760 let ForwardInputsResult::Speech { channels, .. } = r
761 else {
762 unreachable!("All results must have same type, `Speech`")
763 };
764 assert_eq!(
765 channels.len(),
766 1,
767 "Each sequence must have 1 PCM output."
768 );
769 *channels.first().unwrap()
770 })
771 .collect::<Vec<_>>();
772 let pcms = logits
773 .into_iter()
774 .map(|r| {
775 #[allow(irrefutable_let_patterns)]
776 let ForwardInputsResult::Speech { pcms, .. } = r
777 else {
778 unreachable!("All results must have same type, `Speech`")
779 };
780 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
781 pcms.into_iter().nth(0).unwrap()
782 })
783 .collect::<Vec<_>>();
784 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
785 .await?;
786 }
787 }
788 let end = Instant::now();
789 exec_duration += end.duration_since(start);
790
791 Ok(exec_duration)
792 }
793 }
794 }
795
796 async fn sample_causal_gen(
797 &self,
798 seqs: &mut [&mut Sequence],
799 logits: Vec<Tensor>,
800 prefix_cacher: &mut PrefixCacheManagerV2,
801 disable_eos_stop: bool,
802 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
803 ) -> Result<(), candle_core::Error>;
804
805 fn category(&self) -> ModelCategory;
806}
807
808pub(crate) fn extract_logits(
809 logits: &Tensor,
810 context_lens: Vec<(usize, usize)>,
811) -> candle_core::Result<Tensor> {
812 let mut toks = Vec::new();
813 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
814 toks.push(dim.narrow(1, start, len)?);
815 }
816 Tensor::cat(&toks, 0)
817}
818
819#[cfg(test)]
820mod tests {
821 use crate::MessageContent;
822 use either::Either;
823 use indexmap::IndexMap;
824 use serde_json::Value;
825
826 macro_rules! hashmap {
827 (@single $($x:tt)*) => (());
828 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
829
830 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
831 ($($key:expr => $value:expr),*) => {
832 {
833 let _cap = hashmap!(@count $($key),*);
834 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
835 $(
836 let _ = _map.insert($key, Value::String($value));
837 )*
838 _map
839 }
840 };
841 }
842
843 #[cfg(test)]
844 #[track_caller]
845 fn test_with_inputs(
846 templates: &[(bool, &str, &str, &str, &str)],
847 expected_outputs: &[&str],
848 inputs: Vec<IndexMap<String, MessageContent>>,
849 ) {
850 use crate::pipeline::chat_template::ChatTemplateValue;
851
852 use super::chat_template::apply_chat_template_to;
853 let mut failed = Vec::new();
854 let n_templates = templates.len();
855 for ((has_system, bos, eos, unk, template), expected) in
856 templates.iter().zip(expected_outputs)
857 {
858 let output = match apply_chat_template_to(
859 if !has_system {
860 inputs[1..].to_vec()
861 } else {
862 inputs.clone()
863 },
864 true,
865 None,
866 &ChatTemplateValue(Either::Left(template.to_string())),
867 Some(bos.to_string()),
868 Some(eos.to_string()),
869 Some(unk.to_string()),
870 Vec::new(),
871 ) {
872 Ok(v) => v,
873 Err(e) => {
874 failed.push(format!("Failed with {e}."));
875 continue;
876 }
877 };
878 if output != *expected {
879 failed.push(format!(
880 "Expected: `{}` \n\nGot: `{}`",
881 expected.replace('\n', "\\n"),
882 output.replace('\n', "\\n")
883 ));
884 }
885 }
886 if !failed.is_empty() {
887 for (i, line) in failed.iter().enumerate() {
888 println!("------------ Template {i} ------------");
889 println!("{line}");
890 }
891 println!("------------------------");
892 panic!("{}/{n_templates} chat templates failed.", failed.len());
893 }
894 }
895
896 #[test]
897 fn test_chat_templates() {
906 let templates = [
907 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
909 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
911 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
913 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
915 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
917 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
919 ];
920 let expected_outputs = [
921 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
923 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
925 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
927 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
929 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
931 ];
932 let messages = [
933 ["system", "You are a helpful assistant"],
934 ["user", "Hello"],
935 ["assistant", "Hi there"],
936 ["user", "Who are you"],
937 ["assistant", " I am an assistant "],
938 ["user", "Another question"],
939 ];
940 let mut inputs = Vec::new();
941 for [role, content] in messages {
942 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
943 IndexMap::new();
944 message.insert("role".to_string(), Either::Left(role.to_string()));
945 message.insert("content".to_string(), Either::Left(content.to_string()));
946 inputs.push(message);
947 }
948 test_with_inputs(&templates, &expected_outputs, inputs);
949 }
950
951 #[test]
952 fn test_image_chat_templates() {
965 let templates = [
966 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
968 ];
969 let expected_outputs = [
970 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
972 ];
973
974 let mut inputs = Vec::new();
975
976 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
977 IndexMap::new();
978 message.insert("role".to_string(), Either::Left("system".to_string()));
979 message.insert(
980 "content".to_string(),
981 Either::Right(vec![hashmap! {
982 "type".to_string() => "text".to_string(),
983 "text".to_string() => "You are a helpful assistant".to_string()
984 }]),
985 );
986 inputs.push(message);
987
988 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
989 IndexMap::new();
990 message.insert("role".to_string(), Either::Left("user".to_string()));
991 message.insert(
992 "content".to_string(),
993 Either::Right(vec![
994 hashmap! {
995 "type".to_string() => "image".to_string()
996 },
997 hashmap! {
998 "type".to_string() => "text".to_string(),
999 "text".to_string() => "Hello, please describe the above.".to_string()
1000 },
1001 ]),
1002 );
1003 inputs.push(message);
1004
1005 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1006 IndexMap::new();
1007 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1008 message.insert(
1009 "content".to_string(),
1010 Either::Right(vec![hashmap! {
1011 "type".to_string() => "text".to_string(),
1012 "text".to_string() => "Hi there".to_string()
1013 }]),
1014 );
1015 inputs.push(message);
1016
1017 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1018 IndexMap::new();
1019 message.insert("role".to_string(), Either::Left("user".to_string()));
1020 message.insert(
1021 "content".to_string(),
1022 Either::Right(vec![
1023 hashmap! {
1024 "type".to_string() => "image".to_string()
1025 },
1026 hashmap! {
1027 "type".to_string() => "text".to_string(),
1028 "text".to_string() => "This is me, who are you".to_string()
1029 },
1030 ]),
1031 );
1032 inputs.push(message);
1033
1034 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1035 IndexMap::new();
1036 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1037 message.insert(
1038 "content".to_string(),
1039 Either::Right(vec![hashmap! {
1040 "type".to_string() => "text".to_string(),
1041 "text".to_string() => " I am an assistant ".to_string()
1042 }]),
1043 );
1044 inputs.push(message);
1045
1046 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1047 IndexMap::new();
1048 message.insert("role".to_string(), Either::Left("user".to_string()));
1049 message.insert(
1050 "content".to_string(),
1051 Either::Right(vec![
1052 hashmap! {
1053 "type".to_string() => "image".to_string()
1054 },
1055 hashmap! {
1056 "type".to_string() => "text".to_string(),
1057 "text".to_string() => "Another question, what is this?".to_string()
1058 },
1059 ]),
1060 );
1061 inputs.push(message);
1062
1063 test_with_inputs(&templates, &expected_outputs, inputs);
1064 }
1065}