1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod embedding;
6mod ggml;
7mod gguf;
8mod inputs_processor;
9mod isq;
10pub(crate) mod llg;
11mod loaders;
12mod macros;
13mod normal;
14mod paths;
15mod processing;
16mod response;
17mod sampling;
18mod speculative;
19mod speech;
20mod vision;
21
22pub use super::diffusion_models::DiffusionGenerationParams;
23use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
24use crate::device_map::DeviceMapper;
25use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
26use crate::prefix_cacher::PrefixCacheManagerV2;
27pub use amoe::{AnyMoeLoader, AnyMoePipeline};
28pub use auto::{AutoLoader, AutoLoaderBuilder};
29use chat_template::ChatTemplate;
30pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
31pub use embedding::{EmbeddingLoader, EmbeddingLoaderBuilder, EmbeddingSpecificConfig};
32pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
33pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
34use image::DynamicImage;
35pub use inputs_processor::InputProcessorOutput;
36pub(crate) use isq::IsqModelLoader;
37pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
38use llguidance::toktrie::TokEnv;
39pub use loaders::{
40 AdapterKind, AutoDeviceMapParams, AutoEmbeddingLoader, AutoNormalLoader, AutoVisionLoader,
41 DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType,
42 DiffusionModel, DiffusionModelLoader, EmbeddingGemmaLoader, EmbeddingLoaderType,
43 EmbeddingModel, EmbeddingModelLoader, EmbeddingModelPaths, EmbeddingModule,
44 EmbeddingModulePaths, EmbeddingModuleType, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader,
45 Gemma3nLoader, GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
46 LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader,
47 MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel,
48 NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader,
49 PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
50 Qwen3EmbeddingLoader, Qwen3Loader, Qwen3MoELoader, Qwen3VLLoader, SmolLm3Loader,
51 Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
52 VisionModelLoader,
53};
54use mistralrs_quant::IsqType;
55pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
56pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
57pub use paths::{AdapterPaths, LoraAdapterPaths};
58pub(crate) use processing::{
59 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
60};
61use rand_isaac::Isaac64Rng;
62pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
63pub use speech::{SpeechLoader, SpeechPipeline};
64use std::any::Any;
65use std::collections::HashMap;
66use std::fmt::Debug;
67use std::sync::Arc;
68use std::time::{Duration, Instant};
69use tokenizers::Tokenizer;
70pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
71
72use anyhow::Result;
73use candle_core::{DType, Device, IndexOp, Tensor, Var};
74
75use crate::sequence::Sequence;
76
77pub use self::inputs_processor::{
78 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
79};
80use self::text_models_inputs_processor::PagedAttentionMeta;
81pub use crate::kv_cache::{
82 Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
83};
84
85#[derive(Clone, PartialEq, Eq)]
86pub enum SupportedModality {
87 Text,
88 Audio,
89 Vision,
90 Embedding,
91}
92
93impl Debug for SupportedModality {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Self::Text => write!(f, "📝 Text"),
97 Self::Audio => write!(f, "🔊 Audio"),
98 Self::Vision => write!(f, "🖼️ Vision"),
99 Self::Embedding => write!(f, "🔢 Embedding"),
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
105pub struct Modalities {
106 pub input: Vec<SupportedModality>,
107 pub output: Vec<SupportedModality>,
108}
109
110pub struct GeneralMetadata {
111 pub max_seq_len: usize,
112 pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
114 pub no_kv_cache: bool,
115 pub no_prefix_cache: bool,
116 pub num_hidden_layers: usize,
117 pub eos_tok: Vec<u32>,
118 pub kind: ModelKind,
119 pub is_xlora: bool,
121 pub activation_dtype: DType,
122 pub sliding_window: Option<usize>,
123 pub cache_config: Option<CacheConfig>,
125 pub cache_engine: Option<CacheEngine>,
126 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
127 pub modalities: Modalities,
128}
129
130impl GeneralMetadata {
131 pub fn tok_env(&self) -> Option<TokEnv> {
132 self.llg_factory.as_ref().map(|f| f.tok_env().clone())
133 }
134}
135
136pub enum CacheInstruction {
137 In,
138 Out,
139 Reset {
141 load_preallocated_cache: bool,
142 reset_non_granular: bool,
143 },
144 Nothing,
145}
146
147pub trait PreProcessingMixin: MetadataMixin {
148 fn get_processor(&self) -> Arc<dyn Processor> {
149 Arc::new(BasicProcessor)
150 }
151 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
153 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
154}
155
156pub trait IsqPipelineMixin {
157 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
158}
159
160pub trait CacheManagerMixin {
161 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
164 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
167 fn set_none_cache(
171 &self,
172 seqs: &mut [&mut Sequence],
173 reset_non_granular: bool,
174 modify_draft_cache: bool,
175 load_preallocated_cache: bool,
176 );
177 fn cache(&self) -> &EitherCache;
178 fn do_preallocated_cache(&self) -> bool {
179 matches!(self.cache(), EitherCache::Normal(_))
180 }
181}
182
183pub trait MetadataMixin {
184 fn device(&self) -> Device;
185 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
187 fn name(&self) -> String;
188 fn reset_non_granular_state(&self);
189 fn get_metadata(&self) -> Arc<GeneralMetadata>;
190 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
191}
192
193pub trait AnyMoePipelineMixin {
195 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
197 unreachable!()
198 }
199 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
200 unreachable!()
201 }
202 fn amoe_base_model_trainable_params(&self) -> usize {
203 unreachable!()
204 }
205 fn amoe_supported(&self) -> bool {
206 false
207 }
208 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
210 unreachable!()
211 }
212 #[allow(clippy::too_many_arguments)]
214 fn amoe_create_layers(
215 &mut self,
216 _model_ids: Vec<String>,
217 _token: &TokenSource,
218 _revision: Option<String>,
219 _match_regex: &str,
220 _config: AnyMoeConfig,
221 _dtype: DType,
222 _dev: &Device,
223 (_prefix, _mlp): (String, String),
224 _layers: Vec<usize>,
225 _expert_type: AnyMoeExpertType,
226 _silent: bool,
227 _gate_model_id: Option<String>,
228 ) -> candle_core::Result<()> {
229 unreachable!()
230 }
231 #[allow(clippy::too_many_arguments)]
233 fn amoe_pre_train(
234 &self,
235 _inputs: AnyMoeTrainingInputs,
236 (_prefix, _mlp): (String, String),
237 _model_ids: Vec<String>,
238 _token: TokenSource,
239 _revision: Option<String>,
240 _layers: Vec<usize>,
241 _silent: bool,
242 ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
243 unreachable!()
244 }
245}
246
247#[derive(Clone)]
250pub enum ModelCategory {
251 Text,
252 Vision {
253 prefixer: Arc<dyn MultimodalPromptPrefixer>,
254 },
255 Diffusion,
256 Audio,
257 Speech,
258 Embedding,
259}
260
261impl std::fmt::Debug for ModelCategory {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 match self {
264 ModelCategory::Text => write!(f, "ModelCategory::Text"),
265 ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
266 ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
267 ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
268 ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
269 ModelCategory::Embedding => write!(f, "ModelCategory::Embedding"),
270 }
271 }
272}
273
274impl PartialEq for ModelCategory {
275 fn eq(&self, other: &Self) -> bool {
276 match (self, other) {
277 (Self::Text, Self::Text) => true,
278 (Self::Vision { .. }, Self::Vision { .. }) => true,
279 (Self::Audio, Self::Audio) => true,
280 (Self::Speech, Self::Speech) => true,
281 (Self::Diffusion, Self::Diffusion) => true,
282 (Self::Embedding, Self::Embedding) => true,
283 (
284 Self::Text
285 | Self::Vision { .. }
286 | Self::Diffusion
287 | Self::Audio
288 | Self::Speech
289 | Self::Embedding,
290 _,
291 ) => false,
292 }
293 }
294}
295
296pub trait MultimodalPromptPrefixer: Send + Sync {
298 fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
300 prompt.to_string()
301 }
302 fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
304 prompt.to_string()
305 }
306}
307
308pub enum CacheBackendMetadata {
309 DefaultInstructions {
310 pre_op: CacheInstruction,
311 post_op: CacheInstruction,
312 },
313 PagedAttention {
314 metadata: PagedAttentionMeta,
315 blocks_to_copy: HashMap<usize, Vec<usize>>,
316 },
317}
318
319#[derive(Clone, Debug)]
320pub enum ForwardInputsResult {
321 RawLogits {
322 logits: Tensor,
323 },
324 Embeddings {
325 embeddings: Tensor,
326 },
327 CausalGeneration {
328 logits: Tensor,
329 },
330 Image {
331 images: Vec<DynamicImage>,
332 },
333 Speech {
334 pcms: Vec<Arc<Vec<f32>>>,
335 rates: Vec<usize>,
336 channels: Vec<usize>,
337 },
338}
339
340impl ForwardInputsResult {
341 fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
342 match self {
343 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
344 logits: logits.i(bs_idx)?,
345 }),
346 Self::Embeddings { embeddings } => Ok(Self::Embeddings {
347 embeddings: embeddings.i(bs_idx)?,
348 }),
349 Self::RawLogits { logits } => Ok(Self::RawLogits {
350 logits: logits.i(bs_idx)?,
351 }),
352 Self::Image { images } => Ok(Self::Image {
353 images: vec![images[bs_idx].clone()],
354 }),
355 Self::Speech {
356 pcms,
357 rates,
358 channels,
359 } => Ok(Self::Speech {
360 pcms: vec![pcms[bs_idx].clone()],
361 rates: vec![rates[bs_idx]],
362 channels: vec![channels[bs_idx]],
363 }),
364 }
365 }
366
367 fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
368 match self {
369 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
370 logits: logits.to_device(device)?,
371 }),
372 Self::RawLogits { logits } => Ok(Self::RawLogits {
373 logits: logits.to_device(device)?,
374 }),
375 Self::Embeddings { embeddings } => Ok(Self::Embeddings {
376 embeddings: embeddings.to_device(device)?,
377 }),
378 Self::Image { .. } => Ok(self.clone()),
379 Self::Speech { .. } => Ok(self.clone()),
380 }
381 }
382}
383
384#[derive(serde::Serialize, serde::Deserialize)]
385pub(crate) struct FileListCache {
386 files: Vec<String>,
387}
388
389#[async_trait::async_trait]
390pub trait Pipeline:
391 Send
392 + Sync
393 + PreProcessingMixin
394 + IsqPipelineMixin
395 + CacheManagerMixin
396 + MetadataMixin
397 + AnyMoePipelineMixin
398{
399 fn forward_inputs(
400 &mut self,
401 inputs: Box<dyn Any>,
402 return_raw_logits: bool,
403 ) -> Result<ForwardInputsResult, candle_core::Error>;
404
405 #[allow(clippy::too_many_arguments)]
407 async fn step(
408 &mut self,
409 input_seqs: &mut [&mut Sequence],
410 is_prompt: bool,
411 return_raw_logits: bool,
412 prefix_cacher: &mut PrefixCacheManagerV2,
413 disable_eos_stop: bool,
414 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
415 backend_metadata: CacheBackendMetadata,
416 ) -> Result<Duration, candle_core::Error> {
417 match backend_metadata {
418 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
419 let inputs_iter =
420 std::iter::once(self.get_processor().inputs_processor().process_inputs(
421 self.tokenizer(),
422 input_seqs,
423 is_prompt,
424 self.get_metadata().is_xlora,
425 &self.device(),
426 self.get_metadata().no_kv_cache,
427 None,
428 return_raw_logits,
429 self.get_input_processor_config(),
430 None,
431 self.device_mapper(),
432 ));
433
434 let mut logits = vec![None; input_seqs.len()];
435 let len_inputs = 1;
436 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
437 let mut embedding_logits = vec![None; input_seqs.len()];
438
439 let mut exec_duration = Duration::ZERO;
440 for (i, inputs) in inputs_iter.into_iter().enumerate() {
441 let InputProcessorOutput {
442 inputs,
443 seq_indices,
444 } = inputs.map_err(candle_core::Error::msg)?;
445 if i == 0 {
446 match pre_op {
447 CacheInstruction::In => self.clone_in_cache(input_seqs),
448 CacheInstruction::Nothing => (),
449 CacheInstruction::Reset {
450 load_preallocated_cache,
451 reset_non_granular,
452 } => self.set_none_cache(
453 input_seqs,
454 reset_non_granular,
455 false,
456 load_preallocated_cache,
457 ),
458 _ => unreachable!("Unreachable PRE cache op."),
459 }
460 }
461
462 let start = Instant::now();
463 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
464 let end = Instant::now();
465 exec_duration += end.duration_since(start);
466
467 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
468 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
469 raw_out_logits[seq_idx][i] =
470 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
471 } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
472 embedding_logits[seq_idx] =
473 Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
474 } else {
475 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
476 }
477 }
478 }
479
480 match post_op {
481 CacheInstruction::Out => self.clone_out_cache(input_seqs),
482 CacheInstruction::Nothing => (),
483 CacheInstruction::Reset {
484 load_preallocated_cache,
485 reset_non_granular,
486 } => self.set_none_cache(
487 input_seqs,
488 reset_non_granular,
489 false,
490 load_preallocated_cache,
491 ),
492 _ => unreachable!("Unreachable POST cache op."),
493 }
494
495 if raw_out_logits[0][0].is_some() {
496 let start = Instant::now();
497 response::send_raw_responses(
498 input_seqs,
499 raw_out_logits
500 .into_iter()
501 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
502 .collect(),
503 )
504 .await?;
505 let end = Instant::now();
506 exec_duration += end.duration_since(start);
507
508 return Ok(exec_duration);
509 }
510 if embedding_logits[0].is_some() {
511 let start = Instant::now();
512 response::send_embedding_responses(
513 input_seqs,
514 embedding_logits
515 .into_iter()
516 .map(|raw| {
517 raw.unwrap()
518 .to_dtype(DType::F32)
519 .unwrap()
520 .to_vec1::<f32>()
521 .unwrap()
522 })
523 .collect(),
524 )
525 .await?;
526 let end = Instant::now();
527 exec_duration += end.duration_since(start);
528
529 return Ok(exec_duration);
530 }
531
532 let start = Instant::now();
533 let logits_on_cpu = logits.len() > 1;
534 let logits = logits
535 .into_iter()
536 .map(|l| {
537 let l = l.expect("Did not get any inputs. This is shocking.");
538 if logits_on_cpu {
539 l.to_device(&Device::Cpu)
540 } else {
541 Ok(l)
542 }
543 })
544 .collect::<candle_core::Result<Vec<_>>>()?;
545
546 match &logits[0] {
547 ForwardInputsResult::RawLogits { .. }
548 | ForwardInputsResult::Embeddings { .. } => unreachable!(),
549 ForwardInputsResult::CausalGeneration { .. } => {
550 self.sample_causal_gen(
551 input_seqs,
552 logits
553 .into_iter()
554 .map(|r| {
555 #[allow(irrefutable_let_patterns)]
556 let ForwardInputsResult::CausalGeneration { logits } = r
557 else {
558 unreachable!(
559 "All results must have same type, `CausalGeneration`"
560 )
561 };
562 logits
563 })
564 .collect::<Vec<_>>(),
565 prefix_cacher,
566 disable_eos_stop,
567 rng,
568 )
569 .await?;
570 }
571 ForwardInputsResult::Image { .. } => {
572 response::send_image_responses(
573 input_seqs,
574 logits
575 .into_iter()
576 .map(|r| {
577 #[allow(irrefutable_let_patterns)]
578 let ForwardInputsResult::Image { images } = r
579 else {
580 unreachable!("All results must have same type, `Image`")
581 };
582 images
583 .into_iter()
584 .next()
585 .expect("Must have at least 1 element.")
586 })
587 .collect::<Vec<_>>(),
588 )
589 .await?;
590 }
591 ForwardInputsResult::Speech { .. } => {
592 let rates = logits
593 .iter()
594 .map(|r| {
595 #[allow(irrefutable_let_patterns)]
596 let ForwardInputsResult::Speech { rates, .. } = r
597 else {
598 unreachable!("All results must have same type, `Speech`")
599 };
600 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
601 *rates.first().unwrap()
602 })
603 .collect::<Vec<_>>();
604 let channels = logits
605 .iter()
606 .map(|r| {
607 #[allow(irrefutable_let_patterns)]
608 let ForwardInputsResult::Speech { channels, .. } = r
609 else {
610 unreachable!("All results must have same type, `Speech`")
611 };
612 assert_eq!(
613 channels.len(),
614 1,
615 "Each sequence must have 1 PCM output."
616 );
617 *channels.first().unwrap()
618 })
619 .collect::<Vec<_>>();
620 let pcms = logits
621 .into_iter()
622 .map(|r| {
623 #[allow(irrefutable_let_patterns)]
624 let ForwardInputsResult::Speech { pcms, .. } = r
625 else {
626 unreachable!("All results must have same type, `Speech`")
627 };
628 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
629 pcms.into_iter().nth(0).unwrap()
630 })
631 .collect::<Vec<_>>();
632 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
633 .await?;
634 }
635 }
636 let end = Instant::now();
637 exec_duration += end.duration_since(start);
638
639 Ok(exec_duration)
640 }
641 CacheBackendMetadata::PagedAttention {
642 metadata,
643 blocks_to_copy,
644 } => {
645 self.get_metadata()
647 .cache_engine
648 .as_ref()
649 .expect("PagedAttention must have cache engines.")
650 .execute_scheduler_ops(&blocks_to_copy)?;
651
652 let inputs_iter =
653 std::iter::once(self.get_processor().inputs_processor().process_inputs(
654 self.tokenizer(),
655 input_seqs,
656 is_prompt,
657 self.get_metadata().is_xlora,
658 &self.device(),
659 self.get_metadata().no_kv_cache,
660 None,
661 return_raw_logits,
662 self.get_input_processor_config(),
663 Some(metadata),
664 self.device_mapper(),
665 ));
666
667 let mut logits = vec![None; input_seqs.len()];
668 let len_inputs = 1;
669 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
670 let mut embedding_logits = vec![None; input_seqs.len()];
671
672 let mut exec_duration = Duration::ZERO;
673 for (i, inputs) in inputs_iter.into_iter().enumerate() {
674 let InputProcessorOutput {
675 inputs,
676 seq_indices,
677 } = inputs.map_err(candle_core::Error::msg)?;
678
679 let start = Instant::now();
680 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
681 let end = Instant::now();
682 exec_duration += end.duration_since(start);
683
684 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
685 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
686 raw_out_logits[seq_idx][i] =
687 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
688 } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
689 embedding_logits[seq_idx] =
690 Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
691 } else {
692 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
693 }
694 }
695 }
696
697 if raw_out_logits[0][0].is_some() {
698 let start = Instant::now();
699 response::send_raw_responses(
700 input_seqs,
701 raw_out_logits
702 .into_iter()
703 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
704 .collect(),
705 )
706 .await?;
707 let end = Instant::now();
708 exec_duration += end.duration_since(start);
709
710 return Ok(exec_duration);
711 }
712 if embedding_logits[0].is_some() {
713 let start = Instant::now();
714 response::send_embedding_responses(
715 input_seqs,
716 embedding_logits
717 .into_iter()
718 .map(|raw| {
719 raw.unwrap()
720 .to_dtype(DType::F32)
721 .unwrap()
722 .to_vec1::<f32>()
723 .unwrap()
724 })
725 .collect(),
726 )
727 .await?;
728 let end = Instant::now();
729 exec_duration += end.duration_since(start);
730
731 return Ok(exec_duration);
732 }
733
734 let start = Instant::now();
735 let logits_on_cpu = logits.len() > 1;
736 let logits = logits
737 .into_iter()
738 .map(|l| {
739 let l = l.expect("Did not get any inputs. This is shocking.");
740 if logits_on_cpu {
741 l.to_device(&Device::Cpu)
742 } else {
743 Ok(l)
744 }
745 })
746 .collect::<candle_core::Result<Vec<_>>>()?;
747
748 match &logits[0] {
749 ForwardInputsResult::RawLogits { .. }
750 | ForwardInputsResult::Embeddings { .. } => unreachable!(),
751 ForwardInputsResult::CausalGeneration { .. } => {
752 self.sample_causal_gen(
753 input_seqs,
754 logits
755 .into_iter()
756 .map(|r| {
757 #[allow(irrefutable_let_patterns)]
758 let ForwardInputsResult::CausalGeneration { logits } = r
759 else {
760 unreachable!("All results must have same type")
761 };
762 logits
763 })
764 .collect::<Vec<_>>(),
765 prefix_cacher,
766 disable_eos_stop,
767 rng,
768 )
769 .await?;
770 }
771 ForwardInputsResult::Image { .. } => {
772 response::send_image_responses(
773 input_seqs,
774 logits
775 .into_iter()
776 .map(|r| {
777 #[allow(irrefutable_let_patterns)]
778 let ForwardInputsResult::Image { images } = r
779 else {
780 unreachable!("All results must have same type, `Image`")
781 };
782 images
783 .into_iter()
784 .next()
785 .expect("Must have at least 1 element.")
786 })
787 .collect::<Vec<_>>(),
788 )
789 .await?;
790 }
791 ForwardInputsResult::Speech { .. } => {
792 let rates = logits
793 .iter()
794 .map(|r| {
795 #[allow(irrefutable_let_patterns)]
796 let ForwardInputsResult::Speech { rates, .. } = r
797 else {
798 unreachable!("All results must have same type, `Speech`")
799 };
800 assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
801 *rates.first().unwrap()
802 })
803 .collect::<Vec<_>>();
804 let channels = logits
805 .iter()
806 .map(|r| {
807 #[allow(irrefutable_let_patterns)]
808 let ForwardInputsResult::Speech { channels, .. } = r
809 else {
810 unreachable!("All results must have same type, `Speech`")
811 };
812 assert_eq!(
813 channels.len(),
814 1,
815 "Each sequence must have 1 PCM output."
816 );
817 *channels.first().unwrap()
818 })
819 .collect::<Vec<_>>();
820 let pcms = logits
821 .into_iter()
822 .map(|r| {
823 #[allow(irrefutable_let_patterns)]
824 let ForwardInputsResult::Speech { pcms, .. } = r
825 else {
826 unreachable!("All results must have same type, `Speech`")
827 };
828 assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
829 pcms.into_iter().nth(0).unwrap()
830 })
831 .collect::<Vec<_>>();
832 response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
833 .await?;
834 }
835 }
836 let end = Instant::now();
837 exec_duration += end.duration_since(start);
838
839 Ok(exec_duration)
840 }
841 }
842 }
843
844 async fn sample_causal_gen(
845 &self,
846 seqs: &mut [&mut Sequence],
847 logits: Vec<Tensor>,
848 prefix_cacher: &mut PrefixCacheManagerV2,
849 disable_eos_stop: bool,
850 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
851 ) -> Result<(), candle_core::Error>;
852
853 fn category(&self) -> ModelCategory;
854}
855
856pub(crate) fn extract_logits(
857 logits: &Tensor,
858 context_lens: Vec<(usize, usize)>,
859) -> candle_core::Result<Tensor> {
860 let mut toks = Vec::new();
861 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
862 toks.push(dim.narrow(1, start, len)?);
863 }
864 Tensor::cat(&toks, 0)
865}
866
867#[cfg(test)]
868mod tests {
869 use crate::MessageContent;
870 use either::Either;
871 use indexmap::IndexMap;
872 use serde_json::Value;
873
874 macro_rules! hashmap {
875 (@single $($x:tt)*) => (());
876 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
877
878 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
879 ($($key:expr => $value:expr),*) => {
880 {
881 let _cap = hashmap!(@count $($key),*);
882 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
883 $(
884 let _ = _map.insert($key, Value::String($value));
885 )*
886 _map
887 }
888 };
889 }
890
891 #[cfg(test)]
892 #[track_caller]
893 fn test_with_inputs(
894 templates: &[(bool, &str, &str, &str, &str)],
895 expected_outputs: &[&str],
896 inputs: Vec<IndexMap<String, MessageContent>>,
897 ) {
898 use crate::pipeline::chat_template::ChatTemplateValue;
899
900 use super::chat_template::apply_chat_template_to;
901 let mut failed = Vec::new();
902 let n_templates = templates.len();
903 for ((has_system, bos, eos, unk, template), expected) in
904 templates.iter().zip(expected_outputs)
905 {
906 let output = match apply_chat_template_to(
907 if !has_system {
908 inputs[1..].to_vec()
909 } else {
910 inputs.clone()
911 },
912 true,
913 None,
914 &ChatTemplateValue(Either::Left(template.to_string())),
915 Some(bos.to_string()),
916 Some(eos.to_string()),
917 Some(unk.to_string()),
918 Vec::new(),
919 ) {
920 Ok(v) => v,
921 Err(e) => {
922 failed.push(format!("Failed with {e}."));
923 continue;
924 }
925 };
926 if output != *expected {
927 failed.push(format!(
928 "Expected: `{}` \n\nGot: `{}`",
929 expected.replace('\n', "\\n"),
930 output.replace('\n', "\\n")
931 ));
932 }
933 }
934 if !failed.is_empty() {
935 for (i, line) in failed.iter().enumerate() {
936 println!("------------ Template {i} ------------");
937 println!("{line}");
938 }
939 println!("------------------------");
940 panic!("{}/{n_templates} chat templates failed.", failed.len());
941 }
942 }
943
944 #[test]
945 fn test_chat_templates() {
954 let templates = [
955 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
957 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
959 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
961 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
963 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
965 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
967 ];
968 let expected_outputs = [
969 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
971 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
973 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
975 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
977 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
979 ];
980 let messages = [
981 ["system", "You are a helpful assistant"],
982 ["user", "Hello"],
983 ["assistant", "Hi there"],
984 ["user", "Who are you"],
985 ["assistant", " I am an assistant "],
986 ["user", "Another question"],
987 ];
988 let mut inputs = Vec::new();
989 for [role, content] in messages {
990 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
991 IndexMap::new();
992 message.insert("role".to_string(), Either::Left(role.to_string()));
993 message.insert("content".to_string(), Either::Left(content.to_string()));
994 inputs.push(message);
995 }
996 test_with_inputs(&templates, &expected_outputs, inputs);
997 }
998
999 #[test]
1000 fn test_image_chat_templates() {
1013 let templates = [
1014 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1016 ];
1017 let expected_outputs = [
1018 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
1020 ];
1021
1022 let mut inputs = Vec::new();
1023
1024 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1025 IndexMap::new();
1026 message.insert("role".to_string(), Either::Left("system".to_string()));
1027 message.insert(
1028 "content".to_string(),
1029 Either::Right(vec![hashmap! {
1030 "type".to_string() => "text".to_string(),
1031 "text".to_string() => "You are a helpful assistant".to_string()
1032 }]),
1033 );
1034 inputs.push(message);
1035
1036 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1037 IndexMap::new();
1038 message.insert("role".to_string(), Either::Left("user".to_string()));
1039 message.insert(
1040 "content".to_string(),
1041 Either::Right(vec![
1042 hashmap! {
1043 "type".to_string() => "image".to_string()
1044 },
1045 hashmap! {
1046 "type".to_string() => "text".to_string(),
1047 "text".to_string() => "Hello, please describe the above.".to_string()
1048 },
1049 ]),
1050 );
1051 inputs.push(message);
1052
1053 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1054 IndexMap::new();
1055 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1056 message.insert(
1057 "content".to_string(),
1058 Either::Right(vec![hashmap! {
1059 "type".to_string() => "text".to_string(),
1060 "text".to_string() => "Hi there".to_string()
1061 }]),
1062 );
1063 inputs.push(message);
1064
1065 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1066 IndexMap::new();
1067 message.insert("role".to_string(), Either::Left("user".to_string()));
1068 message.insert(
1069 "content".to_string(),
1070 Either::Right(vec![
1071 hashmap! {
1072 "type".to_string() => "image".to_string()
1073 },
1074 hashmap! {
1075 "type".to_string() => "text".to_string(),
1076 "text".to_string() => "This is me, who are you".to_string()
1077 },
1078 ]),
1079 );
1080 inputs.push(message);
1081
1082 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1083 IndexMap::new();
1084 message.insert("role".to_string(), Either::Left("assistant".to_string()));
1085 message.insert(
1086 "content".to_string(),
1087 Either::Right(vec![hashmap! {
1088 "type".to_string() => "text".to_string(),
1089 "text".to_string() => " I am an assistant ".to_string()
1090 }]),
1091 );
1092 inputs.push(message);
1093
1094 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1095 IndexMap::new();
1096 message.insert("role".to_string(), Either::Left("user".to_string()));
1097 message.insert(
1098 "content".to_string(),
1099 Either::Right(vec![
1100 hashmap! {
1101 "type".to_string() => "image".to_string()
1102 },
1103 hashmap! {
1104 "type".to_string() => "text".to_string(),
1105 "text".to_string() => "Another question, what is this?".to_string()
1106 },
1107 ]),
1108 );
1109 inputs.push(message);
1110
1111 test_with_inputs(&templates, &expected_outputs, inputs);
1112 }
1113}