mistralrs_core/pipeline/
mod.rs

1mod amoe;
2mod cache_manager;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod vision;
19
20pub use super::diffusion_models::DiffusionGenerationParams;
21use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
22use crate::device_map::DeviceMapper;
23use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
24use crate::prefix_cacher::PrefixCacheManagerV2;
25pub use amoe::{AnyMoeLoader, AnyMoePipeline};
26use chat_template::ChatTemplate;
27pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder, DiffusionSpecificConfig};
28pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
29pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
30use image::DynamicImage;
31pub use inputs_processor::InputProcessorOutput;
32pub(crate) use isq::IsqModelLoader;
33pub use isq::{parse_isq_value, IsqModel, IsqOrganization};
34pub use loaders::{
35    AdapterKind, AutoDeviceMapParams, AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader,
36    DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, FluxLoader,
37    Gemma2Loader, Gemma3Loader, GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
38    LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader,
39    MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata,
40    NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader,
41    Phi4MMLoader, PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
42    Starcoder2Loader, TokenSource, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
43};
44use mistralrs_quant::IsqType;
45pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
46pub(crate) use paths::{
47    get_chat_template, get_model_paths, get_xlora_paths, AdapterPaths, LoraAdapterPaths,
48};
49pub(crate) use processing::{
50    apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
51};
52use rand_isaac::Isaac64Rng;
53pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
54use std::any::Any;
55use std::collections::HashMap;
56use std::num::NonZeroUsize;
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokenizers::Tokenizer;
60pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
61
62use anyhow::Result;
63use candle_core::{DType, Device, IndexOp, Tensor, Var};
64
65use crate::sequence::Sequence;
66
67pub use self::cache_manager::{
68    Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
69    RotatingCache, SingleCache,
70};
71pub use self::inputs_processor::{
72    text_models_inputs_processor, InputsProcessor, InputsProcessorType,
73};
74use self::text_models_inputs_processor::PagedAttentionMeta;
75
76pub struct GeneralMetadata {
77    pub max_seq_len: usize,
78    /// Only None if it doesnt make sense for the model
79    pub tok_env: Option<llguidance::toktrie::TokEnv>,
80    pub no_kv_cache: bool,
81    pub no_prefix_cache: bool,
82    pub num_hidden_layers: usize,
83    pub eos_tok: Vec<u32>,
84    pub kind: ModelKind,
85    // TODO: Replace is_xlora queries to check via kind instead:
86    pub is_xlora: bool,
87    pub activation_dtype: DType,
88    pub sliding_window: Option<usize>,
89    // PagedAttention stuff
90    pub cache_config: Option<CacheConfig>,
91    pub cache_engine: Option<CacheEngine>,
92    pub prompt_chunksize: Option<NonZeroUsize>,
93    pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
94}
95
96pub enum CacheInstruction {
97    In,
98    Out,
99    /// load_preallocated_cache means to load the preallocated cache, if applicable.
100    Reset {
101        load_preallocated_cache: bool,
102        reset_non_granular: bool,
103    },
104    Nothing,
105}
106
107pub trait PreProcessingMixin: MetadataMixin {
108    fn get_processor(&self) -> Arc<dyn Processor> {
109        Arc::new(BasicProcessor)
110    }
111    /// Only None if it doesnt make sense for the model
112    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
113    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
114}
115
116pub trait IsqPipelineMixin {
117    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
118}
119
120pub trait CacheManagerMixin {
121    /// Clone the cache FROM the sequences' cache TO the model cache. Only called for completion seqs.
122    /// It is not a guarantee that this will be called for each completion step.
123    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
124    /// Clone the cache FROM the model cache TO the sequences. Called for prompt and completion seqs.
125    /// It is not a guarantee that this will be called for each step.
126    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
127    /// Set the model cache to all None. Only called for prompt seqs.
128    /// It is not a guarantee that this will be called for each prompt step.
129    /// This may also reset the non granular state if applicable.
130    fn set_none_cache(
131        &self,
132        seqs: &mut [&mut Sequence],
133        reset_non_granular: bool,
134        modify_draft_cache: bool,
135        load_preallocated_cache: bool,
136    );
137    fn cache(&self) -> &EitherCache;
138    fn do_preallocated_cache(&self) -> bool {
139        matches!(self.cache(), EitherCache::Normal(_))
140    }
141}
142
143pub trait MetadataMixin {
144    fn device(&self) -> Device;
145    /// Only None if it doesnt make sense for the model
146    fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
147    fn name(&self) -> String;
148    fn reset_non_granular_state(&self);
149    fn get_metadata(&self) -> Arc<GeneralMetadata>;
150    fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
151}
152
153/// Implemented by the base model of an AnyMoe.
154pub trait AnyMoePipelineMixin {
155    /// Get vars for each gating layer
156    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
157        unreachable!()
158    }
159    fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
160        unreachable!()
161    }
162    fn amoe_base_model_trainable_params(&self) -> usize {
163        unreachable!()
164    }
165    fn amoe_supported(&self) -> bool {
166        false
167    }
168    /// Per-layer cached outputs.
169    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
170        unreachable!()
171    }
172    /// Inject the MoE layers
173    #[allow(clippy::too_many_arguments)]
174    fn amoe_create_layers(
175        &mut self,
176        _model_ids: Vec<String>,
177        _token: &TokenSource,
178        _revision: Option<String>,
179        _match_regex: &str,
180        _config: AnyMoeConfig,
181        _dtype: DType,
182        _dev: &Device,
183        (_prefix, _mlp): (String, String),
184        _layers: Vec<usize>,
185        _expert_type: AnyMoeExpertType,
186        _silent: bool,
187        _gate_model_id: Option<String>,
188    ) -> candle_core::Result<()> {
189        unreachable!()
190    }
191    /// Pre-train the gating layers
192    #[allow(clippy::too_many_arguments)]
193    fn amoe_pre_train(
194        &self,
195        _inputs: AnyMoeTrainingInputs,
196        (_prefix, _mlp): (String, String),
197        _model_ids: Vec<String>,
198        _token: TokenSource,
199        _revision: Option<String>,
200        _layers: Vec<usize>,
201        _silent: bool,
202    ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
203        unreachable!()
204    }
205}
206
207/// Category of the model. This can also be used to extract model-category specific tools,
208/// such as the vision model prompt prefixer.
209#[derive(Clone)]
210pub enum ModelCategory {
211    Text,
212    Vision {
213        has_conv2d: bool,
214        prefixer: Arc<dyn VisionPromptPrefixer>,
215    },
216    Diffusion,
217}
218
219impl PartialEq for ModelCategory {
220    fn eq(&self, other: &Self) -> bool {
221        match (self, other) {
222            (Self::Text, Self::Text) => true,
223            (Self::Vision { .. }, Self::Vision { .. }) => true,
224            (Self::Diffusion, Self::Diffusion) => true,
225            (Self::Text, _) => false,
226            (Self::Vision { .. }, _) => false,
227            (Self::Diffusion, _) => false,
228        }
229    }
230}
231
232/// Prepend a vision tag appropriate for the model to the prompt. Image indexing is assumed that start at
233pub trait VisionPromptPrefixer: Send + Sync {
234    /// Prefix for inclusion in messages (may do nothing if the chat template handles it).
235    fn prefix_image(&self, image_index: usize, prompt: &str) -> String;
236}
237
238pub enum CacheBackendMetadata<'a> {
239    DefaultInstructions {
240        pre_op: CacheInstruction,
241        post_op: CacheInstruction,
242    },
243    PagedAttention {
244        metadata: PagedAttentionMeta<'a>,
245        blocks_to_swap_in: HashMap<usize, usize>,
246        blocks_to_swap_out: HashMap<usize, usize>,
247        blocks_to_copy: HashMap<usize, Vec<usize>>,
248    },
249}
250
251#[derive(Clone, Debug)]
252pub enum ForwardInputsResult {
253    RawLogits { logits: Tensor },
254    CausalGeneration { logits: Tensor },
255    Image { images: Vec<DynamicImage> },
256}
257
258impl ForwardInputsResult {
259    fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
260        match self {
261            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
262                logits: logits.i(bs_idx)?,
263            }),
264            Self::RawLogits { logits } => Ok(Self::RawLogits {
265                logits: logits.i(bs_idx)?,
266            }),
267            Self::Image { images } => Ok(Self::Image {
268                images: vec![images[bs_idx].clone()],
269            }),
270        }
271    }
272
273    fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
274        match self {
275            Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
276                logits: logits.to_device(device)?,
277            }),
278            Self::RawLogits { logits } => Ok(Self::RawLogits {
279                logits: logits.to_device(device)?,
280            }),
281            Self::Image { .. } => Ok(self.clone()),
282        }
283    }
284}
285
286#[async_trait::async_trait]
287pub trait Pipeline:
288    Send
289    + Sync
290    + PreProcessingMixin
291    + IsqPipelineMixin
292    + CacheManagerMixin
293    + MetadataMixin
294    + AnyMoePipelineMixin
295{
296    fn forward_inputs(
297        &mut self,
298        inputs: Box<dyn Any>,
299        return_raw_logits: bool,
300    ) -> Result<ForwardInputsResult, candle_core::Error>;
301
302    /// Returns the total of model execution time.
303    #[allow(clippy::too_many_arguments)]
304    async fn step(
305        &mut self,
306        input_seqs: &mut [&mut Sequence],
307        is_prompt: bool,
308        return_raw_logits: bool,
309        prefix_cacher: &mut PrefixCacheManagerV2,
310        disable_eos_stop: bool,
311        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
312        backend_metadata: CacheBackendMetadata<'_>,
313    ) -> Result<Duration, candle_core::Error> {
314        match backend_metadata {
315            CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
316                let inputs_iter = self.get_processor().inputs_processor().process_inputs(
317                    self.tokenizer(),
318                    input_seqs,
319                    is_prompt,
320                    self.get_metadata().is_xlora,
321                    &self.device(),
322                    self.get_metadata().no_kv_cache,
323                    None,
324                    return_raw_logits,
325                    self.get_input_processor_config(),
326                    None,
327                    self.get_metadata().prompt_chunksize,
328                    self.device_mapper(),
329                );
330
331                let mut logits = vec![None; input_seqs.len()];
332                let prompt_chunksize = self
333                    .get_metadata()
334                    .prompt_chunksize
335                    .map(NonZeroUsize::get)
336                    .unwrap_or(1);
337                let len_inputs = input_seqs
338                    .iter()
339                    .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
340                    .max()
341                    .unwrap();
342                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
343
344                let mut exec_duration = Duration::ZERO;
345                for (i, inputs) in inputs_iter.into_iter().enumerate() {
346                    let InputProcessorOutput {
347                        inputs,
348                        seq_indices,
349                    } = inputs.map_err(candle_core::Error::msg)?;
350                    if i == 0 {
351                        match pre_op {
352                            CacheInstruction::In => self.clone_in_cache(input_seqs),
353                            CacheInstruction::Nothing => (),
354                            CacheInstruction::Reset {
355                                load_preallocated_cache,
356                                reset_non_granular,
357                            } => self.set_none_cache(
358                                input_seqs,
359                                reset_non_granular,
360                                false,
361                                load_preallocated_cache,
362                            ),
363                            _ => unreachable!("Unreachable PRE cache op."),
364                        }
365                    }
366
367                    let start = Instant::now();
368                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
369                    let end = Instant::now();
370                    exec_duration += end.duration_since(start);
371
372                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
373                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
374                            raw_out_logits[seq_idx][i] =
375                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
376                        } else {
377                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
378                        }
379                    }
380                }
381
382                match post_op {
383                    CacheInstruction::Out => self.clone_out_cache(input_seqs),
384                    CacheInstruction::Nothing => (),
385                    CacheInstruction::Reset {
386                        load_preallocated_cache,
387                        reset_non_granular,
388                    } => self.set_none_cache(
389                        input_seqs,
390                        reset_non_granular,
391                        false,
392                        load_preallocated_cache,
393                    ),
394                    _ => unreachable!("Unreachable POST cache op."),
395                }
396
397                if raw_out_logits[0][0].is_some() {
398                    let start = Instant::now();
399                    response::send_raw_responses(
400                        input_seqs,
401                        raw_out_logits
402                            .into_iter()
403                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
404                            .collect(),
405                    )
406                    .await?;
407                    let end = Instant::now();
408                    exec_duration += end.duration_since(start);
409
410                    return Ok(exec_duration);
411                }
412
413                let logits = logits
414                    .into_iter()
415                    .map(|l| {
416                        l.expect("Did not get any inputs. This is shocking.")
417                            .to_device(&Device::Cpu)
418                    })
419                    .collect::<candle_core::Result<Vec<_>>>()?;
420
421                let start = Instant::now();
422                match &logits[0] {
423                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
424                    ForwardInputsResult::CausalGeneration { .. } => {
425                        self.sample_causal_gen(
426                            input_seqs,
427                            logits
428                                .into_iter()
429                                .map(|r| {
430                                    #[allow(irrefutable_let_patterns)]
431                                    let ForwardInputsResult::CausalGeneration { logits } = r
432                                    else {
433                                        unreachable!(
434                                            "All results must have same type, `CausalGeneration`"
435                                        )
436                                    };
437                                    logits
438                                })
439                                .collect::<Vec<_>>(),
440                            prefix_cacher,
441                            disable_eos_stop,
442                            rng,
443                        )
444                        .await?;
445                    }
446                    ForwardInputsResult::Image { .. } => {
447                        response::send_image_responses(
448                            input_seqs,
449                            logits
450                                .into_iter()
451                                .map(|r| {
452                                    #[allow(irrefutable_let_patterns)]
453                                    let ForwardInputsResult::Image { images } = r
454                                    else {
455                                        unreachable!(
456                                            "All results must have same type, `CausalGeneration`"
457                                        )
458                                    };
459                                    images
460                                        .into_iter()
461                                        .next()
462                                        .expect("Must have at least 1 element.")
463                                })
464                                .collect::<Vec<_>>(),
465                        )
466                        .await?;
467                    }
468                }
469                let end = Instant::now();
470                exec_duration += end.duration_since(start);
471
472                Ok(exec_duration)
473            }
474            CacheBackendMetadata::PagedAttention {
475                metadata,
476                blocks_to_copy,
477                blocks_to_swap_in,
478                blocks_to_swap_out,
479            } => {
480                // Cloning might be bad?
481                self.get_metadata()
482                    .cache_engine
483                    .as_ref()
484                    .expect("PagedAttention must have cache engines.")
485                    .execute_scheduler_ops(
486                        blocks_to_swap_in.clone(),
487                        blocks_to_swap_out.clone(),
488                        blocks_to_copy.clone(),
489                    )?;
490
491                let inputs_iter = self.get_processor().inputs_processor().process_inputs(
492                    self.tokenizer(),
493                    input_seqs,
494                    is_prompt,
495                    self.get_metadata().is_xlora,
496                    &self.device(),
497                    self.get_metadata().no_kv_cache,
498                    None,
499                    return_raw_logits,
500                    self.get_input_processor_config(),
501                    Some(metadata),
502                    self.get_metadata().prompt_chunksize,
503                    self.device_mapper(),
504                );
505
506                let mut logits = vec![None; input_seqs.len()];
507                let prompt_chunksize = self
508                    .get_metadata()
509                    .prompt_chunksize
510                    .map(NonZeroUsize::get)
511                    .unwrap_or(1);
512                let len_inputs = input_seqs
513                    .iter()
514                    .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
515                    .max()
516                    .unwrap();
517                let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
518
519                let mut exec_duration = Duration::ZERO;
520                for (i, inputs) in inputs_iter.into_iter().enumerate() {
521                    let InputProcessorOutput {
522                        inputs,
523                        seq_indices,
524                    } = inputs.map_err(candle_core::Error::msg)?;
525
526                    let start = Instant::now();
527                    let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
528                    let end = Instant::now();
529                    exec_duration += end.duration_since(start);
530
531                    for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
532                        if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
533                            raw_out_logits[seq_idx][i] =
534                                Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
535                        } else {
536                            logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
537                        }
538                    }
539                }
540
541                if raw_out_logits[0][0].is_some() {
542                    let start = Instant::now();
543                    response::send_raw_responses(
544                        input_seqs,
545                        raw_out_logits
546                            .into_iter()
547                            .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
548                            .collect(),
549                    )
550                    .await?;
551                    let end = Instant::now();
552                    exec_duration += end.duration_since(start);
553
554                    return Ok(exec_duration);
555                }
556
557                let logits = logits
558                    .into_iter()
559                    .map(|l| {
560                        l.expect("Did not get any inputs. This is shocking.")
561                            .to_device(&Device::Cpu)
562                    })
563                    .collect::<candle_core::Result<Vec<_>>>()?;
564
565                let start = Instant::now();
566                match &logits[0] {
567                    ForwardInputsResult::RawLogits { .. } => unreachable!(),
568                    ForwardInputsResult::CausalGeneration { .. } => {
569                        self.sample_causal_gen(
570                            input_seqs,
571                            logits
572                                .into_iter()
573                                .map(|r| {
574                                    #[allow(irrefutable_let_patterns)]
575                                    let ForwardInputsResult::CausalGeneration { logits } = r
576                                    else {
577                                        unreachable!("All results must have same type")
578                                    };
579                                    logits
580                                })
581                                .collect::<Vec<_>>(),
582                            prefix_cacher,
583                            disable_eos_stop,
584                            rng,
585                        )
586                        .await?;
587                    }
588                    ForwardInputsResult::Image { .. } => {
589                        response::send_image_responses(
590                            input_seqs,
591                            logits
592                                .into_iter()
593                                .map(|r| {
594                                    #[allow(irrefutable_let_patterns)]
595                                    let ForwardInputsResult::Image { images } = r
596                                    else {
597                                        unreachable!(
598                                            "All results must have same type, `CausalGeneration`"
599                                        )
600                                    };
601                                    images
602                                        .into_iter()
603                                        .next()
604                                        .expect("Must have at least 1 element.")
605                                })
606                                .collect::<Vec<_>>(),
607                        )
608                        .await?;
609                    }
610                }
611                let end = Instant::now();
612                exec_duration += end.duration_since(start);
613
614                Ok(exec_duration)
615            }
616        }
617    }
618
619    async fn sample_causal_gen(
620        &self,
621        seqs: &mut [&mut Sequence],
622        logits: Vec<Tensor>,
623        prefix_cacher: &mut PrefixCacheManagerV2,
624        disable_eos_stop: bool,
625        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
626    ) -> Result<(), candle_core::Error>;
627
628    fn category(&self) -> ModelCategory;
629}
630
631pub(crate) fn extract_logits(
632    logits: &Tensor,
633    context_lens: Vec<(usize, usize)>,
634) -> candle_core::Result<Tensor> {
635    let mut toks = Vec::new();
636    for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
637        toks.push(dim.narrow(1, start, len)?);
638    }
639    Tensor::cat(&toks, 0)
640}
641
642#[cfg(test)]
643mod tests {
644    use crate::MessageContent;
645    use either::Either;
646    use indexmap::IndexMap;
647    use serde_json::Value;
648
649    macro_rules! hashmap {
650        (@single $($x:tt)*) => (());
651        (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
652
653        ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
654        ($($key:expr => $value:expr),*) => {
655            {
656                let _cap = hashmap!(@count $($key),*);
657                let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
658                $(
659                    let _ = _map.insert($key, Value::String($value));
660                )*
661                _map
662            }
663        };
664    }
665
666    #[cfg(test)]
667    #[track_caller]
668    fn test_with_inputs(
669        templates: &[(bool, &str, &str, &str, &str)],
670        expected_outputs: &[&str],
671        inputs: Vec<IndexMap<String, MessageContent>>,
672    ) {
673        use crate::pipeline::chat_template::ChatTemplateValue;
674
675        use super::chat_template::apply_chat_template_to;
676        let mut failed = Vec::new();
677        let n_templates = templates.len();
678        for ((has_system, bos, eos, unk, template), expected) in
679            templates.iter().zip(expected_outputs)
680        {
681            let output = match apply_chat_template_to(
682                if !has_system {
683                    inputs[1..].to_vec()
684                } else {
685                    inputs.clone()
686                },
687                true,
688                &ChatTemplateValue(Either::Left(template.to_string())),
689                Some(bos.to_string()),
690                Some(eos.to_string()),
691                Some(unk.to_string()),
692                Vec::new(),
693            ) {
694                Ok(v) => v,
695                Err(e) => {
696                    failed.push(format!("Failed with {e}."));
697                    continue;
698                }
699            };
700            if output != *expected {
701                failed.push(format!(
702                    "Expected: `{}` \n\nGot:      `{}`",
703                    expected.replace('\n', "\\n"),
704                    output.replace('\n', "\\n")
705                ));
706            }
707        }
708        if !failed.is_empty() {
709            for (i, line) in failed.iter().enumerate() {
710                println!("------------ Template {i} ------------");
711                println!("{line}");
712            }
713            println!("------------------------");
714            panic!("{}/{n_templates} chat templates failed.", failed.len());
715        }
716    }
717
718    #[test]
719    /// Generating these cases:
720    /// ```py
721    /// >>> t=transformers.AutoTokenizer.from_pretrained(...)
722    /// # If non-system prompt model
723    /// >>> t.apply_chat_template([{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
724    /// # If system prompt model
725    /// >>> t.apply_chat_template([{"role":"system","content":"You are a helpful assistant"},{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"},{"role":"user","content":"Who are you"},{"role":"assistant","content":"   I am an assistant   "},{"role":"user","content":"Another question"}], add_generation_prompt=True, tokenize=False)
726    /// ```
727    fn test_chat_templates() {
728        let templates = [
729            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
730            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
731            // mistralai/Mistral-7B-Instruct-v0.1
732            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
733            // meta-llama/Llama-2-13b-chat-hf
734            (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
735            // mistralai/Mixtral-8x7B-Instruct-v0.1
736            (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
737            // google/gemma-7b-it
738            (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
739            // HuggingFaceM4/idefics2-8b-chatty
740            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
741        ];
742        let expected_outputs = [
743            // ChatML: https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
744            "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n   I am an assistant   <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
745            // mistralai/Mistral-7B-Instruct-v0.1
746            "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST]   I am an assistant   </s> [INST] Another question [/INST]",
747            // meta-llama/Llama-2-13b-chat-hf
748            "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
749            // mistralai/Mixtral-8x7B-Instruct-v0.1
750            "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST]   I am an assistant   </s>[INST] Another question [/INST]",
751            // google/gemma-7b-it
752            "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
753        ];
754        let messages = [
755            ["system", "You are a helpful assistant"],
756            ["user", "Hello"],
757            ["assistant", "Hi there"],
758            ["user", "Who are you"],
759            ["assistant", "   I am an assistant   "],
760            ["user", "Another question"],
761        ];
762        let mut inputs = Vec::new();
763        for [role, content] in messages {
764            let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
765                IndexMap::new();
766            message.insert("role".to_string(), Either::Left(role.to_string()));
767            message.insert("content".to_string(), Either::Left(content.to_string()));
768            inputs.push(message);
769        }
770        test_with_inputs(&templates, &expected_outputs, inputs);
771    }
772
773    #[test]
774    /// Generating these cases:
775    /// ```py
776    /// >>> processor=transformers.AutoProcessor.from_pretrained(...)
777    /// >>> processor.apply_chat_template([
778    ///         {"role":"system","content":[{"type":"text", "text": "You are a helpful assistant"}]},
779    ///         {"role":"user","content":[{"type":"image"}, {"type":"text", "text": "Hello, please describe the above."}]},
780    ///         {"role":"assistant","content":[{"type":"text", "text": "Hi there"}]},
781    ///         {"role":"user","content":[{"type":"text", "text": "Who are you"}]},
782    ///         {"role":"assistant","content":[{"type":"text", "text": "   I am an assistant   "}]},
783    ///         {"role":"user","content":[{"type":"text", "text": "Another question"}]}
784    ///     ], add_generation_prompt=True, tokenize=False)
785    /// ```
786    fn test_image_chat_templates() {
787        let templates = [
788            // HuggingFaceM4/idefics2-8b-chatty
789            (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
790        ];
791        let expected_outputs = [
792            // HuggingFaceM4/idefics2-8b-chatty
793            "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant:    I am an assistant   <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
794        ];
795
796        let mut inputs = Vec::new();
797
798        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
799            IndexMap::new();
800        message.insert("role".to_string(), Either::Left("system".to_string()));
801        message.insert(
802            "content".to_string(),
803            Either::Right(vec![hashmap! {
804                "type".to_string() => "text".to_string(),
805                "text".to_string() => "You are a helpful assistant".to_string()
806            }]),
807        );
808        inputs.push(message);
809
810        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
811            IndexMap::new();
812        message.insert("role".to_string(), Either::Left("user".to_string()));
813        message.insert(
814            "content".to_string(),
815            Either::Right(vec![
816                hashmap! {
817                    "type".to_string() => "image".to_string()
818                },
819                hashmap! {
820                    "type".to_string() => "text".to_string(),
821                    "text".to_string() => "Hello, please describe the above.".to_string()
822                },
823            ]),
824        );
825        inputs.push(message);
826
827        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
828            IndexMap::new();
829        message.insert("role".to_string(), Either::Left("assistant".to_string()));
830        message.insert(
831            "content".to_string(),
832            Either::Right(vec![hashmap! {
833                "type".to_string() => "text".to_string(),
834                "text".to_string() => "Hi there".to_string()
835            }]),
836        );
837        inputs.push(message);
838
839        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
840            IndexMap::new();
841        message.insert("role".to_string(), Either::Left("user".to_string()));
842        message.insert(
843            "content".to_string(),
844            Either::Right(vec![
845                hashmap! {
846                    "type".to_string() => "image".to_string()
847                },
848                hashmap! {
849                    "type".to_string() => "text".to_string(),
850                    "text".to_string() => "This is me, who are you".to_string()
851                },
852            ]),
853        );
854        inputs.push(message);
855
856        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
857            IndexMap::new();
858        message.insert("role".to_string(), Either::Left("assistant".to_string()));
859        message.insert(
860            "content".to_string(),
861            Either::Right(vec![hashmap! {
862                "type".to_string() => "text".to_string(),
863                "text".to_string() => "   I am an assistant   ".to_string()
864            }]),
865        );
866        inputs.push(message);
867
868        let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
869            IndexMap::new();
870        message.insert("role".to_string(), Either::Left("user".to_string()));
871        message.insert(
872            "content".to_string(),
873            Either::Right(vec![
874                hashmap! {
875                    "type".to_string() => "image".to_string()
876                },
877                hashmap! {
878                    "type".to_string() => "text".to_string(),
879                    "text".to_string() => "Another question, what is this?".to_string()
880                },
881            ]),
882        );
883        inputs.push(message);
884
885        test_with_inputs(&templates, &expected_outputs, inputs);
886    }
887}