mistralrs_core/pipeline/
mod.rs

1mod amoe;
2mod auto;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod speech;
19mod vision;
20
21pub use super::diffusion_models::DiffusionGenerationParams;
22use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
23use crate::device_map::DeviceMapper;
24use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
25use crate::prefix_cacher::PrefixCacheManagerV2;
26pub use amoe::{AnyMoeLoader, AnyMoePipeline};
27pub use auto::{AutoLoader, AutoLoaderBuilder};
28use chat_template::ChatTemplate;
29pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder};
30pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
31pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
32use image::DynamicImage;
33pub use inputs_processor::InputProcessorOutput;
34pub(crate) use isq::IsqModelLoader;
35pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
36use llguidance::toktrie::TokEnv;
37pub use loaders::{
38    AdapterKind, AutoDeviceMapParams, AutoNormalLoader, AutoVisionLoader, DeepSeekV2Loader,
39    DeepSeekV3Loader, DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel,
40    DiffusionModelLoader, FluxLoader, GLM4Loader, Gemma2Loader, Gemma3Loader, Gemma3nLoader,
41    GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader, LLaVANextLoader, LlamaLoader, Loader,
42    LocalModelPaths, MiniCpmOLoader, Mistral3Loader, MistralLoader, MixtralLoader, ModelKind,
43    ModelPaths, NormalLoaderType, NormalLoadingMetadata, NormalModel, NormalModelLoader,
44    Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader, Phi4MMLoader, PrettyName,
45    QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader, Qwen3Loader, Qwen3MoELoader,
46    SmolLm3Loader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader, VisionLoaderType,
47    VisionModel, VisionModelLoader,
48};
49use mistralrs_quant::IsqType;
50pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
51pub(crate) use paths::{get_chat_template, get_model_paths, get_xlora_paths};
52pub use paths::{AdapterPaths, LoraAdapterPaths};
53pub(crate) use processing::{
54    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
55};
56use rand_isaac::Isaac64Rng;
57pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
58pub use speech::{SpeechLoader, SpeechPipeline};
59use std::any::Any;
60use std::collections::HashMap;
61use std::fmt::Debug;
62use std::sync::Arc;
63use std::time::{Duration, Instant};
64use tokenizers::Tokenizer;
65pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
66
67use anyhow::Result;
68use candle_core::{DType, Device, IndexOp, Tensor, Var};
69
70use crate::sequence::Sequence;
71
72pub use self::inputs_processor::{
73    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
74};
75use self::text_models_inputs_processor::PagedAttentionMeta;
76pub use crate::kv_cache::{
77    Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
78};
79
80#[derive(Clone, PartialEq, Eq)]
81pub enum SupportedModality {
82    Text,
83    Audio,
84    Vision,
85}
86
87impl Debug for SupportedModality {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match self {
90            Self::Text => write!(f, "📝 Text"),
91            Self::Audio => write!(f, "🔊 Audio"),
92            Self::Vision => write!(f, "🖼️ Vision"),
93        }
94    }
95}
96
97#[derive(Debug, Clone)]
98pub struct Modalities {
99    pub input: Vec<SupportedModality>,
100    pub output: Vec<SupportedModality>,
101}
102
103pub struct GeneralMetadata {
104    pub max_seq_len: usize,
105    /// Only None if it doesn't make sense for the model
106    pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
107    pub no_kv_cache: bool,
108    pub no_prefix_cache: bool,
109    pub num_hidden_layers: usize,
110    pub eos_tok: Vec<u32>,
111    pub kind: ModelKind,
112    // TODO: Replace is_xlora queries to check via kind instead:
113    pub is_xlora: bool,
114    pub activation_dtype: DType,
115    pub sliding_window: Option<usize>,
116    // PagedAttention stuff
117    pub cache_config: Option<CacheConfig>,
118    pub cache_engine: Option<CacheEngine>,
119    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
120    pub modalities: Modalities,
121}
122
123impl GeneralMetadata {
124    pub fn tok_env(&self) -> Option<TokEnv> {
125        self.llg_factory.as_ref().map(|f| f.tok_env().clone())
126    }
127}
128
129pub enum CacheInstruction {
130    In,
131    Out,
132    /// load_preallocated_cache means to load the preallocated cache, if applicable.
133    Reset {
134        load_preallocated_cache: bool,
135        reset_non_granular: bool,
136    },
137    Nothing,
138}
139
140pub trait PreProcessingMixin: MetadataMixin {
141    fn get_processor(&self) -> Arc<dyn Processor> {
142        Arc::new(BasicProcessor)
143    }
144    /// Only None if it doesnt make sense for the model
145    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
146    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
147}
148
149pub trait IsqPipelineMixin {
150    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
151}
152
153pub trait CacheManagerMixin {
154    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
155    /// It is not a guarantee that this will be called for each completion step.
156    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
157    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
158    /// It is not a guarantee that this will be called for each step.
159    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
160    /// Set the model cache to all None. Only called for prompt seqs.
161    /// It is not a guarantee that this will be called for each prompt step.
162    /// This may also reset the non granular state if applicable.
163    fn set_none_cache(
164        &self,
165        seqs: &mut [&mut Sequence],
166        reset_non_granular: bool,
167        modify_draft_cache: bool,
168        load_preallocated_cache: bool,
169    );
170    fn cache(&self) -> &EitherCache;
171    fn do_preallocated_cache(&self) -> bool {
172        matches!(self.cache(), EitherCache::Normal(_))
173    }
174}
175
176pub trait MetadataMixin {
177    fn device(&self) -> Device;
178    /// Only None if it doesnt make sense for the model
179    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
180    fn name(&self) -> String;
181    fn reset_non_granular_state(&self);
182    fn get_metadata(&self) -> Arc<GeneralMetadata>;
183    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
184}
185
186/// Implemented by the base model of an AnyMoe.
187pub trait AnyMoePipelineMixin {
188    /// Get vars for each gating layer
189    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
190        unreachable!()
191    }
192    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
193        unreachable!()
194    }
195    fn amoe_base_model_trainable_params(&self) -> usize {
196        unreachable!()
197    }
198    fn amoe_supported(&self) -> bool {
199        false
200    }
201    /// Per-layer cached outputs.
202    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
203        unreachable!()
204    }
205    /// Inject the MoE layers
206    #[allow(clippy::too_many_arguments)]
207    fn amoe_create_layers(
208        &mut self,
209        _model_ids: Vec<String>,
210        _token: &TokenSource,
211        _revision: Option<String>,
212        _match_regex: &str,
213        _config: AnyMoeConfig,
214        _dtype: DType,
215        _dev: &Device,
216        (_prefix, _mlp): (String, String),
217        _layers: Vec<usize>,
218        _expert_type: AnyMoeExpertType,
219        _silent: bool,
220        _gate_model_id: Option<String>,
221    ) -> candle_core::Result<()> {
222        unreachable!()
223    }
224    /// Pre-train the gating layers
225    #[allow(clippy::too_many_arguments)]
226    fn amoe_pre_train(
227        &self,
228        _inputs: AnyMoeTrainingInputs,
229        (_prefix, _mlp): (String, String),
230        _model_ids: Vec<String>,
231        _token: TokenSource,
232        _revision: Option<String>,
233        _layers: Vec<usize>,
234        _silent: bool,
235    ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
236        unreachable!()
237    }
238}
239
240/// Category of the model. This can also be used to extract model-category specific tools,
241/// such as the vision model prompt prefixer.
242#[derive(Clone)]
243pub enum ModelCategory {
244    Text,
245    Vision {
246        prefixer: Arc<dyn MultimodalPromptPrefixer>,
247    },
248    Diffusion,
249    Audio,
250    Speech,
251}
252
253impl std::fmt::Debug for ModelCategory {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        match self {
256            ModelCategory::Text => write!(f, "ModelCategory::Text"),
257            ModelCategory::Vision { .. } => write!(f, "ModelCategory::Vision {{ prefixer: .. }}"),
258            ModelCategory::Diffusion => write!(f, "ModelCategory::Diffusion"),
259            ModelCategory::Audio => write!(f, "ModelCategory::Audio"),
260            ModelCategory::Speech => write!(f, "ModelCategory::Speech"),
261        }
262    }
263}
264
265impl PartialEq for ModelCategory {
266    fn eq(&self, other: &Self) -> bool {
267        match (self, other) {
268            (Self::Text, Self::Text) => true,
269            (Self::Vision { .. }, Self::Vision { .. }) => true,
270            (Self::Audio, Self::Audio) => true,
271            (Self::Speech, Self::Speech) => true,
272            (Self::Diffusion, Self::Diffusion) => true,
273            (
274                Self::Text | Self::Vision { .. } | Self::Diffusion | Self::Audio | Self::Speech,
275                _,
276            ) => false,
277        }
278    }
279}
280
281/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at 0.
282pub trait MultimodalPromptPrefixer: Send + Sync {
283    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
284    fn prefix_image(&self, _image_indices: Vec<usize>, prompt: &str) -> String {
285        prompt.to_string()
286    }
287    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
288    fn prefix_audio(&self, _audio_indexes: Vec<usize>, prompt: &str) -> String {
289        prompt.to_string()
290    }
291}
292
293pub enum CacheBackendMetadata {
294    DefaultInstructions {
295        pre_op: CacheInstruction,
296        post_op: CacheInstruction,
297    },
298    PagedAttention {
299        metadata: PagedAttentionMeta,
300        blocks_to_swap_in: HashMap<usize, usize>,
301        blocks_to_swap_out: HashMap<usize, usize>,
302        blocks_to_copy: HashMap<usize, Vec<usize>>,
303    },
304}
305
306#[derive(Clone, Debug)]
307pub enum ForwardInputsResult {
308    RawLogits {
309        logits: Tensor,
310    },
311    CausalGeneration {
312        logits: Tensor,
313    },
314    Image {
315        images: Vec<DynamicImage>,
316    },
317    Speech {
318        pcms: Vec<Arc<Vec<f32>>>,
319        rates: Vec<usize>,
320        channels: Vec<usize>,
321    },
322}
323
324impl ForwardInputsResult {
325    fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
326        match self {
327            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
328                logits: logits.i(bs_idx)?,
329            }),
330            Self::RawLogits { logits } => Ok(Self::RawLogits {
331                logits: logits.i(bs_idx)?,
332            }),
333            Self::Image { images } => Ok(Self::Image {
334                images: vec![images[bs_idx].clone()],
335            }),
336            Self::Speech {
337                pcms,
338                rates,
339                channels,
340            } => Ok(Self::Speech {
341                pcms: vec![pcms[bs_idx].clone()],
342                rates: vec![rates[bs_idx]],
343                channels: vec![channels[bs_idx]],
344            }),
345        }
346    }
347
348    fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
349        match self {
350            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
351                logits: logits.to_device(device)?,
352            }),
353            Self::RawLogits { logits } => Ok(Self::RawLogits {
354                logits: logits.to_device(device)?,
355            }),
356            Self::Image { .. } => Ok(self.clone()),
357            Self::Speech { .. } => Ok(self.clone()),
358        }
359    }
360}
361
362#[derive(serde::Serialize, serde::Deserialize)]
363pub(crate) struct FileListCache {
364    files: Vec<String>,
365}
366
367#[async_trait::async_trait]
368pub trait Pipeline:
369    Send
370    + Sync
371    + PreProcessingMixin
372    + IsqPipelineMixin
373    + CacheManagerMixin
374    + MetadataMixin
375    + AnyMoePipelineMixin
376{
377    fn forward_inputs(
378        &mut self,
379        inputs: Box<dyn Any>,
380        return_raw_logits: bool,
381    ) -> Result<ForwardInputsResult, candle_core::Error>;
382
383    /// Returns the total of model execution time.
384    #[allow(clippy::too_many_arguments)]
385    async fn step(
386        &mut self,
387        input_seqs: &mut [&mut Sequence],
388        is_prompt: bool,
389        return_raw_logits: bool,
390        prefix_cacher: &mut PrefixCacheManagerV2,
391        disable_eos_stop: bool,
392        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
393        backend_metadata: CacheBackendMetadata,
394    ) -> Result<Duration, candle_core::Error> {
395        match backend_metadata {
396            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
397                let inputs_iter =
398                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
399                        self.tokenizer(),
400                        input_seqs,
401                        is_prompt,
402                        self.get_metadata().is_xlora,
403                        &self.device(),
404                        self.get_metadata().no_kv_cache,
405                        None,
406                        return_raw_logits,
407                        self.get_input_processor_config(),
408                        None,
409                        self.device_mapper(),
410                    ));
411
412                let mut logits = vec![None; input_seqs.len()];
413                let len_inputs = 1;
414                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
415
416                let mut exec_duration = Duration::ZERO;
417                for (i, inputs) in inputs_iter.into_iter().enumerate() {
418                    let InputProcessorOutput {
419                        inputs,
420                        seq_indices,
421                    } = inputs.map_err(candle_core::Error::msg)?;
422                    if i == 0 {
423                        match pre_op {
424                            CacheInstruction::In => self.clone_in_cache(input_seqs),
425                            CacheInstruction::Nothing => (),
426                            CacheInstruction::Reset {
427                                load_preallocated_cache,
428                                reset_non_granular,
429                            } => self.set_none_cache(
430                                input_seqs,
431                                reset_non_granular,
432                                false,
433                                load_preallocated_cache,
434                            ),
435                            _ => unreachable!("Unreachable PRE cache op."),
436                        }
437                    }
438
439                    let start = Instant::now();
440                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
441                    let end = Instant::now();
442                    exec_duration += end.duration_since(start);
443
444                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
445                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
446                            raw_out_logits[seq_idx][i] =
447                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
448                        } else {
449                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
450                        }
451                    }
452                }
453
454                match post_op {
455                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
456                    CacheInstruction::Nothing => (),
457                    CacheInstruction::Reset {
458                        load_preallocated_cache,
459                        reset_non_granular,
460                    } => self.set_none_cache(
461                        input_seqs,
462                        reset_non_granular,
463                        false,
464                        load_preallocated_cache,
465                    ),
466                    _ => unreachable!("Unreachable POST cache op."),
467                }
468
469                if raw_out_logits[0][0].is_some() {
470                    let start = Instant::now();
471                    response::send_raw_responses(
472                        input_seqs,
473                        raw_out_logits
474                            .into_iter()
475                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
476                            .collect(),
477                    )
478                    .await?;
479                    let end = Instant::now();
480                    exec_duration += end.duration_since(start);
481
482                    return Ok(exec_duration);
483                }
484
485                let start = Instant::now();
486                let logits_on_cpu = logits.len() > 1;
487                let logits = logits
488                    .into_iter()
489                    .map(|l| {
490                        let l = l.expect("Did not get any inputs. This is shocking.");
491                        if logits_on_cpu {
492                            l.to_device(&Device::Cpu)
493                        } else {
494                            Ok(l)
495                        }
496                    })
497                    .collect::<candle_core::Result<Vec<_>>>()?;
498
499                match &logits[0] {
500                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
501                    ForwardInputsResult::CausalGeneration { .. } => {
502                        self.sample_causal_gen(
503                            input_seqs,
504                            logits
505                                .into_iter()
506                                .map(|r| {
507                                    #[allow(irrefutable_let_patterns)]
508                                    let ForwardInputsResult::CausalGeneration { logits } = r
509                                    else {
510                                        unreachable!(
511                                            "All results must have same type, `CausalGeneration`"
512                                        )
513                                    };
514                                    logits
515                                })
516                                .collect::<Vec<_>>(),
517                            prefix_cacher,
518                            disable_eos_stop,
519                            rng,
520                        )
521                        .await?;
522                    }
523                    ForwardInputsResult::Image { .. } => {
524                        response::send_image_responses(
525                            input_seqs,
526                            logits
527                                .into_iter()
528                                .map(|r| {
529                                    #[allow(irrefutable_let_patterns)]
530                                    let ForwardInputsResult::Image { images } = r
531                                    else {
532                                        unreachable!("All results must have same type, `Image`")
533                                    };
534                                    images
535                                        .into_iter()
536                                        .next()
537                                        .expect("Must have at least 1 element.")
538                                })
539                                .collect::<Vec<_>>(),
540                        )
541                        .await?;
542                    }
543                    ForwardInputsResult::Speech { .. } => {
544                        let rates = logits
545                            .iter()
546                            .map(|r| {
547                                #[allow(irrefutable_let_patterns)]
548                                let ForwardInputsResult::Speech { rates, .. } = r
549                                else {
550                                    unreachable!("All results must have same type, `Speech`")
551                                };
552                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
553                                *rates.first().unwrap()
554                            })
555                            .collect::<Vec<_>>();
556                        let channels = logits
557                            .iter()
558                            .map(|r| {
559                                #[allow(irrefutable_let_patterns)]
560                                let ForwardInputsResult::Speech { channels, .. } = r
561                                else {
562                                    unreachable!("All results must have same type, `Speech`")
563                                };
564                                assert_eq!(
565                                    channels.len(),
566                                    1,
567                                    "Each sequence must have 1 PCM output."
568                                );
569                                *channels.first().unwrap()
570                            })
571                            .collect::<Vec<_>>();
572                        let pcms = logits
573                            .into_iter()
574                            .map(|r| {
575                                #[allow(irrefutable_let_patterns)]
576                                let ForwardInputsResult::Speech { pcms, .. } = r
577                                else {
578                                    unreachable!("All results must have same type, `Speech`")
579                                };
580                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
581                                pcms.into_iter().nth(0).unwrap()
582                            })
583                            .collect::<Vec<_>>();
584                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
585                            .await?;
586                    }
587                }
588                let end = Instant::now();
589                exec_duration += end.duration_since(start);
590
591                Ok(exec_duration)
592            }
593            CacheBackendMetadata::PagedAttention {
594                metadata,
595                blocks_to_copy,
596                blocks_to_swap_in,
597                blocks_to_swap_out,
598            } => {
599                // Cloning might be bad?
600                self.get_metadata()
601                    .cache_engine
602                    .as_ref()
603                    .expect("PagedAttention must have cache engines.")
604                    .execute_scheduler_ops(
605                        &blocks_to_swap_in,
606                        &blocks_to_swap_out,
607                        &blocks_to_copy,
608                    )?;
609
610                let inputs_iter =
611                    std::iter::once(self.get_processor().inputs_processor().process_inputs(
612                        self.tokenizer(),
613                        input_seqs,
614                        is_prompt,
615                        self.get_metadata().is_xlora,
616                        &self.device(),
617                        self.get_metadata().no_kv_cache,
618                        None,
619                        return_raw_logits,
620                        self.get_input_processor_config(),
621                        Some(metadata),
622                        self.device_mapper(),
623                    ));
624
625                let mut logits = vec![None; input_seqs.len()];
626                let len_inputs = 1;
627                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
628
629                let mut exec_duration = Duration::ZERO;
630                for (i, inputs) in inputs_iter.into_iter().enumerate() {
631                    let InputProcessorOutput {
632                        inputs,
633                        seq_indices,
634                    } = inputs.map_err(candle_core::Error::msg)?;
635
636                    let start = Instant::now();
637                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
638                    let end = Instant::now();
639                    exec_duration += end.duration_since(start);
640
641                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
642                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
643                            raw_out_logits[seq_idx][i] =
644                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
645                        } else {
646                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
647                        }
648                    }
649                }
650
651                if raw_out_logits[0][0].is_some() {
652                    let start = Instant::now();
653                    response::send_raw_responses(
654                        input_seqs,
655                        raw_out_logits
656                            .into_iter()
657                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
658                            .collect(),
659                    )
660                    .await?;
661                    let end = Instant::now();
662                    exec_duration += end.duration_since(start);
663
664                    return Ok(exec_duration);
665                }
666
667                let start = Instant::now();
668                let logits_on_cpu = logits.len() > 1;
669                let logits = logits
670                    .into_iter()
671                    .map(|l| {
672                        let l = l.expect("Did not get any inputs. This is shocking.");
673                        if logits_on_cpu {
674                            l.to_device(&Device::Cpu)
675                        } else {
676                            Ok(l)
677                        }
678                    })
679                    .collect::<candle_core::Result<Vec<_>>>()?;
680
681                match &logits[0] {
682                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
683                    ForwardInputsResult::CausalGeneration { .. } => {
684                        self.sample_causal_gen(
685                            input_seqs,
686                            logits
687                                .into_iter()
688                                .map(|r| {
689                                    #[allow(irrefutable_let_patterns)]
690                                    let ForwardInputsResult::CausalGeneration { logits } = r
691                                    else {
692                                        unreachable!("All results must have same type")
693                                    };
694                                    logits
695                                })
696                                .collect::<Vec<_>>(),
697                            prefix_cacher,
698                            disable_eos_stop,
699                            rng,
700                        )
701                        .await?;
702                    }
703                    ForwardInputsResult::Image { .. } => {
704                        response::send_image_responses(
705                            input_seqs,
706                            logits
707                                .into_iter()
708                                .map(|r| {
709                                    #[allow(irrefutable_let_patterns)]
710                                    let ForwardInputsResult::Image { images } = r
711                                    else {
712                                        unreachable!("All results must have same type, `Image`")
713                                    };
714                                    images
715                                        .into_iter()
716                                        .next()
717                                        .expect("Must have at least 1 element.")
718                                })
719                                .collect::<Vec<_>>(),
720                        )
721                        .await?;
722                    }
723                    ForwardInputsResult::Speech { .. } => {
724                        let rates = logits
725                            .iter()
726                            .map(|r| {
727                                #[allow(irrefutable_let_patterns)]
728                                let ForwardInputsResult::Speech { rates, .. } = r
729                                else {
730                                    unreachable!("All results must have same type, `Speech`")
731                                };
732                                assert_eq!(rates.len(), 1, "Each sequence must have 1 PCM output.");
733                                *rates.first().unwrap()
734                            })
735                            .collect::<Vec<_>>();
736                        let channels = logits
737                            .iter()
738                            .map(|r| {
739                                #[allow(irrefutable_let_patterns)]
740                                let ForwardInputsResult::Speech { channels, .. } = r
741                                else {
742                                    unreachable!("All results must have same type, `Speech`")
743                                };
744                                assert_eq!(
745                                    channels.len(),
746                                    1,
747                                    "Each sequence must have 1 PCM output."
748                                );
749                                *channels.first().unwrap()
750                            })
751                            .collect::<Vec<_>>();
752                        let pcms = logits
753                            .into_iter()
754                            .map(|r| {
755                                #[allow(irrefutable_let_patterns)]
756                                let ForwardInputsResult::Speech { pcms, .. } = r
757                                else {
758                                    unreachable!("All results must have same type, `Speech`")
759                                };
760                                assert_eq!(pcms.len(), 1, "Each sequence must have 1 PCM output.");
761                                pcms.into_iter().nth(0).unwrap()
762                            })
763                            .collect::<Vec<_>>();
764                        response::send_speech_responses(input_seqs, &pcms, &rates, &channels)
765                            .await?;
766                    }
767                }
768                let end = Instant::now();
769                exec_duration += end.duration_since(start);
770
771                Ok(exec_duration)
772            }
773        }
774    }
775
776    async fn sample_causal_gen(
777        &self,
778        seqs: &mut [&mut Sequence],
779        logits: Vec<Tensor>,
780        prefix_cacher: &mut PrefixCacheManagerV2,
781        disable_eos_stop: bool,
782        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
783    ) -> Result<(), candle_core::Error>;
784
785    fn category(&self) -> ModelCategory;
786}
787
788pub(crate) fn extract_logits(
789    logits: &Tensor,
790    context_lens: Vec<(usize, usize)>,
791) -> candle_core::Result<Tensor> {
792    let mut toks = Vec::new();
793    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
794        toks.push(dim.narrow(1, start, len)?);
795    }
796    Tensor::cat(&toks, 0)
797}
798
799#[cfg(test)]
800mod tests {
801    use crate::MessageContent;
802    use either::Either;
803    use indexmap::IndexMap;
804    use serde_json::Value;
805
806    macro_rules! hashmap {
807        (@single $($x:tt)*) => (());
808        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
809
810        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
811        ($($key:expr => $value:expr),*) => {
812            {
813                let _cap = hashmap!(@count $($key),*);
814                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
815                $(
816                    let _ = _map.insert($key, Value::String($value));
817                )*
818                _map
819            }
820        };
821    }
822
823    #[cfg(test)]
824    #[track_caller]
825    fn test_with_inputs(
826        templates: &[(bool, &str, &str, &str, &str)],
827        expected_outputs: &[&str],
828        inputs: Vec<IndexMap<String, MessageContent>>,
829    ) {
830        use crate::pipeline::chat_template::ChatTemplateValue;
831
832        use super::chat_template::apply_chat_template_to;
833        let mut failed = Vec::new();
834        let n_templates = templates.len();
835        for ((has_system, bos, eos, unk, template), expected) in
836            templates.iter().zip(expected_outputs)
837        {
838            let output = match apply_chat_template_to(
839                if !has_system {
840                    inputs[1..].to_vec()
841                } else {
842                    inputs.clone()
843                },
844                true,
845                None,
846                &ChatTemplateValue(Either::Left(template.to_string())),
847                Some(bos.to_string()),
848                Some(eos.to_string()),
849                Some(unk.to_string()),
850                Vec::new(),
851            ) {
852                Ok(v) => v,
853                Err(e) => {
854                    failed.push(format!("Failed with {e}."));
855                    continue;
856                }
857            };
858            if output != *expected {
859                failed.push(format!(
860                    "Expected: `{}` \n\nGot:      `{}`",
861                    expected.replace('\n', "\\n"),
862                    output.replace('\n', "\\n")
863                ));
864            }
865        }
866        if !failed.is_empty() {
867            for (i, line) in failed.iter().enumerate() {
868                println!("------------ Template {i} ------------");
869                println!("{line}");
870            }
871            println!("------------------------");
872            panic!("{}/{n_templates} chat templates failed.", failed.len());
873        }
874    }
875
876    #[test]
877    /// Generating these cases:
878    /// ```py
879    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
880    /// # If non-system prompt model
881    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
882    /// # If system prompt model
883    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
884    /// ```
885    fn test_chat_templates() {
886        let templates = [
887            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
888            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
889            // mistralai/Mistral-7B-Instruct-v0.1
890            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
891            // meta-llama/Llama-2-13b-chat-hf
892            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
893            // mistralai/Mixtral-8x7B-Instruct-v0.1
894            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
895            // google/gemma-7b-it
896            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
897            // HuggingFaceM4/idefics2-8b-chatty
898            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
899        ];
900        let expected_outputs = [
901            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
902            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
903            // mistralai/Mistral-7B-Instruct-v0.1
904            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
905            // meta-llama/Llama-2-13b-chat-hf
906            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
907            // mistralai/Mixtral-8x7B-Instruct-v0.1
908            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
909            // google/gemma-7b-it
910            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
911        ];
912        let messages = [
913            ["system", "You are a helpful assistant"],
914            ["user", "Hello"],
915            ["assistant", "Hi there"],
916            ["user", "Who are you"],
917            ["assistant", "   I am an assistant   "],
918            ["user", "Another question"],
919        ];
920        let mut inputs = Vec::new();
921        for [role, content] in messages {
922            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
923                IndexMap::new();
924            message.insert("role".to_string(), Either::Left(role.to_string()));
925            message.insert("content".to_string(), Either::Left(content.to_string()));
926            inputs.push(message);
927        }
928        test_with_inputs(&templates, &expected_outputs, inputs);
929    }
930
931    #[test]
932    /// Generating these cases:
933    /// ```py
934    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
935    /// >>> processor.apply_chat_template([
936    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
937    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
938    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
939    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
940    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
941    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
942    ///     ], add_generation_prompt=True, tokenize=False)
943    /// ```
944    fn test_image_chat_templates() {
945        let templates = [
946            // HuggingFaceM4/idefics2-8b-chatty
947            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
948        ];
949        let expected_outputs = [
950            // HuggingFaceM4/idefics2-8b-chatty
951            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
952        ];
953
954        let mut inputs = Vec::new();
955
956        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
957            IndexMap::new();
958        message.insert("role".to_string(), Either::Left("system".to_string()));
959        message.insert(
960            "content".to_string(),
961            Either::Right(vec![hashmap! {
962                "type".to_string() => "text".to_string(),
963                "text".to_string() => "You are a helpful assistant".to_string()
964            }]),
965        );
966        inputs.push(message);
967
968        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
969            IndexMap::new();
970        message.insert("role".to_string(), Either::Left("user".to_string()));
971        message.insert(
972            "content".to_string(),
973            Either::Right(vec![
974                hashmap! {
975                    "type".to_string() => "image".to_string()
976                },
977                hashmap! {
978                    "type".to_string() => "text".to_string(),
979                    "text".to_string() => "Hello, please describe the above.".to_string()
980                },
981            ]),
982        );
983        inputs.push(message);
984
985        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
986            IndexMap::new();
987        message.insert("role".to_string(), Either::Left("assistant".to_string()));
988        message.insert(
989            "content".to_string(),
990            Either::Right(vec![hashmap! {
991                "type".to_string() => "text".to_string(),
992                "text".to_string() => "Hi there".to_string()
993            }]),
994        );
995        inputs.push(message);
996
997        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
998            IndexMap::new();
999        message.insert("role".to_string(), Either::Left("user".to_string()));
1000        message.insert(
1001            "content".to_string(),
1002            Either::Right(vec![
1003                hashmap! {
1004                    "type".to_string() => "image".to_string()
1005                },
1006                hashmap! {
1007                    "type".to_string() => "text".to_string(),
1008                    "text".to_string() => "This is me, who are you".to_string()
1009                },
1010            ]),
1011        );
1012        inputs.push(message);
1013
1014        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1015            IndexMap::new();
1016        message.insert("role".to_string(), Either::Left("assistant".to_string()));
1017        message.insert(
1018            "content".to_string(),
1019            Either::Right(vec![hashmap! {
1020                "type".to_string() => "text".to_string(),
1021                "text".to_string() => "   I am an assistant   ".to_string()
1022            }]),
1023        );
1024        inputs.push(message);
1025
1026        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
1027            IndexMap::new();
1028        message.insert("role".to_string(), Either::Left("user".to_string()));
1029        message.insert(
1030            "content".to_string(),
1031            Either::Right(vec![
1032                hashmap! {
1033                    "type".to_string() => "image".to_string()
1034                },
1035                hashmap! {
1036                    "type".to_string() => "text".to_string(),
1037                    "text".to_string() => "Another question, what is this?".to_string()
1038                },
1039            ]),
1040        );
1041        inputs.push(message);
1042
1043        test_with_inputs(&templates, &expected_outputs, inputs);
1044    }
1045}