mistralrs_core/pipeline/
mod.rs

1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod embedding;
6mod ggml;
7mod gguf;
8mod inputs_processor;
9mod isq;
10pub(crate) mod llg;
11mod loaders;
12mod macros;
13mod normal;
14mod paths;
15mod processing;
16mod response;
17mod sampling;
18mod speculative;
19mod speech;
20mod vision;
21
22pub use super::diffusion_models::DiffusionGenerationParams;
23use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
24use crate::device_map::DeviceMapper;
25use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
26use crate::prefix_cacher::PrefixCacheManagerV2;
27pub use amoe::{AnyMoeLoader, AnyMoePipeline};
28pub use auto::{AutoLoader, AutoLoaderBuilder};
29use chat_template::ChatTemplate;
30pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
31pub use embedding::{EmbeddingLoader, EmbeddingLoaderBuilder, EmbeddingSpecificConfig};
32pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
33pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
34use image::DynamicImage;
35pub use inputs_processor::InputProcessorOutput;
36pub(crate) use isq::IsqModelLoader;
37pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
38use llguidance::toktrie::TokEnv;
39pub use loaders::{
40    AdapterKind, AutoDeviceMapParams, AutoEmbeddingLoader, AutoNormalLoader, AutoVisionLoader,
41    DeepSeekV2Loader, DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType,
42    DiffusionModel, DiffusionModelLoader, EmbeddingGemmaLoader, EmbeddingLoaderType,
43    EmbeddingModel, EmbeddingModelLoader, EmbeddingModelPaths, EmbeddingModule,
44    EmbeddingModulePaths, EmbeddingModuleType, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader,
45    Gemma3nLoader, GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader,
46    LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader,
47    MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel,
48    NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader,
49    PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
50    Qwen3EmbeddingLoader, Qwen3Loader, Qwen3MoELoader, Qwen3VLLoader, SmolLm3Loader,
51    Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType, VisionModel,
52    VisionModelLoader,
53};
54use mistralrs_quant::IsqType;
55pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
56pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
57pub use paths::{AdapterPaths, LoraAdapterPaths};
58pub(crate) use processing::{
59    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
60};
61use rand_isaac::Isaac64Rng;
62pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
63pub use speech::{SpeechLoader, SpeechPipeline};
64use std::any::Any;
65use std::collections::HashMap;
66use std::fmt::Debug;
67use std::sync::Arc;
68use std::time::{Duration, Instant};
69use tokenizers::Tokenizer;
70pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
71
72use anyhow::Result;
73use candle_core::{DType, Device, IndexOp, Tensor, Var};
74
75use crate::sequence::Sequence;
76
77pub use self::inputs_processor::{
78    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
79};
80use self::text_models_inputs_processor::PagedAttentionMeta;
81pub use crate::kv_cache::{
82    Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
83};
84
85#[derive(Clone, PartialEq, Eq)]
86pub enum SupportedModality {
87    Text,
88    Audio,
89    Vision,
90    Embedding,
91}
92
93impl Debug for SupportedModality {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Self::Text => write!(f, "📝 Text"),
97            Self::Audio => write!(f, "🔊 Audio"),
98            Self::Vision => write!(f, "🖼️ Vision"),
99            Self::Embedding => write!(f, "🔢 Embedding"),
100        }
101    }
102}
103
104#[derive(Debug, Clone)]
105pub struct Modalities {
106    pub input: Vec<SupportedModality>,
107    pub output: Vec<SupportedModality>,
108}
109
110pub struct GeneralMetadata {
111    pub max_seq_len: usize,
112    /// Only None if it doesn't make sense for the model
113    pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
114    pub no_kv_cache: bool,
115    pub no_prefix_cache: bool,
116    pub num_hidden_layers: usize,
117    pub eos_tok: Vec<u32>,
118    pub kind: ModelKind,
119    // TODO: Replace is_xlora queries to check via kind instead:
120    pub is_xlora: bool,
121    pub activation_dtype: DType,
122    pub sliding_window: Option<usize>,
123    // PagedAttention stuff
124    pub cache_config: Option<CacheConfig>,
125    pub cache_engine: Option<CacheEngine>,
126    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
127    pub modalities: Modalities,
128}
129
130impl GeneralMetadata {
131    pub fn tok_env(&self) -> Option<TokEnv> {
132        self.llg_factory.as_ref().map(|f| f.tok_env().clone())
133    }
134}
135
136pub enum CacheInstruction {
137    In,
138    Out,
139    /// load_preallocated_cache means to load the preallocated cache, if applicable.
140    Reset {
141        load_preallocated_cache: bool,
142        reset_non_granular: bool,
143    },
144    Nothing,
145}
146
147pub trait PreProcessingMixin: MetadataMixin {
148    fn get_processor(&self) -> Arc<dyn Processor> {
149        Arc::new(BasicProcessor)
150    }
151    /// Only None if it doesnt make sense for the model
152    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
153    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
154}
155
156pub trait IsqPipelineMixin {
157    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
158}
159
160pub trait CacheManagerMixin {
161    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
162    /// It is not a guarantee that this will be called for each completion step.
163    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
164    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
165    /// It is not a guarantee that this will be called for each step.
166    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
167    /// Set the model cache to all None. Only called for prompt seqs.
168    /// It is not a guarantee that this will be called for each prompt step.
169    /// This may also reset the non granular state if applicable.
170    fn set_none_cache(
171        &self,
172        seqs: &mut [&mut Sequence],
173        reset_non_granular: bool,
174        modify_draft_cache: bool,
175        load_preallocated_cache: bool,
176    );
177    fn cache(&self) -> &EitherCache;
178    fn do_preallocated_cache(&self) -> bool {
179        matches!(self.cache(), EitherCache::Normal(_))
180    }
181}
182
183pub trait MetadataMixin {
184    fn device(&self) -> Device;
185    /// Only None if it doesnt make sense for the model
186    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
187    fn name(&self) -> String;
188    fn reset_non_granular_state(&self);
189    fn get_metadata(&self) -> Arc<GeneralMetadata>;
190    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
191}
192
193/// Implemented by the base model of an AnyMoe.
194pub trait AnyMoePipelineMixin {
195    /// Get vars for each gating layer
196    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
197        unreachable!()
198    }
199    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
200        unreachable!()
201    }
202    fn amoe_base_model_trainable_params(&self) -> usize {
203        unreachable!()
204    }
205    fn amoe_supported(&self) -> bool {
206        false
207    }
208    /// Per-layer cached outputs.
209    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
210        unreachable!()
211    }
212    /// Inject the MoE layers
213    #[allow(clippy::too_many_arguments)]
214    fn amoe_create_layers(
215        &mut self,
216        _model_ids: Vec<String>,
217        _token: &TokenSource,
218        _revision: Option<String>,
219        _match_regex: &str,
220        _config: AnyMoeConfig,
221        _dtype: DType,
222        _dev: &Device,
223        (_prefix, _mlp): (String, String),
224        _layers: Vec<usize>,
225        _expert_type: AnyMoeExpertType,
226        _silent: bool,
227        _gate_model_id: Option<String>,
228    ) -> candle_core::Result<()> {
229        unreachable!()
230    }
231    /// Pre-train the gating layers
232    #[allow(clippy::too_many_arguments)]
233    fn amoe_pre_train(
234        &self,
235        _inputs: AnyMoeTrainingInputs,
236        (_prefix, _mlp): (String, String),
237        _model_ids: Vec<String>,
238        _token: TokenSource,
239        _revision: Option<String>,
240        _layers: Vec<usize>,
241        _silent: bool,
242    ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
243        unreachable!()
244    }
245}
246
247/// Category of the model. This can also be used to extract model-category specific tools,
248/// such as the vision model prompt prefixer.
249#[derive(Clone)]
250pub enum ModelCategory {
251    Text,
252    Vision {
253        prefixer: Arc<dyn MultimodalPromptPrefixer>,
254    },
255    Diffusion,
256    Audio,
257    Speech,
258    Embedding,
259}
260
261impl std::fmt::Debug for ModelCategory {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        match self {
264            ModelCategory::Text => write!(f, "ModelCategory::Text"),
265            ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
266            ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
267            ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
268            ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
269            ModelCategory::Embedding => write!(f, "ModelCategory::Embedding"),
270        }
271    }
272}
273
274impl PartialEq for ModelCategory {
275    fn eq(&self, other: &Self) -> bool {
276        match (self, other) {
277            (Self::Text, Self::Text) => true,
278            (Self::Vision { .. }, Self::Vision { .. }) => true,
279            (Self::Audio, Self::Audio) => true,
280            (Self::Speech, Self::Speech) => true,
281            (Self::Diffusion, Self::Diffusion) => true,
282            (Self::Embedding, Self::Embedding) => true,
283            (
284                Self::Text
285                | Self::Vision { .. }
286                | Self::Diffusion
287                | Self::Audio
288                | Self::Speech
289                | Self::Embedding,
290                _,
291            ) => false,
292        }
293    }
294}
295
296/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at 0.
297pub trait MultimodalPromptPrefixer: Send + Sync {
298    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
299    fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
300        prompt.to_string()
301    }
302    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
303    fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
304        prompt.to_string()
305    }
306}
307
308pub enum CacheBackendMetadata {
309    DefaultInstructions {
310        pre_op: CacheInstruction,
311        post_op: CacheInstruction,
312    },
313    PagedAttention {
314        metadata: PagedAttentionMeta,
315        blocks_to_copy: HashMap<usize, Vec<usize>>,
316    },
317}
318
319#[derive(Clone, Debug)]
320pub enum ForwardInputsResult {
321    RawLogits {
322        logits: Tensor,
323    },
324    Embeddings {
325        embeddings: Tensor,
326    },
327    CausalGeneration {
328        logits: Tensor,
329    },
330    Image {
331        images: Vec<DynamicImage>,
332    },
333    Speech {
334        pcms: Vec<Arc<Vec<f32>>>,
335        rates: Vec<usize>,
336        channels: Vec<usize>,
337    },
338}
339
340impl ForwardInputsResult {
341    fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
342        match self {
343            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
344                logits: logits.i(bs_idx)?,
345            }),
346            Self::Embeddings { embeddings } => Ok(Self::Embeddings {
347                embeddings: embeddings.i(bs_idx)?,
348            }),
349            Self::RawLogits { logits } => Ok(Self::RawLogits {
350                logits: logits.i(bs_idx)?,
351            }),
352            Self::Image { images } => Ok(Self::Image {
353                images: vec![images[bs_idx].clone()],
354            }),
355            Self::Speech {
356                pcms,
357                rates,
358                channels,
359            } => Ok(Self::Speech {
360                pcms: vec![pcms[bs_idx].clone()],
361                rates: vec![rates[bs_idx]],
362                channels: vec![channels[bs_idx]],
363            }),
364        }
365    }
366
367    fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
368        match self {
369            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
370                logits: logits.to_device(device)?,
371            }),
372            Self::RawLogits { logits } => Ok(Self::RawLogits {
373                logits: logits.to_device(device)?,
374            }),
375            Self::Embeddings { embeddings } => Ok(Self::Embeddings {
376                embeddings: embeddings.to_device(device)?,
377            }),
378            Self::Image { .. } => Ok(self.clone()),
379            Self::Speech { .. } => Ok(self.clone()),
380        }
381    }
382}
383
384#[derive(serde::Serialize, serde::Deserialize)]
385pub(crate) struct FileListCache {
386    files: Vec<String>,
387}
388
389#[async_trait::async_trait]
390pub trait Pipeline:
391    Send
392    + Sync
393    + PreProcessingMixin
394    + IsqPipelineMixin
395    + CacheManagerMixin
396    + MetadataMixin
397    + AnyMoePipelineMixin
398{
399    fn forward_inputs(
400        &mut self,
401        inputs: Box<dyn Any>,
402        return_raw_logits: bool,
403    ) -> Result<ForwardInputsResult, candle_core::Error>;
404
405    /// Returns the total of model execution time.
406    #[allow(clippy::too_many_arguments)]
407    async fn step(
408        &mut self,
409        input_seqs: &mut [&mut Sequence],
410        is_prompt: bool,
411        return_raw_logits: bool,
412        prefix_cacher: &mut PrefixCacheManagerV2,
413        disable_eos_stop: bool,
414        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
415        backend_metadata: CacheBackendMetadata,
416    ) -> Result<Duration, candle_core::Error> {
417        match backend_metadata {
418            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
419                let inputs_iter =
420                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
421                        self.tokenizer(),
422                        input_seqs,
423                        is_prompt,
424                        self.get_metadata().is_xlora,
425                        &self.device(),
426                        self.get_metadata().no_kv_cache,
427                        None,
428                        return_raw_logits,
429                        self.get_input_processor_config(),
430                        None,
431                        self.device_mapper(),
432                    ));
433
434                let mut logits = vec![None; input_seqs.len()];
435                let len_inputs = 1;
436                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
437                let mut embedding_logits = vec![None; input_seqs.len()];
438
439                let mut exec_duration = Duration::ZERO;
440                for (i, inputs) in inputs_iter.into_iter().enumerate() {
441                    let InputProcessorOutput {
442                        inputs,
443                        seq_indices,
444                    } = inputs.map_err(candle_core::Error::msg)?;
445                    if i == 0 {
446                        match pre_op {
447                            CacheInstruction::In => self.clone_in_cache(input_seqs),
448                            CacheInstruction::Nothing => (),
449                            CacheInstruction::Reset {
450                                load_preallocated_cache,
451                                reset_non_granular,
452                            } => self.set_none_cache(
453                                input_seqs,
454                                reset_non_granular,
455                                false,
456                                load_preallocated_cache,
457                            ),
458                            _ => unreachable!("Unreachable PRE cache op."),
459                        }
460                    }
461
462                    let start = Instant::now();
463                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
464                    let end = Instant::now();
465                    exec_duration += end.duration_since(start);
466
467                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
468                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
469                            raw_out_logits[seq_idx][i] =
470                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
471                        } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
472                            embedding_logits[seq_idx] =
473                                Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
474                        } else {
475                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
476                        }
477                    }
478                }
479
480                match post_op {
481                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
482                    CacheInstruction::Nothing => (),
483                    CacheInstruction::Reset {
484                        load_preallocated_cache,
485                        reset_non_granular,
486                    } => self.set_none_cache(
487                        input_seqs,
488                        reset_non_granular,
489                        false,
490                        load_preallocated_cache,
491                    ),
492                    _ => unreachable!("Unreachable POST cache op."),
493                }
494
495                if raw_out_logits[0][0].is_some() {
496                    let start = Instant::now();
497                    response::send_raw_responses(
498                        input_seqs,
499                        raw_out_logits
500                            .into_iter()
501                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
502                            .collect(),
503                    )
504                    .await?;
505                    let end = Instant::now();
506                    exec_duration += end.duration_since(start);
507
508                    return Ok(exec_duration);
509                }
510                if embedding_logits[0].is_some() {
511                    let start = Instant::now();
512                    response::send_embedding_responses(
513                        input_seqs,
514                        embedding_logits
515                            .into_iter()
516                            .map(|raw| {
517                                raw.unwrap()
518                                    .to_dtype(DType::F32)
519                                    .unwrap()
520                                    .to_vec1::<f32>()
521                                    .unwrap()
522                            })
523                            .collect(),
524                    )
525                    .await?;
526                    let end = Instant::now();
527                    exec_duration += end.duration_since(start);
528
529                    return Ok(exec_duration);
530                }
531
532                let start = Instant::now();
533                let logits_on_cpu = logits.len() > 1;
534                let logits = logits
535                    .into_iter()
536                    .map(|l| {
537                        let l = l.expect("Did not get any inputs. This is shocking.");
538                        if logits_on_cpu {
539                            l.to_device(&Device::Cpu)
540                        } else {
541                            Ok(l)
542                        }
543                    })
544                    .collect::<candle_core::Result<Vec<_>>>()?;
545
546                match &logits[0] {
547                    ForwardInputsResult::RawLogits { .. }
548                    | ForwardInputsResult::Embeddings { .. } => unreachable!(),
549                    ForwardInputsResult::CausalGeneration { .. } => {
550                        self.sample_causal_gen(
551                            input_seqs,
552                            logits
553                                .into_iter()
554                                .map(|r| {
555                                    #[allow(irrefutable_let_patterns)]
556                                    let ForwardInputsResult::CausalGeneration { logits } = r
557                                    else {
558                                        unreachable!(
559                                            "All results must have same type, `CausalGeneration`"
560                                        )
561                                    };
562                                    logits
563                                })
564                                .collect::<Vec<_>>(),
565                            prefix_cacher,
566                            disable_eos_stop,
567                            rng,
568                        )
569                        .await?;
570                    }
571                    ForwardInputsResult::Image { .. } => {
572                        response::send_image_responses(
573                            input_seqs,
574                            logits
575                                .into_iter()
576                                .map(|r| {
577                                    #[allow(irrefutable_let_patterns)]
578                                    let ForwardInputsResult::Image { images } = r
579                                    else {
580                                        unreachable!("All results must have same type, `Image`")
581                                    };
582                                    images
583                                        .into_iter()
584                                        .next()
585                                        .expect("Must have at least 1 element.")
586                                })
587                                .collect::<Vec<_>>(),
588                        )
589                        .await?;
590                    }
591                    ForwardInputsResult::Speech { .. } => {
592                        let rates = logits
593                            .iter()
594                            .map(|r| {
595                                #[allow(irrefutable_let_patterns)]
596                                let ForwardInputsResult::Speech { rates, .. } = r
597                                else {
598                                    unreachable!("All results must have same type, `Speech`")
599                                };
600                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
601                                *rates.first().unwrap()
602                            })
603                            .collect::<Vec<_>>();
604                        let channels = logits
605                            .iter()
606                            .map(|r| {
607                                #[allow(irrefutable_let_patterns)]
608                                let ForwardInputsResult::Speech { channels, .. } = r
609                                else {
610                                    unreachable!("All results must have same type, `Speech`")
611                                };
612                                assert_eq!(
613                                    channels.len(),
614                                    1,
615                                    "Each sequence must have 1 PCM output."
616                                );
617                                *channels.first().unwrap()
618                            })
619                            .collect::<Vec<_>>();
620                        let pcms = logits
621                            .into_iter()
622                            .map(|r| {
623                                #[allow(irrefutable_let_patterns)]
624                                let ForwardInputsResult::Speech { pcms, .. } = r
625                                else {
626                                    unreachable!("All results must have same type, `Speech`")
627                                };
628                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
629                                pcms.into_iter().nth(0).unwrap()
630                            })
631                            .collect::<Vec<_>>();
632                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
633                            .await?;
634                    }
635                }
636                let end = Instant::now();
637                exec_duration += end.duration_since(start);
638
639                Ok(exec_duration)
640            }
641            CacheBackendMetadata::PagedAttention {
642                metadata,
643                blocks_to_copy,
644            } => {
645                // Cloning might be bad?
646                self.get_metadata()
647                    .cache_engine
648                    .as_ref()
649                    .expect("PagedAttention must have cache engines.")
650                    .execute_scheduler_ops(&blocks_to_copy)?;
651
652                let inputs_iter =
653                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
654                        self.tokenizer(),
655                        input_seqs,
656                        is_prompt,
657                        self.get_metadata().is_xlora,
658                        &self.device(),
659                        self.get_metadata().no_kv_cache,
660                        None,
661                        return_raw_logits,
662                        self.get_input_processor_config(),
663                        Some(metadata),
664                        self.device_mapper(),
665                    ));
666
667                let mut logits = vec![None; input_seqs.len()];
668                let len_inputs = 1;
669                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
670                let mut embedding_logits = vec![None; input_seqs.len()];
671
672                let mut exec_duration = Duration::ZERO;
673                for (i, inputs) in inputs_iter.into_iter().enumerate() {
674                    let InputProcessorOutput {
675                        inputs,
676                        seq_indices,
677                    } = inputs.map_err(candle_core::Error::msg)?;
678
679                    let start = Instant::now();
680                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
681                    let end = Instant::now();
682                    exec_duration += end.duration_since(start);
683
684                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
685                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
686                            raw_out_logits[seq_idx][i] =
687                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
688                        } else if let ForwardInputsResult::Embeddings { embeddings } = &raw_logits {
689                            embedding_logits[seq_idx] =
690                                Some(embeddings.i(logit_idx)?.to_device(&Device::Cpu)?);
691                        } else {
692                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
693                        }
694                    }
695                }
696
697                if raw_out_logits[0][0].is_some() {
698                    let start = Instant::now();
699                    response::send_raw_responses(
700                        input_seqs,
701                        raw_out_logits
702                            .into_iter()
703                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
704                            .collect(),
705                    )
706                    .await?;
707                    let end = Instant::now();
708                    exec_duration += end.duration_since(start);
709
710                    return Ok(exec_duration);
711                }
712                if embedding_logits[0].is_some() {
713                    let start = Instant::now();
714                    response::send_embedding_responses(
715                        input_seqs,
716                        embedding_logits
717                            .into_iter()
718                            .map(|raw| {
719                                raw.unwrap()
720                                    .to_dtype(DType::F32)
721                                    .unwrap()
722                                    .to_vec1::<f32>()
723                                    .unwrap()
724                            })
725                            .collect(),
726                    )
727                    .await?;
728                    let end = Instant::now();
729                    exec_duration += end.duration_since(start);
730
731                    return Ok(exec_duration);
732                }
733
734                let start = Instant::now();
735                let logits_on_cpu = logits.len() > 1;
736                let logits = logits
737                    .into_iter()
738                    .map(|l| {
739                        let l = l.expect("Did not get any inputs. This is shocking.");
740                        if logits_on_cpu {
741                            l.to_device(&Device::Cpu)
742                        } else {
743                            Ok(l)
744                        }
745                    })
746                    .collect::<candle_core::Result<Vec<_>>>()?;
747
748                match &logits[0] {
749                    ForwardInputsResult::RawLogits { .. }
750                    | ForwardInputsResult::Embeddings { .. } => unreachable!(),
751                    ForwardInputsResult::CausalGeneration { .. } => {
752                        self.sample_causal_gen(
753                            input_seqs,
754                            logits
755                                .into_iter()
756                                .map(|r| {
757                                    #[allow(irrefutable_let_patterns)]
758                                    let ForwardInputsResult::CausalGeneration { logits } = r
759                                    else {
760                                        unreachable!("All results must have same type")
761                                    };
762                                    logits
763                                })
764                                .collect::<Vec<_>>(),
765                            prefix_cacher,
766                            disable_eos_stop,
767                            rng,
768                        )
769                        .await?;
770                    }
771                    ForwardInputsResult::Image { .. } => {
772                        response::send_image_responses(
773                            input_seqs,
774                            logits
775                                .into_iter()
776                                .map(|r| {
777                                    #[allow(irrefutable_let_patterns)]
778                                    let ForwardInputsResult::Image { images } = r
779                                    else {
780                                        unreachable!("All results must have same type, `Image`")
781                                    };
782                                    images
783                                        .into_iter()
784                                        .next()
785                                        .expect("Must have at least 1 element.")
786                                })
787                                .collect::<Vec<_>>(),
788                        )
789                        .await?;
790                    }
791                    ForwardInputsResult::Speech { .. } => {
792                        let rates = logits
793                            .iter()
794                            .map(|r| {
795                                #[allow(irrefutable_let_patterns)]
796                                let ForwardInputsResult::Speech { rates, .. } = r
797                                else {
798                                    unreachable!("All results must have same type, `Speech`")
799                                };
800                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
801                                *rates.first().unwrap()
802                            })
803                            .collect::<Vec<_>>();
804                        let channels = logits
805                            .iter()
806                            .map(|r| {
807                                #[allow(irrefutable_let_patterns)]
808                                let ForwardInputsResult::Speech { channels, .. } = r
809                                else {
810                                    unreachable!("All results must have same type, `Speech`")
811                                };
812                                assert_eq!(
813                                    channels.len(),
814                                    1,
815                                    "Each sequence must have 1 PCM output."
816                                );
817                                *channels.first().unwrap()
818                            })
819                            .collect::<Vec<_>>();
820                        let pcms = logits
821                            .into_iter()
822                            .map(|r| {
823                                #[allow(irrefutable_let_patterns)]
824                                let ForwardInputsResult::Speech { pcms, .. } = r
825                                else {
826                                    unreachable!("All results must have same type, `Speech`")
827                                };
828                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
829                                pcms.into_iter().nth(0).unwrap()
830                            })
831                            .collect::<Vec<_>>();
832                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
833                            .await?;
834                    }
835                }
836                let end = Instant::now();
837                exec_duration += end.duration_since(start);
838
839                Ok(exec_duration)
840            }
841        }
842    }
843
844    async fn sample_causal_gen(
845        &self,
846        seqs: &mut [&mut Sequence],
847        logits: Vec<Tensor>,
848        prefix_cacher: &mut PrefixCacheManagerV2,
849        disable_eos_stop: bool,
850        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
851    ) -> Result<(), candle_core::Error>;
852
853    fn category(&self) -> ModelCategory;
854}
855
856pub(crate) fn extract_logits(
857    logits: &Tensor,
858    context_lens: Vec<(usize, usize)>,
859) -> candle_core::Result<Tensor> {
860    let mut toks = Vec::new();
861    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
862        toks.push(dim.narrow(1, start, len)?);
863    }
864    Tensor::cat(&toks, 0)
865}
866
867#[cfg(test)]
868mod tests {
869    use crate::MessageContent;
870    use either::Either;
871    use indexmap::IndexMap;
872    use serde_json::Value;
873
874    macro_rules! hashmap {
875        (@single $($x:tt)*) => (());
876        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
877
878        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
879        ($($key:expr => $value:expr),*) => {
880            {
881                let _cap = hashmap!(@count $($key),*);
882                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
883                $(
884                    let _ = _map.insert($key, Value::String($value));
885                )*
886                _map
887            }
888        };
889    }
890
891    #[cfg(test)]
892    #[track_caller]
893    fn test_with_inputs(
894        templates: &[(bool, &str, &str, &str, &str)],
895        expected_outputs: &[&str],
896        inputs: Vec<IndexMap<String, MessageContent>>,
897    ) {
898        use crate::pipeline::chat_template::ChatTemplateValue;
899
900        use super::chat_template::apply_chat_template_to;
901        let mut failed = Vec::new();
902        let n_templates = templates.len();
903        for ((has_system, bos, eos, unk, template), expected) in
904            templates.iter().zip(expected_outputs)
905        {
906            let output = match apply_chat_template_to(
907                if !has_system {
908                    inputs[1..].to_vec()
909                } else {
910                    inputs.clone()
911                },
912                true,
913                None,
914                &ChatTemplateValue(Either::Left(template.to_string())),
915                Some(bos.to_string()),
916                Some(eos.to_string()),
917                Some(unk.to_string()),
918                Vec::new(),
919            ) {
920                Ok(v) => v,
921                Err(e) => {
922                    failed.push(format!("Failed with {e}."));
923                    continue;
924                }
925            };
926            if output != *expected {
927                failed.push(format!(
928                    "Expected: `{}` \n\nGot:      `{}`",
929                    expected.replace('\n', "\\n"),
930                    output.replace('\n', "\\n")
931                ));
932            }
933        }
934        if !failed.is_empty() {
935            for (i, line) in failed.iter().enumerate() {
936                println!("------------ Template {i} ------------");
937                println!("{line}");
938            }
939            println!("------------------------");
940            panic!("{}/{n_templates} chat templates failed.", failed.len());
941        }
942    }
943
944    #[test]
945    /// Generating these cases:
946    /// ```py
947    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
948    /// # If non-system prompt model
949    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
950    /// # If system prompt model
951    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
952    /// ```
953    fn test_chat_templates() {
954        let templates = [
955            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
956            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
957            // mistralai/Mistral-7B-Instruct-v0.1
958            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
959            // meta-llama/Llama-2-13b-chat-hf
960            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
961            // mistralai/Mixtral-8x7B-Instruct-v0.1
962            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
963            // google/gemma-7b-it
964            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
965            // HuggingFaceM4/idefics2-8b-chatty
966            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
967        ];
968        let expected_outputs = [
969            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
970            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
971            // mistralai/Mistral-7B-Instruct-v0.1
972            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
973            // meta-llama/Llama-2-13b-chat-hf
974            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
975            // mistralai/Mixtral-8x7B-Instruct-v0.1
976            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
977            // google/gemma-7b-it
978            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
979        ];
980        let messages = [
981            ["system", "You are a helpful assistant"],
982            ["user", "Hello"],
983            ["assistant", "Hi there"],
984            ["user", "Who are you"],
985            ["assistant", "   I am an assistant   "],
986            ["user", "Another question"],
987        ];
988        let mut inputs = Vec::new();
989        for [role, content] in messages {
990            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
991                IndexMap::new();
992            message.insert("role".to_string(), Either::Left(role.to_string()));
993            message.insert("content".to_string(), Either::Left(content.to_string()));
994            inputs.push(message);
995        }
996        test_with_inputs(&templates, &expected_outputs, inputs);
997    }
998
999    #[test]
1000    /// Generating these cases:
1001    /// ```py
1002    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
1003    /// >>> processor.apply_chat_template([
1004    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
1005    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
1006    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
1007    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
1008    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
1009    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
1010    ///     ], add_generation_prompt=True, tokenize=False)
1011    /// ```
1012    fn test_image_chat_templates() {
1013        let templates = [
1014            // HuggingFaceM4/idefics2-8b-chatty
1015            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
1016        ];
1017        let expected_outputs = [
1018            // HuggingFaceM4/idefics2-8b-chatty
1019            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
1020        ];
1021
1022        let mut inputs = Vec::new();
1023
1024        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1025            IndexMap::new();
1026        message.insert("role".to_string(), Either::Left("system".to_string()));
1027        message.insert(
1028            "content".to_string(),
1029            Either::Right(vec![hashmap! {
1030                "type".to_string() => "text".to_string(),
1031                "text".to_string() => "You are a helpful assistant".to_string()
1032            }]),
1033        );
1034        inputs.push(message);
1035
1036        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1037            IndexMap::new();
1038        message.insert("role".to_string(), Either::Left("user".to_string()));
1039        message.insert(
1040            "content".to_string(),
1041            Either::Right(vec![
1042                hashmap! {
1043                    "type".to_string() => "image".to_string()
1044                },
1045                hashmap! {
1046                    "type".to_string() => "text".to_string(),
1047                    "text".to_string() => "Hello, please describe the above.".to_string()
1048                },
1049            ]),
1050        );
1051        inputs.push(message);
1052
1053        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1054            IndexMap::new();
1055        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1056        message.insert(
1057            "content".to_string(),
1058            Either::Right(vec![hashmap! {
1059                "type".to_string() => "text".to_string(),
1060                "text".to_string() => "Hi there".to_string()
1061            }]),
1062        );
1063        inputs.push(message);
1064
1065        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1066            IndexMap::new();
1067        message.insert("role".to_string(), Either::Left("user".to_string()));
1068        message.insert(
1069            "content".to_string(),
1070            Either::Right(vec![
1071                hashmap! {
1072                    "type".to_string() => "image".to_string()
1073                },
1074                hashmap! {
1075                    "type".to_string() => "text".to_string(),
1076                    "text".to_string() => "This is me, who are you".to_string()
1077                },
1078            ]),
1079        );
1080        inputs.push(message);
1081
1082        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1083            IndexMap::new();
1084        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1085        message.insert(
1086            "content".to_string(),
1087            Either::Right(vec![hashmap! {
1088                "type".to_string() => "text".to_string(),
1089                "text".to_string() => "   I am an assistant   ".to_string()
1090            }]),
1091        );
1092        inputs.push(message);
1093
1094        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1095            IndexMap::new();
1096        message.insert("role".to_string(), Either::Left("user".to_string()));
1097        message.insert(
1098            "content".to_string(),
1099            Either::Right(vec![
1100                hashmap! {
1101                    "type".to_string() => "image".to_string()
1102                },
1103                hashmap! {
1104                    "type".to_string() => "text".to_string(),
1105                    "text".to_string() => "Another question, what is this?".to_string()
1106                },
1107            ]),
1108        );
1109        inputs.push(message);
1110
1111        test_with_inputs(&templates, &expected_outputs, inputs);
1112    }
1113}