mistralrs_core/pipeline/
mod.rs

1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod speech;
19mod vision;
20
21pub use super::diffusion_models::DiffusionGenerationParams;
22use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
23use crate::device_map::DeviceMapper;
24use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26pub use amoe::{AnyMoeLoader, AnyMoePipeline};
27pub use auto::{AutoLoader, AutoLoaderBuilder};
28use chat_template::ChatTemplate;
29pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
30pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
31pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
32use image::DynamicImage;
33pub use inputs_processor::InputProcessorOutput;
34pub(crate) use isq::IsqModelLoader;
35pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
36use llguidance::toktrie::TokEnv;
37pub use loaders::{
38    AdapterKind, AutoDeviceMapParams, AutoNormalLoader, AutoVisionLoader, DeepSeekV2Loader,
39    DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel,
40    DiffusionModelLoader, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader, GemmaLoader,
41    Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
42    LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind,
43    ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
44    Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
45    QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader,
46    Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
47    VisionModelLoader,
48};
49use mistralrs_quant::IsqType;
50pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
51pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
52pub use paths::{AdapterPaths, LoraAdapterPaths};
53pub(crate) use processing::{
54    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
55};
56use rand_isaac::Isaac64Rng;
57pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
58pub use speech::{SpeechLoader, SpeechPipeline};
59use std::any::Any;
60use std::collections::HashMap;
61use std::fmt::Debug;
62use std::num::NonZeroUsize;
63use std::sync::Arc;
64use std::time::{Duration, Instant};
65use tokenizers::Tokenizer;
66pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
67
68use anyhow::Result;
69use candle_core::{DType, Device, IndexOp, Tensor, Var};
70
71use crate::sequence::Sequence;
72
73pub use self::inputs_processor::{
74    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
75};
76use self::text_models_inputs_processor::PagedAttentionMeta;
77pub use crate::kv_cache::{
78    Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
79};
80
81#[derive(Clone, PartialEq, Eq)]
82pub enum SupportedModality {
83    Text,
84    Audio,
85    Vision,
86}
87
88impl Debug for SupportedModality {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        match self {
91            Self::Text => write!(f, "📝 Text"),
92            Self::Audio => write!(f, "🔊 Audio"),
93            Self::Vision => write!(f, "🖼️ Vision"),
94        }
95    }
96}
97
98#[derive(Debug, Clone)]
99pub struct Modalities {
100    pub input: Vec<SupportedModality>,
101    pub output: Vec<SupportedModality>,
102}
103
104pub struct GeneralMetadata {
105    pub max_seq_len: usize,
106    /// Only None if it doesn't make sense for the model
107    pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
108    pub no_kv_cache: bool,
109    pub no_prefix_cache: bool,
110    pub num_hidden_layers: usize,
111    pub eos_tok: Vec<u32>,
112    pub kind: ModelKind,
113    // TODO: Replace is_xlora queries to check via kind instead:
114    pub is_xlora: bool,
115    pub activation_dtype: DType,
116    pub sliding_window: Option<usize>,
117    // PagedAttention stuff
118    pub cache_config: Option<CacheConfig>,
119    pub cache_engine: Option<CacheEngine>,
120    pub prompt_chunksize: Option<NonZeroUsize>,
121    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
122    pub modalities: Modalities,
123}
124
125impl GeneralMetadata {
126    pub fn tok_env(&self) -> Option<TokEnv> {
127        self.llg_factory.as_ref().map(|f| f.tok_env().clone())
128    }
129}
130
131pub enum CacheInstruction {
132    In,
133    Out,
134    /// load_preallocated_cache means to load the preallocated cache, if applicable.
135    Reset {
136        load_preallocated_cache: bool,
137        reset_non_granular: bool,
138    },
139    Nothing,
140}
141
142pub trait PreProcessingMixin: MetadataMixin {
143    fn get_processor(&self) -> Arc<dyn Processor> {
144        Arc::new(BasicProcessor)
145    }
146    /// Only None if it doesnt make sense for the model
147    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
148    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
149}
150
151pub trait IsqPipelineMixin {
152    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
153}
154
155pub trait CacheManagerMixin {
156    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
157    /// It is not a guarantee that this will be called for each completion step.
158    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
159    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
160    /// It is not a guarantee that this will be called for each step.
161    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
162    /// Set the model cache to all None. Only called for prompt seqs.
163    /// It is not a guarantee that this will be called for each prompt step.
164    /// This may also reset the non granular state if applicable.
165    fn set_none_cache(
166        &self,
167        seqs: &mut [&mut Sequence],
168        reset_non_granular: bool,
169        modify_draft_cache: bool,
170        load_preallocated_cache: bool,
171    );
172    fn cache(&self) -> &EitherCache;
173    fn do_preallocated_cache(&self) -> bool {
174        matches!(self.cache(), EitherCache::Normal(_))
175    }
176}
177
178pub trait MetadataMixin {
179    fn device(&self) -> Device;
180    /// Only None if it doesnt make sense for the model
181    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
182    fn name(&self) -> String;
183    fn reset_non_granular_state(&self);
184    fn get_metadata(&self) -> Arc<GeneralMetadata>;
185    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
186}
187
188/// Implemented by the base model of an AnyMoe.
189pub trait AnyMoePipelineMixin {
190    /// Get vars for each gating layer
191    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
192        unreachable!()
193    }
194    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
195        unreachable!()
196    }
197    fn amoe_base_model_trainable_params(&self) -> usize {
198        unreachable!()
199    }
200    fn amoe_supported(&self) -> bool {
201        false
202    }
203    /// Per-layer cached outputs.
204    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
205        unreachable!()
206    }
207    /// Inject the MoE layers
208    #[allow(clippy::too_many_arguments)]
209    fn amoe_create_layers(
210        &mut self,
211        _model_ids: Vec<String>,
212        _token: &TokenSource,
213        _revision: Option<String>,
214        _match_regex: &str,
215        _config: AnyMoeConfig,
216        _dtype: DType,
217        _dev: &Device,
218        (_prefix, _mlp): (String, String),
219        _layers: Vec<usize>,
220        _expert_type: AnyMoeExpertType,
221        _silent: bool,
222        _gate_model_id: Option<String>,
223    ) -> candle_core::Result<()> {
224        unreachable!()
225    }
226    /// Pre-train the gating layers
227    #[allow(clippy::too_many_arguments)]
228    fn amoe_pre_train(
229        &self,
230        _inputs: AnyMoeTrainingInputs,
231        (_prefix, _mlp): (String, String),
232        _model_ids: Vec<String>,
233        _token: TokenSource,
234        _revision: Option<String>,
235        _layers: Vec<usize>,
236        _silent: bool,
237    ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
238        unreachable!()
239    }
240}
241
242/// Category of the model. This can also be used to extract model-category specific tools,
243/// such as the vision model prompt prefixer.
244#[derive(Clone)]
245pub enum ModelCategory {
246    Text,
247    Vision {
248        prefixer: Arc<dyn MultimodalPromptPrefixer>,
249    },
250    Diffusion,
251    Audio,
252    Speech,
253}
254
255impl std::fmt::Debug for ModelCategory {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        match self {
258            ModelCategory::Text => write!(f, "ModelCategory::Text"),
259            ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
260            ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
261            ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
262            ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
263        }
264    }
265}
266
267impl PartialEq for ModelCategory {
268    fn eq(&self, other: &Self) -> bool {
269        match (self, other) {
270            (Self::Text, Self::Text) => true,
271            (Self::Vision { .. }, Self::Vision { .. }) => true,
272            (Self::Audio, Self::Audio) => true,
273            (Self::Speech, Self::Speech) => true,
274            (Self::Diffusion, Self::Diffusion) => true,
275            (
276                Self::Text | Self::Vision { .. } | Self::Diffusion | Self::Audio | Self::Speech,
277                _,
278            ) => false,
279        }
280    }
281}
282
283/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at 0.
284pub trait MultimodalPromptPrefixer: Send + Sync {
285    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
286    fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
287        prompt.to_string()
288    }
289    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
290    fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
291        prompt.to_string()
292    }
293}
294
295pub enum CacheBackendMetadata {
296    DefaultInstructions {
297        pre_op: CacheInstruction,
298        post_op: CacheInstruction,
299    },
300    PagedAttention {
301        metadata: PagedAttentionMeta,
302        blocks_to_swap_in: HashMap<usize, usize>,
303        blocks_to_swap_out: HashMap<usize, usize>,
304        blocks_to_copy: HashMap<usize, Vec<usize>>,
305    },
306}
307
308#[derive(Clone, Debug)]
309pub enum ForwardInputsResult {
310    RawLogits {
311        logits: Tensor,
312    },
313    CausalGeneration {
314        logits: Tensor,
315    },
316    Image {
317        images: Vec<DynamicImage>,
318    },
319    Speech {
320        pcms: Vec<Arc<Vec<f32>>>,
321        rates: Vec<usize>,
322        channels: Vec<usize>,
323    },
324}
325
326impl ForwardInputsResult {
327    fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
328        match self {
329            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
330                logits: logits.i(bs_idx)?,
331            }),
332            Self::RawLogits { logits } => Ok(Self::RawLogits {
333                logits: logits.i(bs_idx)?,
334            }),
335            Self::Image { images } => Ok(Self::Image {
336                images: vec![images[bs_idx].clone()],
337            }),
338            Self::Speech {
339                pcms,
340                rates,
341                channels,
342            } => Ok(Self::Speech {
343                pcms: vec![pcms[bs_idx].clone()],
344                rates: vec![rates[bs_idx]],
345                channels: vec![channels[bs_idx]],
346            }),
347        }
348    }
349
350    fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
351        match self {
352            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
353                logits: logits.to_device(device)?,
354            }),
355            Self::RawLogits { logits } => Ok(Self::RawLogits {
356                logits: logits.to_device(device)?,
357            }),
358            Self::Image { .. } => Ok(self.clone()),
359            Self::Speech { .. } => Ok(self.clone()),
360        }
361    }
362}
363
364#[derive(serde::Serialize, serde::Deserialize)]
365pub(crate) struct FileListCache {
366    files: Vec<String>,
367}
368
369#[async_trait::async_trait]
370pub trait Pipeline:
371    Send
372    + Sync
373    + PreProcessingMixin
374    + IsqPipelineMixin
375    + CacheManagerMixin
376    + MetadataMixin
377    + AnyMoePipelineMixin
378{
379    fn forward_inputs(
380        &mut self,
381        inputs: Box<dyn Any>,
382        return_raw_logits: bool,
383    ) -> Result<ForwardInputsResult, candle_core::Error>;
384
385    /// Returns the total of model execution time.
386    #[allow(clippy::too_many_arguments)]
387    async fn step(
388        &mut self,
389        input_seqs: &mut [&mut Sequence],
390        is_prompt: bool,
391        return_raw_logits: bool,
392        prefix_cacher: &mut PrefixCacheManagerV2,
393        disable_eos_stop: bool,
394        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
395        backend_metadata: CacheBackendMetadata,
396    ) -> Result<Duration, candle_core::Error> {
397        match backend_metadata {
398            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
399                let inputs_iter = self.get_processor().inputs_processor().process_inputs(
400                    self.tokenizer(),
401                    input_seqs,
402                    is_prompt,
403                    self.get_metadata().is_xlora,
404                    &self.device(),
405                    self.get_metadata().no_kv_cache,
406                    None,
407                    return_raw_logits,
408                    self.get_input_processor_config(),
409                    None,
410                    self.get_metadata().prompt_chunksize,
411                    self.device_mapper(),
412                );
413
414                let mut logits = vec![None; input_seqs.len()];
415                let prompt_chunksize = self
416                    .get_metadata()
417                    .prompt_chunksize
418                    .map(NonZeroUsize::get)
419                    .unwrap_or(1);
420                let len_inputs = input_seqs
421                    .iter()
422                    .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
423                    .max()
424                    .unwrap();
425                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
426
427                let mut exec_duration = Duration::ZERO;
428                for (i, inputs) in inputs_iter.into_iter().enumerate() {
429                    let InputProcessorOutput {
430                        inputs,
431                        seq_indices,
432                    } = inputs.map_err(candle_core::Error::msg)?;
433                    if i == 0 {
434                        match pre_op {
435                            CacheInstruction::In => self.clone_in_cache(input_seqs),
436                            CacheInstruction::Nothing => (),
437                            CacheInstruction::Reset {
438                                load_preallocated_cache,
439                                reset_non_granular,
440                            } => self.set_none_cache(
441                                input_seqs,
442                                reset_non_granular,
443                                false,
444                                load_preallocated_cache,
445                            ),
446                            _ => unreachable!("Unreachable PRE cache op."),
447                        }
448                    }
449
450                    let start = Instant::now();
451                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
452                    let end = Instant::now();
453                    exec_duration += end.duration_since(start);
454
455                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
456                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
457                            raw_out_logits[seq_idx][i] =
458                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
459                        } else {
460                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
461                        }
462                    }
463                }
464
465                match post_op {
466                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
467                    CacheInstruction::Nothing => (),
468                    CacheInstruction::Reset {
469                        load_preallocated_cache,
470                        reset_non_granular,
471                    } => self.set_none_cache(
472                        input_seqs,
473                        reset_non_granular,
474                        false,
475                        load_preallocated_cache,
476                    ),
477                    _ => unreachable!("Unreachable POST cache op."),
478                }
479
480                if raw_out_logits[0][0].is_some() {
481                    let start = Instant::now();
482                    response::send_raw_responses(
483                        input_seqs,
484                        raw_out_logits
485                            .into_iter()
486                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
487                            .collect(),
488                    )
489                    .await?;
490                    let end = Instant::now();
491                    exec_duration += end.duration_since(start);
492
493                    return Ok(exec_duration);
494                }
495
496                let start = Instant::now();
497                let logits_on_cpu = logits.len() > 1;
498                let logits = logits
499                    .into_iter()
500                    .map(|l| {
501                        let l = l.expect("Did not get any inputs. This is shocking.");
502                        if logits_on_cpu {
503                            l.to_device(&Device::Cpu)
504                        } else {
505                            Ok(l)
506                        }
507                    })
508                    .collect::<candle_core::Result<Vec<_>>>()?;
509
510                match &logits[0] {
511                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
512                    ForwardInputsResult::CausalGeneration { .. } => {
513                        self.sample_causal_gen(
514                            input_seqs,
515                            logits
516                                .into_iter()
517                                .map(|r| {
518                                    #[allow(irrefutable_let_patterns)]
519                                    let ForwardInputsResult::CausalGeneration { logits } = r
520                                    else {
521                                        unreachable!(
522                                            "All results must have same type, `CausalGeneration`"
523                                        )
524                                    };
525                                    logits
526                                })
527                                .collect::<Vec<_>>(),
528                            prefix_cacher,
529                            disable_eos_stop,
530                            rng,
531                        )
532                        .await?;
533                    }
534                    ForwardInputsResult::Image { .. } => {
535                        response::send_image_responses(
536                            input_seqs,
537                            logits
538                                .into_iter()
539                                .map(|r| {
540                                    #[allow(irrefutable_let_patterns)]
541                                    let ForwardInputsResult::Image { images } = r
542                                    else {
543                                        unreachable!("All results must have same type, `Image`")
544                                    };
545                                    images
546                                        .into_iter()
547                                        .next()
548                                        .expect("Must have at least 1 element.")
549                                })
550                                .collect::<Vec<_>>(),
551                        )
552                        .await?;
553                    }
554                    ForwardInputsResult::Speech { .. } => {
555                        let rates = logits
556                            .iter()
557                            .map(|r| {
558                                #[allow(irrefutable_let_patterns)]
559                                let ForwardInputsResult::Speech { rates, .. } = r
560                                else {
561                                    unreachable!("All results must have same type, `Speech`")
562                                };
563                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
564                                *rates.first().unwrap()
565                            })
566                            .collect::<Vec<_>>();
567                        let channels = logits
568                            .iter()
569                            .map(|r| {
570                                #[allow(irrefutable_let_patterns)]
571                                let ForwardInputsResult::Speech { channels, .. } = r
572                                else {
573                                    unreachable!("All results must have same type, `Speech`")
574                                };
575                                assert_eq!(
576                                    channels.len(),
577                                    1,
578                                    "Each sequence must have 1 PCM output."
579                                );
580                                *channels.first().unwrap()
581                            })
582                            .collect::<Vec<_>>();
583                        let pcms = logits
584                            .into_iter()
585                            .map(|r| {
586                                #[allow(irrefutable_let_patterns)]
587                                let ForwardInputsResult::Speech { pcms, .. } = r
588                                else {
589                                    unreachable!("All results must have same type, `Speech`")
590                                };
591                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
592                                pcms.into_iter().nth(0).unwrap()
593                            })
594                            .collect::<Vec<_>>();
595                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
596                            .await?;
597                    }
598                }
599                let end = Instant::now();
600                exec_duration += end.duration_since(start);
601
602                Ok(exec_duration)
603            }
604            CacheBackendMetadata::PagedAttention {
605                metadata,
606                blocks_to_copy,
607                blocks_to_swap_in,
608                blocks_to_swap_out,
609            } => {
610                // Cloning might be bad?
611                self.get_metadata()
612                    .cache_engine
613                    .as_ref()
614                    .expect("PagedAttention must have cache engines.")
615                    .execute_scheduler_ops(
616                        &blocks_to_swap_in,
617                        &blocks_to_swap_out,
618                        &blocks_to_copy,
619                    )?;
620
621                let inputs_iter = self.get_processor().inputs_processor().process_inputs(
622                    self.tokenizer(),
623                    input_seqs,
624                    is_prompt,
625                    self.get_metadata().is_xlora,
626                    &self.device(),
627                    self.get_metadata().no_kv_cache,
628                    None,
629                    return_raw_logits,
630                    self.get_input_processor_config(),
631                    Some(metadata),
632                    self.get_metadata().prompt_chunksize,
633                    self.device_mapper(),
634                );
635
636                let mut logits = vec![None; input_seqs.len()];
637                let prompt_chunksize = self
638                    .get_metadata()
639                    .prompt_chunksize
640                    .map(NonZeroUsize::get)
641                    .unwrap_or(1);
642                let len_inputs = input_seqs
643                    .iter()
644                    .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
645                    .max()
646                    .unwrap();
647                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
648
649                let mut exec_duration = Duration::ZERO;
650                for (i, inputs) in inputs_iter.into_iter().enumerate() {
651                    let InputProcessorOutput {
652                        inputs,
653                        seq_indices,
654                    } = inputs.map_err(candle_core::Error::msg)?;
655
656                    let start = Instant::now();
657                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
658                    let end = Instant::now();
659                    exec_duration += end.duration_since(start);
660
661                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
662                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
663                            raw_out_logits[seq_idx][i] =
664                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
665                        } else {
666                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
667                        }
668                    }
669                }
670
671                if raw_out_logits[0][0].is_some() {
672                    let start = Instant::now();
673                    response::send_raw_responses(
674                        input_seqs,
675                        raw_out_logits
676                            .into_iter()
677                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
678                            .collect(),
679                    )
680                    .await?;
681                    let end = Instant::now();
682                    exec_duration += end.duration_since(start);
683
684                    return Ok(exec_duration);
685                }
686
687                let start = Instant::now();
688                let logits_on_cpu = logits.len() > 1;
689                let logits = logits
690                    .into_iter()
691                    .map(|l| {
692                        let l = l.expect("Did not get any inputs. This is shocking.");
693                        if logits_on_cpu {
694                            l.to_device(&Device::Cpu)
695                        } else {
696                            Ok(l)
697                        }
698                    })
699                    .collect::<candle_core::Result<Vec<_>>>()?;
700
701                match &logits[0] {
702                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
703                    ForwardInputsResult::CausalGeneration { .. } => {
704                        self.sample_causal_gen(
705                            input_seqs,
706                            logits
707                                .into_iter()
708                                .map(|r| {
709                                    #[allow(irrefutable_let_patterns)]
710                                    let ForwardInputsResult::CausalGeneration { logits } = r
711                                    else {
712                                        unreachable!("All results must have same type")
713                                    };
714                                    logits
715                                })
716                                .collect::<Vec<_>>(),
717                            prefix_cacher,
718                            disable_eos_stop,
719                            rng,
720                        )
721                        .await?;
722                    }
723                    ForwardInputsResult::Image { .. } => {
724                        response::send_image_responses(
725                            input_seqs,
726                            logits
727                                .into_iter()
728                                .map(|r| {
729                                    #[allow(irrefutable_let_patterns)]
730                                    let ForwardInputsResult::Image { images } = r
731                                    else {
732                                        unreachable!("All results must have same type, `Image`")
733                                    };
734                                    images
735                                        .into_iter()
736                                        .next()
737                                        .expect("Must have at least 1 element.")
738                                })
739                                .collect::<Vec<_>>(),
740                        )
741                        .await?;
742                    }
743                    ForwardInputsResult::Speech { .. } => {
744                        let rates = logits
745                            .iter()
746                            .map(|r| {
747                                #[allow(irrefutable_let_patterns)]
748                                let ForwardInputsResult::Speech { rates, .. } = r
749                                else {
750                                    unreachable!("All results must have same type, `Speech`")
751                                };
752                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
753                                *rates.first().unwrap()
754                            })
755                            .collect::<Vec<_>>();
756                        let channels = logits
757                            .iter()
758                            .map(|r| {
759                                #[allow(irrefutable_let_patterns)]
760                                let ForwardInputsResult::Speech { channels, .. } = r
761                                else {
762                                    unreachable!("All results must have same type, `Speech`")
763                                };
764                                assert_eq!(
765                                    channels.len(),
766                                    1,
767                                    "Each sequence must have 1 PCM output."
768                                );
769                                *channels.first().unwrap()
770                            })
771                            .collect::<Vec<_>>();
772                        let pcms = logits
773                            .into_iter()
774                            .map(|r| {
775                                #[allow(irrefutable_let_patterns)]
776                                let ForwardInputsResult::Speech { pcms, .. } = r
777                                else {
778                                    unreachable!("All results must have same type, `Speech`")
779                                };
780                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
781                                pcms.into_iter().nth(0).unwrap()
782                            })
783                            .collect::<Vec<_>>();
784                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
785                            .await?;
786                    }
787                }
788                let end = Instant::now();
789                exec_duration += end.duration_since(start);
790
791                Ok(exec_duration)
792            }
793        }
794    }
795
796    async fn sample_causal_gen(
797        &self,
798        seqs: &mut [&mut Sequence],
799        logits: Vec<Tensor>,
800        prefix_cacher: &mut PrefixCacheManagerV2,
801        disable_eos_stop: bool,
802        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
803    ) -> Result<(), candle_core::Error>;
804
805    fn category(&self) -> ModelCategory;
806}
807
808pub(crate) fn extract_logits(
809    logits: &Tensor,
810    context_lens: Vec<(usize, usize)>,
811) -> candle_core::Result<Tensor> {
812    let mut toks = Vec::new();
813    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
814        toks.push(dim.narrow(1, start, len)?);
815    }
816    Tensor::cat(&toks, 0)
817}
818
819#[cfg(test)]
820mod tests {
821    use crate::MessageContent;
822    use either::Either;
823    use indexmap::IndexMap;
824    use serde_json::Value;
825
826    macro_rules! hashmap {
827        (@single $($x:tt)*) => (());
828        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
829
830        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
831        ($($key:expr => $value:expr),*) => {
832            {
833                let _cap = hashmap!(@count $($key),*);
834                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
835                $(
836                    let _ = _map.insert($key, Value::String($value));
837                )*
838                _map
839            }
840        };
841    }
842
843    #[cfg(test)]
844    #[track_caller]
845    fn test_with_inputs(
846        templates: &[(bool, &str, &str, &str, &str)],
847        expected_outputs: &[&str],
848        inputs: Vec<IndexMap<String, MessageContent>>,
849    ) {
850        use crate::pipeline::chat_template::ChatTemplateValue;
851
852        use super::chat_template::apply_chat_template_to;
853        let mut failed = Vec::new();
854        let n_templates = templates.len();
855        for ((has_system, bos, eos, unk, template), expected) in
856            templates.iter().zip(expected_outputs)
857        {
858            let output = match apply_chat_template_to(
859                if !has_system {
860                    inputs[1..].to_vec()
861                } else {
862                    inputs.clone()
863                },
864                true,
865                None,
866                &ChatTemplateValue(Either::Left(template.to_string())),
867                Some(bos.to_string()),
868                Some(eos.to_string()),
869                Some(unk.to_string()),
870                Vec::new(),
871            ) {
872                Ok(v) => v,
873                Err(e) => {
874                    failed.push(format!("Failed with {e}."));
875                    continue;
876                }
877            };
878            if output != *expected {
879                failed.push(format!(
880                    "Expected: `{}` \n\nGot:      `{}`",
881                    expected.replace('\n', "\\n"),
882                    output.replace('\n', "\\n")
883                ));
884            }
885        }
886        if !failed.is_empty() {
887            for (i, line) in failed.iter().enumerate() {
888                println!("------------ Template {i} ------------");
889                println!("{line}");
890            }
891            println!("------------------------");
892            panic!("{}/{n_templates} chat templates failed.", failed.len());
893        }
894    }
895
896    #[test]
897    /// Generating these cases:
898    /// ```py
899    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
900    /// # If non-system prompt model
901    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
902    /// # If system prompt model
903    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
904    /// ```
905    fn test_chat_templates() {
906        let templates = [
907            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
908            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
909            // mistralai/Mistral-7B-Instruct-v0.1
910            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
911            // meta-llama/Llama-2-13b-chat-hf
912            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
913            // mistralai/Mixtral-8x7B-Instruct-v0.1
914            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
915            // google/gemma-7b-it
916            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
917            // HuggingFaceM4/idefics2-8b-chatty
918            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
919        ];
920        let expected_outputs = [
921            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
922            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
923            // mistralai/Mistral-7B-Instruct-v0.1
924            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
925            // meta-llama/Llama-2-13b-chat-hf
926            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
927            // mistralai/Mixtral-8x7B-Instruct-v0.1
928            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
929            // google/gemma-7b-it
930            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
931        ];
932        let messages = [
933            ["system", "You are a helpful assistant"],
934            ["user", "Hello"],
935            ["assistant", "Hi there"],
936            ["user", "Who are you"],
937            ["assistant", "   I am an assistant   "],
938            ["user", "Another question"],
939        ];
940        let mut inputs = Vec::new();
941        for [role, content] in messages {
942            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
943                IndexMap::new();
944            message.insert("role".to_string(), Either::Left(role.to_string()));
945            message.insert("content".to_string(), Either::Left(content.to_string()));
946            inputs.push(message);
947        }
948        test_with_inputs(&templates, &expected_outputs, inputs);
949    }
950
951    #[test]
952    /// Generating these cases:
953    /// ```py
954    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
955    /// >>> processor.apply_chat_template([
956    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
957    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
958    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
959    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
960    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
961    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
962    ///     ], add_generation_prompt=True, tokenize=False)
963    /// ```
964    fn test_image_chat_templates() {
965        let templates = [
966            // HuggingFaceM4/idefics2-8b-chatty
967            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
968        ];
969        let expected_outputs = [
970            // HuggingFaceM4/idefics2-8b-chatty
971            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
972        ];
973
974        let mut inputs = Vec::new();
975
976        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
977            IndexMap::new();
978        message.insert("role".to_string(), Either::Left("system".to_string()));
979        message.insert(
980            "content".to_string(),
981            Either::Right(vec![hashmap! {
982                "type".to_string() => "text".to_string(),
983                "text".to_string() => "You are a helpful assistant".to_string()
984            }]),
985        );
986        inputs.push(message);
987
988        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
989            IndexMap::new();
990        message.insert("role".to_string(), Either::Left("user".to_string()));
991        message.insert(
992            "content".to_string(),
993            Either::Right(vec![
994                hashmap! {
995                    "type".to_string() => "image".to_string()
996                },
997                hashmap! {
998                    "type".to_string() => "text".to_string(),
999                    "text".to_string() => "Hello, please describe the above.".to_string()
1000                },
1001            ]),
1002        );
1003        inputs.push(message);
1004
1005        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1006            IndexMap::new();
1007        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1008        message.insert(
1009            "content".to_string(),
1010            Either::Right(vec![hashmap! {
1011                "type".to_string() => "text".to_string(),
1012                "text".to_string() => "Hi there".to_string()
1013            }]),
1014        );
1015        inputs.push(message);
1016
1017        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1018            IndexMap::new();
1019        message.insert("role".to_string(), Either::Left("user".to_string()));
1020        message.insert(
1021            "content".to_string(),
1022            Either::Right(vec![
1023                hashmap! {
1024                    "type".to_string() => "image".to_string()
1025                },
1026                hashmap! {
1027                    "type".to_string() => "text".to_string(),
1028                    "text".to_string() => "This is me, who are you".to_string()
1029                },
1030            ]),
1031        );
1032        inputs.push(message);
1033
1034        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1035            IndexMap::new();
1036        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1037        message.insert(
1038            "content".to_string(),
1039            Either::Right(vec![hashmap! {
1040                "type".to_string() => "text".to_string(),
1041                "text".to_string() => "   I am an assistant   ".to_string()
1042            }]),
1043        );
1044        inputs.push(message);
1045
1046        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1047            IndexMap::new();
1048        message.insert("role".to_string(), Either::Left("user".to_string()));
1049        message.insert(
1050            "content".to_string(),
1051            Either::Right(vec![
1052                hashmap! {
1053                    "type".to_string() => "image".to_string()
1054                },
1055                hashmap! {
1056                    "type".to_string() => "text".to_string(),
1057                    "text".to_string() => "Another question, what is this?".to_string()
1058                },
1059            ]),
1060        );
1061        inputs.push(message);
1062
1063        test_with_inputs(&templates, &expected_outputs, inputs);
1064    }
1065}