1mod amoe;
2mod cache_manager;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod vision;
19
20pub use super::diffusion_models::DiffusionGenerationParams;
21use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
22use crate::device_map::DeviceMapper;
23use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
24use crate::prefix_cacher::PrefixCacheManagerV2;
25pub use amoe::{AnyMoeLoader, AnyMoePipeline};
26use chat_template::ChatTemplate;
27pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder, DiffusionSpecificConfig};
28pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
29pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
30use image::DynamicImage;
31pub use inputs_processor::InputProcessorOutput;
32pub(crate) use isq::IsqModelLoader;
33pub use isq::{parse_isq_value, IsqModel, IsqOrganization};
34pub use loaders::{
35 AdapterKind, AutoDeviceMapParams, AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader,
36 DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, FluxLoader,
37 Gemma2Loader, Gemma3Loader, GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
38 LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader,
39 MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata,
40 NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader,
41 Phi4MMLoader, PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
42 Starcoder2Loader, TokenSource, VLlamaLoader, VisionLoaderType, VisionModel, VisionModelLoader,
43};
44use mistralrs_quant::IsqType;
45pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
46pub(crate) use paths::{
47 get_chat_template, get_model_paths, get_xlora_paths, AdapterPaths, LoraAdapterPaths,
48};
49pub(crate) use processing::{
50 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
51};
52use rand_isaac::Isaac64Rng;
53pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
54use std::any::Any;
55use std::collections::HashMap;
56use std::num::NonZeroUsize;
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokenizers::Tokenizer;
60pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
61
62use anyhow::Result;
63use candle_core::{DType, Device, IndexOp, Tensor, Var};
64
65use crate::sequence::Sequence;
66
67pub use self::cache_manager::{
68 Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
69 RotatingCache, SingleCache,
70};
71pub use self::inputs_processor::{
72 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
73};
74use self::text_models_inputs_processor::PagedAttentionMeta;
75
76pub struct GeneralMetadata {
77 pub max_seq_len: usize,
78 pub tok_env: Option<llguidance::toktrie::TokEnv>,
80 pub no_kv_cache: bool,
81 pub no_prefix_cache: bool,
82 pub num_hidden_layers: usize,
83 pub eos_tok: Vec<u32>,
84 pub kind: ModelKind,
85 pub is_xlora: bool,
87 pub activation_dtype: DType,
88 pub sliding_window: Option<usize>,
89 pub cache_config: Option<CacheConfig>,
91 pub cache_engine: Option<CacheEngine>,
92 pub prompt_chunksize: Option<NonZeroUsize>,
93 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
94}
95
96pub enum CacheInstruction {
97 In,
98 Out,
99 Reset {
101 load_preallocated_cache: bool,
102 reset_non_granular: bool,
103 },
104 Nothing,
105}
106
107pub trait PreProcessingMixin: MetadataMixin {
108 fn get_processor(&self) -> Arc<dyn Processor> {
109 Arc::new(BasicProcessor)
110 }
111 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
113 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
114}
115
116pub trait IsqPipelineMixin {
117 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
118}
119
120pub trait CacheManagerMixin {
121 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
124 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
127 fn set_none_cache(
131 &self,
132 seqs: &mut [&mut Sequence],
133 reset_non_granular: bool,
134 modify_draft_cache: bool,
135 load_preallocated_cache: bool,
136 );
137 fn cache(&self) -> &EitherCache;
138 fn do_preallocated_cache(&self) -> bool {
139 matches!(self.cache(), EitherCache::Normal(_))
140 }
141}
142
143pub trait MetadataMixin {
144 fn device(&self) -> Device;
145 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
147 fn name(&self) -> String;
148 fn reset_non_granular_state(&self);
149 fn get_metadata(&self) -> Arc<GeneralMetadata>;
150 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
151}
152
153pub trait AnyMoePipelineMixin {
155 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
157 unreachable!()
158 }
159 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
160 unreachable!()
161 }
162 fn amoe_base_model_trainable_params(&self) -> usize {
163 unreachable!()
164 }
165 fn amoe_supported(&self) -> bool {
166 false
167 }
168 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
170 unreachable!()
171 }
172 #[allow(clippy::too_many_arguments)]
174 fn amoe_create_layers(
175 &mut self,
176 _model_ids: Vec<String>,
177 _token: &TokenSource,
178 _revision: Option<String>,
179 _match_regex: &str,
180 _config: AnyMoeConfig,
181 _dtype: DType,
182 _dev: &Device,
183 (_prefix, _mlp): (String, String),
184 _layers: Vec<usize>,
185 _expert_type: AnyMoeExpertType,
186 _silent: bool,
187 _gate_model_id: Option<String>,
188 ) -> candle_core::Result<()> {
189 unreachable!()
190 }
191 #[allow(clippy::too_many_arguments)]
193 fn amoe_pre_train(
194 &self,
195 _inputs: AnyMoeTrainingInputs,
196 (_prefix, _mlp): (String, String),
197 _model_ids: Vec<String>,
198 _token: TokenSource,
199 _revision: Option<String>,
200 _layers: Vec<usize>,
201 _silent: bool,
202 ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
203 unreachable!()
204 }
205}
206
207#[derive(Clone)]
210pub enum ModelCategory {
211 Text,
212 Vision {
213 has_conv2d: bool,
214 prefixer: Arc<dyn VisionPromptPrefixer>,
215 },
216 Diffusion,
217}
218
219impl PartialEq for ModelCategory {
220 fn eq(&self, other: &Self) -> bool {
221 match (self, other) {
222 (Self::Text, Self::Text) => true,
223 (Self::Vision { .. }, Self::Vision { .. }) => true,
224 (Self::Diffusion, Self::Diffusion) => true,
225 (Self::Text, _) => false,
226 (Self::Vision { .. }, _) => false,
227 (Self::Diffusion, _) => false,
228 }
229 }
230}
231
232pub trait VisionPromptPrefixer: Send + Sync {
234 fn prefix_image(&self, image_index: usize, prompt: &str) -> String;
236}
237
238pub enum CacheBackendMetadata<'a> {
239 DefaultInstructions {
240 pre_op: CacheInstruction,
241 post_op: CacheInstruction,
242 },
243 PagedAttention {
244 metadata: PagedAttentionMeta<'a>,
245 blocks_to_swap_in: HashMap<usize, usize>,
246 blocks_to_swap_out: HashMap<usize, usize>,
247 blocks_to_copy: HashMap<usize, Vec<usize>>,
248 },
249}
250
251#[derive(Clone, Debug)]
252pub enum ForwardInputsResult {
253 RawLogits { logits: Tensor },
254 CausalGeneration { logits: Tensor },
255 Image { images: Vec<DynamicImage> },
256}
257
258impl ForwardInputsResult {
259 fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
260 match self {
261 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
262 logits: logits.i(bs_idx)?,
263 }),
264 Self::RawLogits { logits } => Ok(Self::RawLogits {
265 logits: logits.i(bs_idx)?,
266 }),
267 Self::Image { images } => Ok(Self::Image {
268 images: vec![images[bs_idx].clone()],
269 }),
270 }
271 }
272
273 fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
274 match self {
275 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
276 logits: logits.to_device(device)?,
277 }),
278 Self::RawLogits { logits } => Ok(Self::RawLogits {
279 logits: logits.to_device(device)?,
280 }),
281 Self::Image { .. } => Ok(self.clone()),
282 }
283 }
284}
285
286#[async_trait::async_trait]
287pub trait Pipeline:
288 Send
289 + Sync
290 + PreProcessingMixin
291 + IsqPipelineMixin
292 + CacheManagerMixin
293 + MetadataMixin
294 + AnyMoePipelineMixin
295{
296 fn forward_inputs(
297 &mut self,
298 inputs: Box<dyn Any>,
299 return_raw_logits: bool,
300 ) -> Result<ForwardInputsResult, candle_core::Error>;
301
302 #[allow(clippy::too_many_arguments)]
304 async fn step(
305 &mut self,
306 input_seqs: &mut [&mut Sequence],
307 is_prompt: bool,
308 return_raw_logits: bool,
309 prefix_cacher: &mut PrefixCacheManagerV2,
310 disable_eos_stop: bool,
311 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
312 backend_metadata: CacheBackendMetadata<'_>,
313 ) -> Result<Duration, candle_core::Error> {
314 match backend_metadata {
315 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
316 let inputs_iter = self.get_processor().inputs_processor().process_inputs(
317 self.tokenizer(),
318 input_seqs,
319 is_prompt,
320 self.get_metadata().is_xlora,
321 &self.device(),
322 self.get_metadata().no_kv_cache,
323 None,
324 return_raw_logits,
325 self.get_input_processor_config(),
326 None,
327 self.get_metadata().prompt_chunksize,
328 self.device_mapper(),
329 );
330
331 let mut logits = vec![None; input_seqs.len()];
332 let prompt_chunksize = self
333 .get_metadata()
334 .prompt_chunksize
335 .map(NonZeroUsize::get)
336 .unwrap_or(1);
337 let len_inputs = input_seqs
338 .iter()
339 .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
340 .max()
341 .unwrap();
342 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
343
344 let mut exec_duration = Duration::ZERO;
345 for (i, inputs) in inputs_iter.into_iter().enumerate() {
346 let InputProcessorOutput {
347 inputs,
348 seq_indices,
349 } = inputs.map_err(candle_core::Error::msg)?;
350 if i == 0 {
351 match pre_op {
352 CacheInstruction::In => self.clone_in_cache(input_seqs),
353 CacheInstruction::Nothing => (),
354 CacheInstruction::Reset {
355 load_preallocated_cache,
356 reset_non_granular,
357 } => self.set_none_cache(
358 input_seqs,
359 reset_non_granular,
360 false,
361 load_preallocated_cache,
362 ),
363 _ => unreachable!("Unreachable PRE cache op."),
364 }
365 }
366
367 let start = Instant::now();
368 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
369 let end = Instant::now();
370 exec_duration += end.duration_since(start);
371
372 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
373 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
374 raw_out_logits[seq_idx][i] =
375 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
376 } else {
377 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
378 }
379 }
380 }
381
382 match post_op {
383 CacheInstruction::Out => self.clone_out_cache(input_seqs),
384 CacheInstruction::Nothing => (),
385 CacheInstruction::Reset {
386 load_preallocated_cache,
387 reset_non_granular,
388 } => self.set_none_cache(
389 input_seqs,
390 reset_non_granular,
391 false,
392 load_preallocated_cache,
393 ),
394 _ => unreachable!("Unreachable POST cache op."),
395 }
396
397 if raw_out_logits[0][0].is_some() {
398 let start = Instant::now();
399 response::send_raw_responses(
400 input_seqs,
401 raw_out_logits
402 .into_iter()
403 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
404 .collect(),
405 )
406 .await?;
407 let end = Instant::now();
408 exec_duration += end.duration_since(start);
409
410 return Ok(exec_duration);
411 }
412
413 let logits = logits
414 .into_iter()
415 .map(|l| {
416 l.expect("Did not get any inputs. This is shocking.")
417 .to_device(&Device::Cpu)
418 })
419 .collect::<candle_core::Result<Vec<_>>>()?;
420
421 let start = Instant::now();
422 match &logits[0] {
423 ForwardInputsResult::RawLogits { .. } => unreachable!(),
424 ForwardInputsResult::CausalGeneration { .. } => {
425 self.sample_causal_gen(
426 input_seqs,
427 logits
428 .into_iter()
429 .map(|r| {
430 #[allow(irrefutable_let_patterns)]
431 let ForwardInputsResult::CausalGeneration { logits } = r
432 else {
433 unreachable!(
434 "All results must have same type, `CausalGeneration`"
435 )
436 };
437 logits
438 })
439 .collect::<Vec<_>>(),
440 prefix_cacher,
441 disable_eos_stop,
442 rng,
443 )
444 .await?;
445 }
446 ForwardInputsResult::Image { .. } => {
447 response::send_image_responses(
448 input_seqs,
449 logits
450 .into_iter()
451 .map(|r| {
452 #[allow(irrefutable_let_patterns)]
453 let ForwardInputsResult::Image { images } = r
454 else {
455 unreachable!(
456 "All results must have same type, `CausalGeneration`"
457 )
458 };
459 images
460 .into_iter()
461 .next()
462 .expect("Must have at least 1 element.")
463 })
464 .collect::<Vec<_>>(),
465 )
466 .await?;
467 }
468 }
469 let end = Instant::now();
470 exec_duration += end.duration_since(start);
471
472 Ok(exec_duration)
473 }
474 CacheBackendMetadata::PagedAttention {
475 metadata,
476 blocks_to_copy,
477 blocks_to_swap_in,
478 blocks_to_swap_out,
479 } => {
480 self.get_metadata()
482 .cache_engine
483 .as_ref()
484 .expect("PagedAttention must have cache engines.")
485 .execute_scheduler_ops(
486 blocks_to_swap_in.clone(),
487 blocks_to_swap_out.clone(),
488 blocks_to_copy.clone(),
489 )?;
490
491 let inputs_iter = self.get_processor().inputs_processor().process_inputs(
492 self.tokenizer(),
493 input_seqs,
494 is_prompt,
495 self.get_metadata().is_xlora,
496 &self.device(),
497 self.get_metadata().no_kv_cache,
498 None,
499 return_raw_logits,
500 self.get_input_processor_config(),
501 Some(metadata),
502 self.get_metadata().prompt_chunksize,
503 self.device_mapper(),
504 );
505
506 let mut logits = vec![None; input_seqs.len()];
507 let prompt_chunksize = self
508 .get_metadata()
509 .prompt_chunksize
510 .map(NonZeroUsize::get)
511 .unwrap_or(1);
512 let len_inputs = input_seqs
513 .iter()
514 .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
515 .max()
516 .unwrap();
517 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
518
519 let mut exec_duration = Duration::ZERO;
520 for (i, inputs) in inputs_iter.into_iter().enumerate() {
521 let InputProcessorOutput {
522 inputs,
523 seq_indices,
524 } = inputs.map_err(candle_core::Error::msg)?;
525
526 let start = Instant::now();
527 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
528 let end = Instant::now();
529 exec_duration += end.duration_since(start);
530
531 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
532 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
533 raw_out_logits[seq_idx][i] =
534 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
535 } else {
536 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
537 }
538 }
539 }
540
541 if raw_out_logits[0][0].is_some() {
542 let start = Instant::now();
543 response::send_raw_responses(
544 input_seqs,
545 raw_out_logits
546 .into_iter()
547 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
548 .collect(),
549 )
550 .await?;
551 let end = Instant::now();
552 exec_duration += end.duration_since(start);
553
554 return Ok(exec_duration);
555 }
556
557 let logits = logits
558 .into_iter()
559 .map(|l| {
560 l.expect("Did not get any inputs. This is shocking.")
561 .to_device(&Device::Cpu)
562 })
563 .collect::<candle_core::Result<Vec<_>>>()?;
564
565 let start = Instant::now();
566 match &logits[0] {
567 ForwardInputsResult::RawLogits { .. } => unreachable!(),
568 ForwardInputsResult::CausalGeneration { .. } => {
569 self.sample_causal_gen(
570 input_seqs,
571 logits
572 .into_iter()
573 .map(|r| {
574 #[allow(irrefutable_let_patterns)]
575 let ForwardInputsResult::CausalGeneration { logits } = r
576 else {
577 unreachable!("All results must have same type")
578 };
579 logits
580 })
581 .collect::<Vec<_>>(),
582 prefix_cacher,
583 disable_eos_stop,
584 rng,
585 )
586 .await?;
587 }
588 ForwardInputsResult::Image { .. } => {
589 response::send_image_responses(
590 input_seqs,
591 logits
592 .into_iter()
593 .map(|r| {
594 #[allow(irrefutable_let_patterns)]
595 let ForwardInputsResult::Image { images } = r
596 else {
597 unreachable!(
598 "All results must have same type, `CausalGeneration`"
599 )
600 };
601 images
602 .into_iter()
603 .next()
604 .expect("Must have at least 1 element.")
605 })
606 .collect::<Vec<_>>(),
607 )
608 .await?;
609 }
610 }
611 let end = Instant::now();
612 exec_duration += end.duration_since(start);
613
614 Ok(exec_duration)
615 }
616 }
617 }
618
619 async fn sample_causal_gen(
620 &self,
621 seqs: &mut [&mut Sequence],
622 logits: Vec<Tensor>,
623 prefix_cacher: &mut PrefixCacheManagerV2,
624 disable_eos_stop: bool,
625 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
626 ) -> Result<(), candle_core::Error>;
627
628 fn category(&self) -> ModelCategory;
629}
630
631pub(crate) fn extract_logits(
632 logits: &Tensor,
633 context_lens: Vec<(usize, usize)>,
634) -> candle_core::Result<Tensor> {
635 let mut toks = Vec::new();
636 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
637 toks.push(dim.narrow(1, start, len)?);
638 }
639 Tensor::cat(&toks, 0)
640}
641
642#[cfg(test)]
643mod tests {
644 use crate::MessageContent;
645 use either::Either;
646 use indexmap::IndexMap;
647 use serde_json::Value;
648
649 macro_rules! hashmap {
650 (@single $($x:tt)*) => (());
651 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
652
653 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
654 ($($key:expr => $value:expr),*) => {
655 {
656 let _cap = hashmap!(@count $($key),*);
657 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
658 $(
659 let _ = _map.insert($key, Value::String($value));
660 )*
661 _map
662 }
663 };
664 }
665
666 #[cfg(test)]
667 #[track_caller]
668 fn test_with_inputs(
669 templates: &[(bool, &str, &str, &str, &str)],
670 expected_outputs: &[&str],
671 inputs: Vec<IndexMap<String, MessageContent>>,
672 ) {
673 use crate::pipeline::chat_template::ChatTemplateValue;
674
675 use super::chat_template::apply_chat_template_to;
676 let mut failed = Vec::new();
677 let n_templates = templates.len();
678 for ((has_system, bos, eos, unk, template), expected) in
679 templates.iter().zip(expected_outputs)
680 {
681 let output = match apply_chat_template_to(
682 if !has_system {
683 inputs[1..].to_vec()
684 } else {
685 inputs.clone()
686 },
687 true,
688 &ChatTemplateValue(Either::Left(template.to_string())),
689 Some(bos.to_string()),
690 Some(eos.to_string()),
691 Some(unk.to_string()),
692 Vec::new(),
693 ) {
694 Ok(v) => v,
695 Err(e) => {
696 failed.push(format!("Failed with {e}."));
697 continue;
698 }
699 };
700 if output != *expected {
701 failed.push(format!(
702 "Expected: `{}` \n\nGot: `{}`",
703 expected.replace('\n', "\\n"),
704 output.replace('\n', "\\n")
705 ));
706 }
707 }
708 if !failed.is_empty() {
709 for (i, line) in failed.iter().enumerate() {
710 println!("------------ Template {i} ------------");
711 println!("{line}");
712 }
713 println!("------------------------");
714 panic!("{}/{n_templates} chat templates failed.", failed.len());
715 }
716 }
717
718 #[test]
719 fn test_chat_templates() {
728 let templates = [
729 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
731 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
733 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
735 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
737 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
739 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
741 ];
742 let expected_outputs = [
743 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
745 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
747 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
749 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
751 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
753 ];
754 let messages = [
755 ["system", "You are a helpful assistant"],
756 ["user", "Hello"],
757 ["assistant", "Hi there"],
758 ["user", "Who are you"],
759 ["assistant", " I am an assistant "],
760 ["user", "Another question"],
761 ];
762 let mut inputs = Vec::new();
763 for [role, content] in messages {
764 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
765 IndexMap::new();
766 message.insert("role".to_string(), Either::Left(role.to_string()));
767 message.insert("content".to_string(), Either::Left(content.to_string()));
768 inputs.push(message);
769 }
770 test_with_inputs(&templates, &expected_outputs, inputs);
771 }
772
773 #[test]
774 fn test_image_chat_templates() {
787 let templates = [
788 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
790 ];
791 let expected_outputs = [
792 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
794 ];
795
796 let mut inputs = Vec::new();
797
798 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
799 IndexMap::new();
800 message.insert("role".to_string(), Either::Left("system".to_string()));
801 message.insert(
802 "content".to_string(),
803 Either::Right(vec![hashmap! {
804 "type".to_string() => "text".to_string(),
805 "text".to_string() => "You are a helpful assistant".to_string()
806 }]),
807 );
808 inputs.push(message);
809
810 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
811 IndexMap::new();
812 message.insert("role".to_string(), Either::Left("user".to_string()));
813 message.insert(
814 "content".to_string(),
815 Either::Right(vec![
816 hashmap! {
817 "type".to_string() => "image".to_string()
818 },
819 hashmap! {
820 "type".to_string() => "text".to_string(),
821 "text".to_string() => "Hello, please describe the above.".to_string()
822 },
823 ]),
824 );
825 inputs.push(message);
826
827 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
828 IndexMap::new();
829 message.insert("role".to_string(), Either::Left("assistant".to_string()));
830 message.insert(
831 "content".to_string(),
832 Either::Right(vec![hashmap! {
833 "type".to_string() => "text".to_string(),
834 "text".to_string() => "Hi there".to_string()
835 }]),
836 );
837 inputs.push(message);
838
839 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
840 IndexMap::new();
841 message.insert("role".to_string(), Either::Left("user".to_string()));
842 message.insert(
843 "content".to_string(),
844 Either::Right(vec![
845 hashmap! {
846 "type".to_string() => "image".to_string()
847 },
848 hashmap! {
849 "type".to_string() => "text".to_string(),
850 "text".to_string() => "This is me, who are you".to_string()
851 },
852 ]),
853 );
854 inputs.push(message);
855
856 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
857 IndexMap::new();
858 message.insert("role".to_string(), Either::Left("assistant".to_string()));
859 message.insert(
860 "content".to_string(),
861 Either::Right(vec![hashmap! {
862 "type".to_string() => "text".to_string(),
863 "text".to_string() => " I am an assistant ".to_string()
864 }]),
865 );
866 inputs.push(message);
867
868 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
869 IndexMap::new();
870 message.insert("role".to_string(), Either::Left("user".to_string()));
871 message.insert(
872 "content".to_string(),
873 Either::Right(vec![
874 hashmap! {
875 "type".to_string() => "image".to_string()
876 },
877 hashmap! {
878 "type".to_string() => "text".to_string(),
879 "text".to_string() => "Another question, what is this?".to_string()
880 },
881 ]),
882 );
883 inputs.push(message);
884
885 test_with_inputs(&templates, &expected_outputs, inputs);
886 }
887}