mistralrs_core/pipeline/
mod.rs

1mod amoe;
2mod cache_manager;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod vision;
19
20pub use super::diffusion_models::DiffusionGenerationParams;
21use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
22use crate::device_map::DeviceMapper;
23use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
24use crate::prefix_cacher::PrefixCacheManagerV2;
25pub use amoe::{AnyMoeLoader, AnyMoePipeline};
26use chat_template::ChatTemplate;
27pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder, DiffusionSpecificConfig};
28pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
29pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
30use image::DynamicImage;
31pub use inputs_processor::InputProcessorOutput;
32pub(crate) use isq::IsqModelLoader;
33pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
34use llguidance::toktrie::TokEnv;
35pub use loaders::{
36    AdapterKind, AutoDeviceMapParams, AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader,
37    DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, FluxLoader,
38    Gemma2Loader, Gemma3Loader, GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
39    LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader,
40    MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata,
41    NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader,
42    Phi4MMLoader, PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
43    Qwen3Loader, Qwen3MoELoader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader,
44    VisionLoaderType, VisionModel, VisionModelLoader,
45};
46use mistralrs_quant::IsqType;
47pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
48pub(crate) use paths::{
49    get_chat_template, get_model_paths, get_xlora_paths, AdapterPaths, LoraAdapterPaths,
50};
51pub(crate) use processing::{
52    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
53};
54use rand_isaac::Isaac64Rng;
55pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
56use std::any::Any;
57use std::collections::HashMap;
58use std::num::NonZeroUsize;
59use std::sync::Arc;
60use std::time::{Duration, Instant};
61use tokenizers::Tokenizer;
62pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
63
64use anyhow::Result;
65use candle_core::{DType, Device, IndexOp, Tensor, Var};
66
67use crate::sequence::Sequence;
68
69pub use self::cache_manager::{
70    Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
71};
72pub use self::inputs_processor::{
73    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
74};
75use self::text_models_inputs_processor::PagedAttentionMeta;
76
77pub struct GeneralMetadata {
78    pub max_seq_len: usize,
79    /// Only None if it doesn't make sense for the model
80    pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
81    pub no_kv_cache: bool,
82    pub no_prefix_cache: bool,
83    pub num_hidden_layers: usize,
84    pub eos_tok: Vec<u32>,
85    pub kind: ModelKind,
86    // TODO: Replace is_xlora queries to check via kind instead:
87    pub is_xlora: bool,
88    pub activation_dtype: DType,
89    pub sliding_window: Option<usize>,
90    // PagedAttention stuff
91    pub cache_config: Option<CacheConfig>,
92    pub cache_engine: Option<CacheEngine>,
93    pub prompt_chunksize: Option<NonZeroUsize>,
94    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
95}
96
97impl GeneralMetadata {
98    pub fn tok_env(&self) -> Option<TokEnv> {
99        self.llg_factory.as_ref().map(|f| f.tok_env().clone())
100    }
101}
102
103pub enum CacheInstruction {
104    In,
105    Out,
106    /// load_preallocated_cache means to load the preallocated cache, if applicable.
107    Reset {
108        load_preallocated_cache: bool,
109        reset_non_granular: bool,
110    },
111    Nothing,
112}
113
114pub trait PreProcessingMixin: MetadataMixin {
115    fn get_processor(&self) -> Arc<dyn Processor> {
116        Arc::new(BasicProcessor)
117    }
118    /// Only None if it doesnt make sense for the model
119    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
120    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
121}
122
123pub trait IsqPipelineMixin {
124    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
125}
126
127pub trait CacheManagerMixin {
128    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
129    /// It is not a guarantee that this will be called for each completion step.
130    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
131    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
132    /// It is not a guarantee that this will be called for each step.
133    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
134    /// Set the model cache to all None. Only called for prompt seqs.
135    /// It is not a guarantee that this will be called for each prompt step.
136    /// This may also reset the non granular state if applicable.
137    fn set_none_cache(
138        &self,
139        seqs: &mut [&mut Sequence],
140        reset_non_granular: bool,
141        modify_draft_cache: bool,
142        load_preallocated_cache: bool,
143    );
144    fn cache(&self) -> &EitherCache;
145    fn do_preallocated_cache(&self) -> bool {
146        matches!(self.cache(), EitherCache::Normal(_))
147    }
148}
149
150pub trait MetadataMixin {
151    fn device(&self) -> Device;
152    /// Only None if it doesnt make sense for the model
153    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
154    fn name(&self) -> String;
155    fn reset_non_granular_state(&self);
156    fn get_metadata(&self) -> Arc<GeneralMetadata>;
157    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
158}
159
160/// Implemented by the base model of an AnyMoe.
161pub trait AnyMoePipelineMixin {
162    /// Get vars for each gating layer
163    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
164        unreachable!()
165    }
166    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
167        unreachable!()
168    }
169    fn amoe_base_model_trainable_params(&self) -> usize {
170        unreachable!()
171    }
172    fn amoe_supported(&self) -> bool {
173        false
174    }
175    /// Per-layer cached outputs.
176    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
177        unreachable!()
178    }
179    /// Inject the MoE layers
180    #[allow(clippy::too_many_arguments)]
181    fn amoe_create_layers(
182        &mut self,
183        _model_ids: Vec<String>,
184        _token: &TokenSource,
185        _revision: Option<String>,
186        _match_regex: &str,
187        _config: AnyMoeConfig,
188        _dtype: DType,
189        _dev: &Device,
190        (_prefix, _mlp): (String, String),
191        _layers: Vec<usize>,
192        _expert_type: AnyMoeExpertType,
193        _silent: bool,
194        _gate_model_id: Option<String>,
195    ) -> candle_core::Result<()> {
196        unreachable!()
197    }
198    /// Pre-train the gating layers
199    #[allow(clippy::too_many_arguments)]
200    fn amoe_pre_train(
201        &self,
202        _inputs: AnyMoeTrainingInputs,
203        (_prefix, _mlp): (String, String),
204        _model_ids: Vec<String>,
205        _token: TokenSource,
206        _revision: Option<String>,
207        _layers: Vec<usize>,
208        _silent: bool,
209    ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
210        unreachable!()
211    }
212}
213
214/// Category of the model. This can also be used to extract model-category specific tools,
215/// such as the vision model prompt prefixer.
216#[derive(Clone)]
217pub enum ModelCategory {
218    Text,
219    Vision {
220        has_conv2d: bool,
221        prefixer: Arc<dyn VisionPromptPrefixer>,
222    },
223    Diffusion,
224}
225
226impl PartialEq for ModelCategory {
227    fn eq(&self, other: &Self) -> bool {
228        match (self, other) {
229            (Self::Text, Self::Text) => true,
230            (Self::Vision { .. }, Self::Vision { .. }) => true,
231            (Self::Diffusion, Self::Diffusion) => true,
232            (Self::Text, _) => false,
233            (Self::Vision { .. }, _) => false,
234            (Self::Diffusion, _) => false,
235        }
236    }
237}
238
239/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at 0.
240pub trait VisionPromptPrefixer: Send + Sync {
241    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
242    fn prefix_image(&self, image_indees: Vec<usize>, prompt: &str) -> String;
243}
244
245pub enum CacheBackendMetadata<'a> {
246    DefaultInstructions {
247        pre_op: CacheInstruction,
248        post_op: CacheInstruction,
249    },
250    PagedAttention {
251        metadata: PagedAttentionMeta<'a>,
252        blocks_to_swap_in: HashMap<usize, usize>,
253        blocks_to_swap_out: HashMap<usize, usize>,
254        blocks_to_copy: HashMap<usize, Vec<usize>>,
255    },
256}
257
258#[derive(Clone, Debug)]
259pub enum ForwardInputsResult {
260    RawLogits { logits: Tensor },
261    CausalGeneration { logits: Tensor },
262    Image { images: Vec<DynamicImage> },
263}
264
265impl ForwardInputsResult {
266    fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
267        match self {
268            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
269                logits: logits.i(bs_idx)?,
270            }),
271            Self::RawLogits { logits } => Ok(Self::RawLogits {
272                logits: logits.i(bs_idx)?,
273            }),
274            Self::Image { images } => Ok(Self::Image {
275                images: vec![images[bs_idx].clone()],
276            }),
277        }
278    }
279
280    fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
281        match self {
282            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
283                logits: logits.to_device(device)?,
284            }),
285            Self::RawLogits { logits } => Ok(Self::RawLogits {
286                logits: logits.to_device(device)?,
287            }),
288            Self::Image { .. } => Ok(self.clone()),
289        }
290    }
291}
292
293#[async_trait::async_trait]
294pub trait Pipeline:
295    Send
296    + Sync
297    + PreProcessingMixin
298    + IsqPipelineMixin
299    + CacheManagerMixin
300    + MetadataMixin
301    + AnyMoePipelineMixin
302{
303    fn forward_inputs(
304        &mut self,
305        inputs: Box<dyn Any>,
306        return_raw_logits: bool,
307    ) -> Result<ForwardInputsResult, candle_core::Error>;
308
309    /// Returns the total of model execution time.
310    #[allow(clippy::too_many_arguments)]
311    async fn step(
312        &mut self,
313        input_seqs: &mut [&mut Sequence],
314        is_prompt: bool,
315        return_raw_logits: bool,
316        prefix_cacher: &mut PrefixCacheManagerV2,
317        disable_eos_stop: bool,
318        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
319        backend_metadata: CacheBackendMetadata<'_>,
320    ) -> Result<Duration, candle_core::Error> {
321        match backend_metadata {
322            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
323                let inputs_iter = self.get_processor().inputs_processor().process_inputs(
324                    self.tokenizer(),
325                    input_seqs,
326                    is_prompt,
327                    self.get_metadata().is_xlora,
328                    &self.device(),
329                    self.get_metadata().no_kv_cache,
330                    None,
331                    return_raw_logits,
332                    self.get_input_processor_config(),
333                    None,
334                    self.get_metadata().prompt_chunksize,
335                    self.device_mapper(),
336                );
337
338                let mut logits = vec![None; input_seqs.len()];
339                let prompt_chunksize = self
340                    .get_metadata()
341                    .prompt_chunksize
342                    .map(NonZeroUsize::get)
343                    .unwrap_or(1);
344                let len_inputs = input_seqs
345                    .iter()
346                    .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
347                    .max()
348                    .unwrap();
349                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
350
351                let mut exec_duration = Duration::ZERO;
352                for (i, inputs) in inputs_iter.into_iter().enumerate() {
353                    let InputProcessorOutput {
354                        inputs,
355                        seq_indices,
356                    } = inputs.map_err(candle_core::Error::msg)?;
357                    if i == 0 {
358                        match pre_op {
359                            CacheInstruction::In => self.clone_in_cache(input_seqs),
360                            CacheInstruction::Nothing => (),
361                            CacheInstruction::Reset {
362                                load_preallocated_cache,
363                                reset_non_granular,
364                            } => self.set_none_cache(
365                                input_seqs,
366                                reset_non_granular,
367                                false,
368                                load_preallocated_cache,
369                            ),
370                            _ => unreachable!("Unreachable PRE cache op."),
371                        }
372                    }
373
374                    let start = Instant::now();
375                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
376                    let end = Instant::now();
377                    exec_duration += end.duration_since(start);
378
379                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
380                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
381                            raw_out_logits[seq_idx][i] =
382                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
383                        } else {
384                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
385                        }
386                    }
387                }
388
389                match post_op {
390                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
391                    CacheInstruction::Nothing => (),
392                    CacheInstruction::Reset {
393                        load_preallocated_cache,
394                        reset_non_granular,
395                    } => self.set_none_cache(
396                        input_seqs,
397                        reset_non_granular,
398                        false,
399                        load_preallocated_cache,
400                    ),
401                    _ => unreachable!("Unreachable POST cache op."),
402                }
403
404                if raw_out_logits[0][0].is_some() {
405                    let start = Instant::now();
406                    response::send_raw_responses(
407                        input_seqs,
408                        raw_out_logits
409                            .into_iter()
410                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
411                            .collect(),
412                    )
413                    .await?;
414                    let end = Instant::now();
415                    exec_duration += end.duration_since(start);
416
417                    return Ok(exec_duration);
418                }
419
420                let start = Instant::now();
421                let logits = logits
422                    .into_iter()
423                    .map(|l| {
424                        l.expect("Did not get any inputs. This is shocking.")
425                            .to_device(&Device::Cpu)
426                    })
427                    .collect::<candle_core::Result<Vec<_>>>()?;
428
429                match &logits[0] {
430                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
431                    ForwardInputsResult::CausalGeneration { .. } => {
432                        self.sample_causal_gen(
433                            input_seqs,
434                            logits
435                                .into_iter()
436                                .map(|r| {
437                                    #[allow(irrefutable_let_patterns)]
438                                    let ForwardInputsResult::CausalGeneration { logits } = r
439                                    else {
440                                        unreachable!(
441                                            "All results must have same type, `CausalGeneration`"
442                                        )
443                                    };
444                                    logits
445                                })
446                                .collect::<Vec<_>>(),
447                            prefix_cacher,
448                            disable_eos_stop,
449                            rng,
450                        )
451                        .await?;
452                    }
453                    ForwardInputsResult::Image { .. } => {
454                        response::send_image_responses(
455                            input_seqs,
456                            logits
457                                .into_iter()
458                                .map(|r| {
459                                    #[allow(irrefutable_let_patterns)]
460                                    let ForwardInputsResult::Image { images } = r
461                                    else {
462                                        unreachable!(
463                                            "All results must have same type, `CausalGeneration`"
464                                        )
465                                    };
466                                    images
467                                        .into_iter()
468                                        .next()
469                                        .expect("Must have at least 1 element.")
470                                })
471                                .collect::<Vec<_>>(),
472                        )
473                        .await?;
474                    }
475                }
476                let end = Instant::now();
477                exec_duration += end.duration_since(start);
478
479                Ok(exec_duration)
480            }
481            CacheBackendMetadata::PagedAttention {
482                metadata,
483                blocks_to_copy,
484                blocks_to_swap_in,
485                blocks_to_swap_out,
486            } => {
487                // Cloning might be bad?
488                self.get_metadata()
489                    .cache_engine
490                    .as_ref()
491                    .expect("PagedAttention must have cache engines.")
492                    .execute_scheduler_ops(
493                        &blocks_to_swap_in,
494                        &blocks_to_swap_out,
495                        &blocks_to_copy,
496                    )?;
497
498                let inputs_iter = self.get_processor().inputs_processor().process_inputs(
499                    self.tokenizer(),
500                    input_seqs,
501                    is_prompt,
502                    self.get_metadata().is_xlora,
503                    &self.device(),
504                    self.get_metadata().no_kv_cache,
505                    None,
506                    return_raw_logits,
507                    self.get_input_processor_config(),
508                    Some(metadata),
509                    self.get_metadata().prompt_chunksize,
510                    self.device_mapper(),
511                );
512
513                let mut logits = vec![None; input_seqs.len()];
514                let prompt_chunksize = self
515                    .get_metadata()
516                    .prompt_chunksize
517                    .map(NonZeroUsize::get)
518                    .unwrap_or(1);
519                let len_inputs = input_seqs
520                    .iter()
521                    .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
522                    .max()
523                    .unwrap();
524                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
525
526                let mut exec_duration = Duration::ZERO;
527                for (i, inputs) in inputs_iter.into_iter().enumerate() {
528                    let InputProcessorOutput {
529                        inputs,
530                        seq_indices,
531                    } = inputs.map_err(candle_core::Error::msg)?;
532
533                    let start = Instant::now();
534                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
535                    let end = Instant::now();
536                    exec_duration += end.duration_since(start);
537
538                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
539                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
540                            raw_out_logits[seq_idx][i] =
541                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
542                        } else {
543                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
544                        }
545                    }
546                }
547
548                if raw_out_logits[0][0].is_some() {
549                    let start = Instant::now();
550                    response::send_raw_responses(
551                        input_seqs,
552                        raw_out_logits
553                            .into_iter()
554                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
555                            .collect(),
556                    )
557                    .await?;
558                    let end = Instant::now();
559                    exec_duration += end.duration_since(start);
560
561                    return Ok(exec_duration);
562                }
563
564                let start = Instant::now();
565                let logits = logits
566                    .into_iter()
567                    .map(|l| {
568                        l.expect("Did not get any inputs. This is shocking.")
569                            .to_device(&Device::Cpu)
570                    })
571                    .collect::<candle_core::Result<Vec<_>>>()?;
572
573                match &logits[0] {
574                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
575                    ForwardInputsResult::CausalGeneration { .. } => {
576                        self.sample_causal_gen(
577                            input_seqs,
578                            logits
579                                .into_iter()
580                                .map(|r| {
581                                    #[allow(irrefutable_let_patterns)]
582                                    let ForwardInputsResult::CausalGeneration { logits } = r
583                                    else {
584                                        unreachable!("All results must have same type")
585                                    };
586                                    logits
587                                })
588                                .collect::<Vec<_>>(),
589                            prefix_cacher,
590                            disable_eos_stop,
591                            rng,
592                        )
593                        .await?;
594                    }
595                    ForwardInputsResult::Image { .. } => {
596                        response::send_image_responses(
597                            input_seqs,
598                            logits
599                                .into_iter()
600                                .map(|r| {
601                                    #[allow(irrefutable_let_patterns)]
602                                    let ForwardInputsResult::Image { images } = r
603                                    else {
604                                        unreachable!(
605                                            "All results must have same type, `CausalGeneration`"
606                                        )
607                                    };
608                                    images
609                                        .into_iter()
610                                        .next()
611                                        .expect("Must have at least 1 element.")
612                                })
613                                .collect::<Vec<_>>(),
614                        )
615                        .await?;
616                    }
617                }
618                let end = Instant::now();
619                exec_duration += end.duration_since(start);
620
621                Ok(exec_duration)
622            }
623        }
624    }
625
626    async fn sample_causal_gen(
627        &self,
628        seqs: &mut [&mut Sequence],
629        logits: Vec<Tensor>,
630        prefix_cacher: &mut PrefixCacheManagerV2,
631        disable_eos_stop: bool,
632        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
633    ) -> Result<(), candle_core::Error>;
634
635    fn category(&self) -> ModelCategory;
636}
637
638pub(crate) fn extract_logits(
639    logits: &Tensor,
640    context_lens: Vec<(usize, usize)>,
641) -> candle_core::Result<Tensor> {
642    let mut toks = Vec::new();
643    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
644        toks.push(dim.narrow(1, start, len)?);
645    }
646    Tensor::cat(&toks, 0)
647}
648
649#[cfg(test)]
650mod tests {
651    use crate::MessageContent;
652    use either::Either;
653    use indexmap::IndexMap;
654    use serde_json::Value;
655
656    macro_rules! hashmap {
657        (@single $($x:tt)*) => (());
658        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
659
660        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
661        ($($key:expr => $value:expr),*) => {
662            {
663                let _cap = hashmap!(@count $($key),*);
664                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
665                $(
666                    let _ = _map.insert($key, Value::String($value));
667                )*
668                _map
669            }
670        };
671    }
672
673    #[cfg(test)]
674    #[track_caller]
675    fn test_with_inputs(
676        templates: &[(bool, &str, &str, &str, &str)],
677        expected_outputs: &[&str],
678        inputs: Vec<IndexMap<String, MessageContent>>,
679    ) {
680        use crate::pipeline::chat_template::ChatTemplateValue;
681
682        use super::chat_template::apply_chat_template_to;
683        let mut failed = Vec::new();
684        let n_templates = templates.len();
685        for ((has_system, bos, eos, unk, template), expected) in
686            templates.iter().zip(expected_outputs)
687        {
688            let output = match apply_chat_template_to(
689                if !has_system {
690                    inputs[1..].to_vec()
691                } else {
692                    inputs.clone()
693                },
694                true,
695                None,
696                &ChatTemplateValue(Either::Left(template.to_string())),
697                Some(bos.to_string()),
698                Some(eos.to_string()),
699                Some(unk.to_string()),
700                Vec::new(),
701            ) {
702                Ok(v) => v,
703                Err(e) => {
704                    failed.push(format!("Failed with {e}."));
705                    continue;
706                }
707            };
708            if output != *expected {
709                failed.push(format!(
710                    "Expected: `{}` \n\nGot:      `{}`",
711                    expected.replace('\n', "\\n"),
712                    output.replace('\n', "\\n")
713                ));
714            }
715        }
716        if !failed.is_empty() {
717            for (i, line) in failed.iter().enumerate() {
718                println!("------------ Template {i} ------------");
719                println!("{line}");
720            }
721            println!("------------------------");
722            panic!("{}/{n_templates} chat templates failed.", failed.len());
723        }
724    }
725
726    #[test]
727    /// Generating these cases:
728    /// ```py
729    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
730    /// # If non-system prompt model
731    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
732    /// # If system prompt model
733    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
734    /// ```
735    fn test_chat_templates() {
736        let templates = [
737            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
738            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
739            // mistralai/Mistral-7B-Instruct-v0.1
740            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
741            // meta-llama/Llama-2-13b-chat-hf
742            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
743            // mistralai/Mixtral-8x7B-Instruct-v0.1
744            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
745            // google/gemma-7b-it
746            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
747            // HuggingFaceM4/idefics2-8b-chatty
748            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
749        ];
750        let expected_outputs = [
751            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
752            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
753            // mistralai/Mistral-7B-Instruct-v0.1
754            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
755            // meta-llama/Llama-2-13b-chat-hf
756            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
757            // mistralai/Mixtral-8x7B-Instruct-v0.1
758            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
759            // google/gemma-7b-it
760            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
761        ];
762        let messages = [
763            ["system", "You are a helpful assistant"],
764            ["user", "Hello"],
765            ["assistant", "Hi there"],
766            ["user", "Who are you"],
767            ["assistant", "   I am an assistant   "],
768            ["user", "Another question"],
769        ];
770        let mut inputs = Vec::new();
771        for [role, content] in messages {
772            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
773                IndexMap::new();
774            message.insert("role".to_string(), Either::Left(role.to_string()));
775            message.insert("content".to_string(), Either::Left(content.to_string()));
776            inputs.push(message);
777        }
778        test_with_inputs(&templates, &expected_outputs, inputs);
779    }
780
781    #[test]
782    /// Generating these cases:
783    /// ```py
784    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
785    /// >>> processor.apply_chat_template([
786    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
787    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
788    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
789    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
790    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
791    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
792    ///     ], add_generation_prompt=True, tokenize=False)
793    /// ```
794    fn test_image_chat_templates() {
795        let templates = [
796            // HuggingFaceM4/idefics2-8b-chatty
797            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
798        ];
799        let expected_outputs = [
800            // HuggingFaceM4/idefics2-8b-chatty
801            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
802        ];
803
804        let mut inputs = Vec::new();
805
806        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
807            IndexMap::new();
808        message.insert("role".to_string(), Either::Left("system".to_string()));
809        message.insert(
810            "content".to_string(),
811            Either::Right(vec![hashmap! {
812                "type".to_string() => "text".to_string(),
813                "text".to_string() => "You are a helpful assistant".to_string()
814            }]),
815        );
816        inputs.push(message);
817
818        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
819            IndexMap::new();
820        message.insert("role".to_string(), Either::Left("user".to_string()));
821        message.insert(
822            "content".to_string(),
823            Either::Right(vec![
824                hashmap! {
825                    "type".to_string() => "image".to_string()
826                },
827                hashmap! {
828                    "type".to_string() => "text".to_string(),
829                    "text".to_string() => "Hello, please describe the above.".to_string()
830                },
831            ]),
832        );
833        inputs.push(message);
834
835        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
836            IndexMap::new();
837        message.insert("role".to_string(), Either::Left("assistant".to_string()));
838        message.insert(
839            "content".to_string(),
840            Either::Right(vec![hashmap! {
841                "type".to_string() => "text".to_string(),
842                "text".to_string() => "Hi there".to_string()
843            }]),
844        );
845        inputs.push(message);
846
847        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
848            IndexMap::new();
849        message.insert("role".to_string(), Either::Left("user".to_string()));
850        message.insert(
851            "content".to_string(),
852            Either::Right(vec![
853                hashmap! {
854                    "type".to_string() => "image".to_string()
855                },
856                hashmap! {
857                    "type".to_string() => "text".to_string(),
858                    "text".to_string() => "This is me, who are you".to_string()
859                },
860            ]),
861        );
862        inputs.push(message);
863
864        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
865            IndexMap::new();
866        message.insert("role".to_string(), Either::Left("assistant".to_string()));
867        message.insert(
868            "content".to_string(),
869            Either::Right(vec![hashmap! {
870                "type".to_string() => "text".to_string(),
871                "text".to_string() => "   I am an assistant   ".to_string()
872            }]),
873        );
874        inputs.push(message);
875
876        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
877            IndexMap::new();
878        message.insert("role".to_string(), Either::Left("user".to_string()));
879        message.insert(
880            "content".to_string(),
881            Either::Right(vec![
882                hashmap! {
883                    "type".to_string() => "image".to_string()
884                },
885                hashmap! {
886                    "type".to_string() => "text".to_string(),
887                    "text".to_string() => "Another question, what is this?".to_string()
888                },
889            ]),
890        );
891        inputs.push(message);
892
893        test_with_inputs(&templates, &expected_outputs, inputs);
894    }
895}