1mod amoe;
2mod cache_manager;
3pub mod chat_template;
4mod diffusion;
5mod ggml;
6mod gguf;
7mod inputs_processor;
8mod isq;
9pub(crate) mod llg;
10mod loaders;
11mod macros;
12mod normal;
13mod paths;
14mod processing;
15mod response;
16mod sampling;
17mod speculative;
18mod vision;
19
20pub use super::diffusion_models::DiffusionGenerationParams;
21use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
22use crate::device_map::DeviceMapper;
23use crate::paged_attention::{CacheConfig, CacheEngine, ModelConfigLike};
24use crate::prefix_cacher::PrefixCacheManagerV2;
25pub use amoe::{AnyMoeLoader, AnyMoePipeline};
26use chat_template::ChatTemplate;
27pub use diffusion::{DiffusionLoader, DiffusionLoaderBuilder, DiffusionSpecificConfig};
28pub use ggml::{GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig};
29pub use gguf::{GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig};
30use image::DynamicImage;
31pub use inputs_processor::InputProcessorOutput;
32pub(crate) use isq::IsqModelLoader;
33pub use isq::{parse_isq_value, IsqModel, IsqOrganization, UQFF_MULTI_FILE_DELIMITER};
34use llguidance::toktrie::TokEnv;
35pub use loaders::{
36 AdapterKind, AutoDeviceMapParams, AutoLoader, DeepSeekV2Loader, DeepSeekV3Loader,
37 DeviceMappedModelLoader, DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, FluxLoader,
38 Gemma2Loader, Gemma3Loader, GemmaLoader, Idefics2Loader, Idefics3Loader, LLaVALoader,
39 LLaVANextLoader, LlamaLoader, Loader, LocalModelPaths, MiniCpmOLoader, Mistral3Loader,
40 MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoaderType, NormalLoadingMetadata,
41 NormalModel, NormalModelLoader, Phi2Loader, Phi3Loader, Phi3VLoader, Phi3_5MoELoader,
42 Phi4MMLoader, PrettyName, QuantizationKind, Qwen2Loader, Qwen2VLLoader, Qwen2_5VLLoader,
43 Qwen3Loader, Qwen3MoELoader, Starcoder2Loader, TokenSource, VLlama4Loader, VLlamaLoader,
44 VisionLoaderType, VisionModel, VisionModelLoader,
45};
46use mistralrs_quant::IsqType;
47pub use normal::{NormalLoader, NormalLoaderBuilder, NormalSpecificConfig};
48pub(crate) use paths::{
49 get_chat_template, get_model_paths, get_xlora_paths, AdapterPaths, LoraAdapterPaths,
50};
51pub(crate) use processing::{
52 apply_chat_template, BasicProcessor, MessagesAction, Processor, ProcessorCreator,
53};
54use rand_isaac::Isaac64Rng;
55pub use speculative::{SpeculativeConfig, SpeculativeLoader, SpeculativePipeline};
56use std::any::Any;
57use std::collections::HashMap;
58use std::num::NonZeroUsize;
59use std::sync::Arc;
60use std::time::{Duration, Instant};
61use tokenizers::Tokenizer;
62pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
63
64use anyhow::Result;
65use candle_core::{DType, Device, IndexOp, Tensor, Var};
66
67use crate::sequence::Sequence;
68
69pub use self::cache_manager::{
70 Cache, CacheManager, EitherCache, KvCache, LayerCaches, NormalCache, NormalCacheType,
71};
72pub use self::inputs_processor::{
73 text_models_inputs_processor, InputsProcessor, InputsProcessorType,
74};
75use self::text_models_inputs_processor::PagedAttentionMeta;
76
77pub struct GeneralMetadata {
78 pub max_seq_len: usize,
79 pub llg_factory: Option<Arc<llguidance::ParserFactory>>,
81 pub no_kv_cache: bool,
82 pub no_prefix_cache: bool,
83 pub num_hidden_layers: usize,
84 pub eos_tok: Vec<u32>,
85 pub kind: ModelKind,
86 pub is_xlora: bool,
88 pub activation_dtype: DType,
89 pub sliding_window: Option<usize>,
90 pub cache_config: Option<CacheConfig>,
92 pub cache_engine: Option<CacheEngine>,
93 pub prompt_chunksize: Option<NonZeroUsize>,
94 pub model_metadata: Option<Arc<dyn ModelConfigLike + Send + Sync>>,
95}
96
97impl GeneralMetadata {
98 pub fn tok_env(&self) -> Option<TokEnv> {
99 self.llg_factory.as_ref().map(|f| f.tok_env().clone())
100 }
101}
102
103pub enum CacheInstruction {
104 In,
105 Out,
106 Reset {
108 load_preallocated_cache: bool,
109 reset_non_granular: bool,
110 },
111 Nothing,
112}
113
114pub trait PreProcessingMixin: MetadataMixin {
115 fn get_processor(&self) -> Arc<dyn Processor> {
116 Arc::new(BasicProcessor)
117 }
118 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>>;
120 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>>;
121}
122
123pub trait IsqPipelineMixin {
124 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()>;
125}
126
127pub trait CacheManagerMixin {
128 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]);
131 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]);
134 fn set_none_cache(
138 &self,
139 seqs: &mut [&mut Sequence],
140 reset_non_granular: bool,
141 modify_draft_cache: bool,
142 load_preallocated_cache: bool,
143 );
144 fn cache(&self) -> &EitherCache;
145 fn do_preallocated_cache(&self) -> bool {
146 matches!(self.cache(), EitherCache::Normal(_))
147 }
148}
149
150pub trait MetadataMixin {
151 fn device(&self) -> Device;
152 fn tokenizer(&self) -> Option<Arc<Tokenizer>>;
154 fn name(&self) -> String;
155 fn reset_non_granular_state(&self);
156 fn get_metadata(&self) -> Arc<GeneralMetadata>;
157 fn device_mapper(&self) -> Option<&dyn DeviceMapper>;
158}
159
160pub trait AnyMoePipelineMixin {
162 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
164 unreachable!()
165 }
166 fn amoe_finish_training(&mut self, _gate_model_id: Option<String>) -> candle_core::Result<()> {
167 unreachable!()
168 }
169 fn amoe_base_model_trainable_params(&self) -> usize {
170 unreachable!()
171 }
172 fn amoe_supported(&self) -> bool {
173 false
174 }
175 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
177 unreachable!()
178 }
179 #[allow(clippy::too_many_arguments)]
181 fn amoe_create_layers(
182 &mut self,
183 _model_ids: Vec<String>,
184 _token: &TokenSource,
185 _revision: Option<String>,
186 _match_regex: &str,
187 _config: AnyMoeConfig,
188 _dtype: DType,
189 _dev: &Device,
190 (_prefix, _mlp): (String, String),
191 _layers: Vec<usize>,
192 _expert_type: AnyMoeExpertType,
193 _silent: bool,
194 _gate_model_id: Option<String>,
195 ) -> candle_core::Result<()> {
196 unreachable!()
197 }
198 #[allow(clippy::too_many_arguments)]
200 fn amoe_pre_train(
201 &self,
202 _inputs: AnyMoeTrainingInputs,
203 (_prefix, _mlp): (String, String),
204 _model_ids: Vec<String>,
205 _token: TokenSource,
206 _revision: Option<String>,
207 _layers: Vec<usize>,
208 _silent: bool,
209 ) -> Result<Option<AnyMoeTrainingResult>, candle_core::Error> {
210 unreachable!()
211 }
212}
213
214#[derive(Clone)]
217pub enum ModelCategory {
218 Text,
219 Vision {
220 has_conv2d: bool,
221 prefixer: Arc<dyn VisionPromptPrefixer>,
222 },
223 Diffusion,
224}
225
226impl PartialEq for ModelCategory {
227 fn eq(&self, other: &Self) -> bool {
228 match (self, other) {
229 (Self::Text, Self::Text) => true,
230 (Self::Vision { .. }, Self::Vision { .. }) => true,
231 (Self::Diffusion, Self::Diffusion) => true,
232 (Self::Text, _) => false,
233 (Self::Vision { .. }, _) => false,
234 (Self::Diffusion, _) => false,
235 }
236 }
237}
238
239pub trait VisionPromptPrefixer: Send + Sync {
241 fn prefix_image(&self, image_indees: Vec<usize>, prompt: &str) -> String;
243}
244
245pub enum CacheBackendMetadata<'a> {
246 DefaultInstructions {
247 pre_op: CacheInstruction,
248 post_op: CacheInstruction,
249 },
250 PagedAttention {
251 metadata: PagedAttentionMeta<'a>,
252 blocks_to_swap_in: HashMap<usize, usize>,
253 blocks_to_swap_out: HashMap<usize, usize>,
254 blocks_to_copy: HashMap<usize, Vec<usize>>,
255 },
256}
257
258#[derive(Clone, Debug)]
259pub enum ForwardInputsResult {
260 RawLogits { logits: Tensor },
261 CausalGeneration { logits: Tensor },
262 Image { images: Vec<DynamicImage> },
263}
264
265impl ForwardInputsResult {
266 fn index_bs(&self, bs_idx: usize) -> candle_core::Result<Self> {
267 match self {
268 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
269 logits: logits.i(bs_idx)?,
270 }),
271 Self::RawLogits { logits } => Ok(Self::RawLogits {
272 logits: logits.i(bs_idx)?,
273 }),
274 Self::Image { images } => Ok(Self::Image {
275 images: vec![images[bs_idx].clone()],
276 }),
277 }
278 }
279
280 fn to_device(&self, device: &Device) -> candle_core::Result<Self> {
281 match self {
282 Self::CausalGeneration { logits } => Ok(Self::CausalGeneration {
283 logits: logits.to_device(device)?,
284 }),
285 Self::RawLogits { logits } => Ok(Self::RawLogits {
286 logits: logits.to_device(device)?,
287 }),
288 Self::Image { .. } => Ok(self.clone()),
289 }
290 }
291}
292
293#[async_trait::async_trait]
294pub trait Pipeline:
295 Send
296 + Sync
297 + PreProcessingMixin
298 + IsqPipelineMixin
299 + CacheManagerMixin
300 + MetadataMixin
301 + AnyMoePipelineMixin
302{
303 fn forward_inputs(
304 &mut self,
305 inputs: Box<dyn Any>,
306 return_raw_logits: bool,
307 ) -> Result<ForwardInputsResult, candle_core::Error>;
308
309 #[allow(clippy::too_many_arguments)]
311 async fn step(
312 &mut self,
313 input_seqs: &mut [&mut Sequence],
314 is_prompt: bool,
315 return_raw_logits: bool,
316 prefix_cacher: &mut PrefixCacheManagerV2,
317 disable_eos_stop: bool,
318 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
319 backend_metadata: CacheBackendMetadata<'_>,
320 ) -> Result<Duration, candle_core::Error> {
321 match backend_metadata {
322 CacheBackendMetadata::DefaultInstructions { pre_op, post_op } => {
323 let inputs_iter = self.get_processor().inputs_processor().process_inputs(
324 self.tokenizer(),
325 input_seqs,
326 is_prompt,
327 self.get_metadata().is_xlora,
328 &self.device(),
329 self.get_metadata().no_kv_cache,
330 None,
331 return_raw_logits,
332 self.get_input_processor_config(),
333 None,
334 self.get_metadata().prompt_chunksize,
335 self.device_mapper(),
336 );
337
338 let mut logits = vec![None; input_seqs.len()];
339 let prompt_chunksize = self
340 .get_metadata()
341 .prompt_chunksize
342 .map(NonZeroUsize::get)
343 .unwrap_or(1);
344 let len_inputs = input_seqs
345 .iter()
346 .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
347 .max()
348 .unwrap();
349 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
350
351 let mut exec_duration = Duration::ZERO;
352 for (i, inputs) in inputs_iter.into_iter().enumerate() {
353 let InputProcessorOutput {
354 inputs,
355 seq_indices,
356 } = inputs.map_err(candle_core::Error::msg)?;
357 if i == 0 {
358 match pre_op {
359 CacheInstruction::In => self.clone_in_cache(input_seqs),
360 CacheInstruction::Nothing => (),
361 CacheInstruction::Reset {
362 load_preallocated_cache,
363 reset_non_granular,
364 } => self.set_none_cache(
365 input_seqs,
366 reset_non_granular,
367 false,
368 load_preallocated_cache,
369 ),
370 _ => unreachable!("Unreachable PRE cache op."),
371 }
372 }
373
374 let start = Instant::now();
375 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
376 let end = Instant::now();
377 exec_duration += end.duration_since(start);
378
379 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
380 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
381 raw_out_logits[seq_idx][i] =
382 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
383 } else {
384 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
385 }
386 }
387 }
388
389 match post_op {
390 CacheInstruction::Out => self.clone_out_cache(input_seqs),
391 CacheInstruction::Nothing => (),
392 CacheInstruction::Reset {
393 load_preallocated_cache,
394 reset_non_granular,
395 } => self.set_none_cache(
396 input_seqs,
397 reset_non_granular,
398 false,
399 load_preallocated_cache,
400 ),
401 _ => unreachable!("Unreachable POST cache op."),
402 }
403
404 if raw_out_logits[0][0].is_some() {
405 let start = Instant::now();
406 response::send_raw_responses(
407 input_seqs,
408 raw_out_logits
409 .into_iter()
410 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
411 .collect(),
412 )
413 .await?;
414 let end = Instant::now();
415 exec_duration += end.duration_since(start);
416
417 return Ok(exec_duration);
418 }
419
420 let start = Instant::now();
421 let logits = logits
422 .into_iter()
423 .map(|l| {
424 l.expect("Did not get any inputs. This is shocking.")
425 .to_device(&Device::Cpu)
426 })
427 .collect::<candle_core::Result<Vec<_>>>()?;
428
429 match &logits[0] {
430 ForwardInputsResult::RawLogits { .. } => unreachable!(),
431 ForwardInputsResult::CausalGeneration { .. } => {
432 self.sample_causal_gen(
433 input_seqs,
434 logits
435 .into_iter()
436 .map(|r| {
437 #[allow(irrefutable_let_patterns)]
438 let ForwardInputsResult::CausalGeneration { logits } = r
439 else {
440 unreachable!(
441 "All results must have same type, `CausalGeneration`"
442 )
443 };
444 logits
445 })
446 .collect::<Vec<_>>(),
447 prefix_cacher,
448 disable_eos_stop,
449 rng,
450 )
451 .await?;
452 }
453 ForwardInputsResult::Image { .. } => {
454 response::send_image_responses(
455 input_seqs,
456 logits
457 .into_iter()
458 .map(|r| {
459 #[allow(irrefutable_let_patterns)]
460 let ForwardInputsResult::Image { images } = r
461 else {
462 unreachable!(
463 "All results must have same type, `CausalGeneration`"
464 )
465 };
466 images
467 .into_iter()
468 .next()
469 .expect("Must have at least 1 element.")
470 })
471 .collect::<Vec<_>>(),
472 )
473 .await?;
474 }
475 }
476 let end = Instant::now();
477 exec_duration += end.duration_since(start);
478
479 Ok(exec_duration)
480 }
481 CacheBackendMetadata::PagedAttention {
482 metadata,
483 blocks_to_copy,
484 blocks_to_swap_in,
485 blocks_to_swap_out,
486 } => {
487 self.get_metadata()
489 .cache_engine
490 .as_ref()
491 .expect("PagedAttention must have cache engines.")
492 .execute_scheduler_ops(
493 &blocks_to_swap_in,
494 &blocks_to_swap_out,
495 &blocks_to_copy,
496 )?;
497
498 let inputs_iter = self.get_processor().inputs_processor().process_inputs(
499 self.tokenizer(),
500 input_seqs,
501 is_prompt,
502 self.get_metadata().is_xlora,
503 &self.device(),
504 self.get_metadata().no_kv_cache,
505 None,
506 return_raw_logits,
507 self.get_input_processor_config(),
508 Some(metadata),
509 self.get_metadata().prompt_chunksize,
510 self.device_mapper(),
511 );
512
513 let mut logits = vec![None; input_seqs.len()];
514 let prompt_chunksize = self
515 .get_metadata()
516 .prompt_chunksize
517 .map(NonZeroUsize::get)
518 .unwrap_or(1);
519 let len_inputs = input_seqs
520 .iter()
521 .map(|seq| seq.get_toks().len().div_ceil(prompt_chunksize))
522 .max()
523 .unwrap();
524 let mut raw_out_logits = vec![vec![None; len_inputs]; input_seqs.len()];
525
526 let mut exec_duration = Duration::ZERO;
527 for (i, inputs) in inputs_iter.into_iter().enumerate() {
528 let InputProcessorOutput {
529 inputs,
530 seq_indices,
531 } = inputs.map_err(candle_core::Error::msg)?;
532
533 let start = Instant::now();
534 let raw_logits = self.forward_inputs(inputs, return_raw_logits)?;
535 let end = Instant::now();
536 exec_duration += end.duration_since(start);
537
538 for (logit_idx, seq_idx) in seq_indices.into_iter().enumerate() {
539 if let ForwardInputsResult::RawLogits { logits } = &raw_logits {
540 raw_out_logits[seq_idx][i] =
541 Some(logits.i(logit_idx)?.to_device(&Device::Cpu)?);
542 } else {
543 logits[seq_idx] = Some(raw_logits.index_bs(logit_idx)?);
544 }
545 }
546 }
547
548 if raw_out_logits[0][0].is_some() {
549 let start = Instant::now();
550 response::send_raw_responses(
551 input_seqs,
552 raw_out_logits
553 .into_iter()
554 .map(|raw| raw.into_iter().flatten().collect::<Vec<_>>())
555 .collect(),
556 )
557 .await?;
558 let end = Instant::now();
559 exec_duration += end.duration_since(start);
560
561 return Ok(exec_duration);
562 }
563
564 let start = Instant::now();
565 let logits = logits
566 .into_iter()
567 .map(|l| {
568 l.expect("Did not get any inputs. This is shocking.")
569 .to_device(&Device::Cpu)
570 })
571 .collect::<candle_core::Result<Vec<_>>>()?;
572
573 match &logits[0] {
574 ForwardInputsResult::RawLogits { .. } => unreachable!(),
575 ForwardInputsResult::CausalGeneration { .. } => {
576 self.sample_causal_gen(
577 input_seqs,
578 logits
579 .into_iter()
580 .map(|r| {
581 #[allow(irrefutable_let_patterns)]
582 let ForwardInputsResult::CausalGeneration { logits } = r
583 else {
584 unreachable!("All results must have same type")
585 };
586 logits
587 })
588 .collect::<Vec<_>>(),
589 prefix_cacher,
590 disable_eos_stop,
591 rng,
592 )
593 .await?;
594 }
595 ForwardInputsResult::Image { .. } => {
596 response::send_image_responses(
597 input_seqs,
598 logits
599 .into_iter()
600 .map(|r| {
601 #[allow(irrefutable_let_patterns)]
602 let ForwardInputsResult::Image { images } = r
603 else {
604 unreachable!(
605 "All results must have same type, `CausalGeneration`"
606 )
607 };
608 images
609 .into_iter()
610 .next()
611 .expect("Must have at least 1 element.")
612 })
613 .collect::<Vec<_>>(),
614 )
615 .await?;
616 }
617 }
618 let end = Instant::now();
619 exec_duration += end.duration_since(start);
620
621 Ok(exec_duration)
622 }
623 }
624 }
625
626 async fn sample_causal_gen(
627 &self,
628 seqs: &mut [&mut Sequence],
629 logits: Vec<Tensor>,
630 prefix_cacher: &mut PrefixCacheManagerV2,
631 disable_eos_stop: bool,
632 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
633 ) -> Result<(), candle_core::Error>;
634
635 fn category(&self) -> ModelCategory;
636}
637
638pub(crate) fn extract_logits(
639 logits: &Tensor,
640 context_lens: Vec<(usize, usize)>,
641) -> candle_core::Result<Tensor> {
642 let mut toks = Vec::new();
643 for (dim, (start, len)) in logits.chunk(logits.dims()[0], 0)?.iter().zip(context_lens) {
644 toks.push(dim.narrow(1, start, len)?);
645 }
646 Tensor::cat(&toks, 0)
647}
648
649#[cfg(test)]
650mod tests {
651 use crate::MessageContent;
652 use either::Either;
653 use indexmap::IndexMap;
654 use serde_json::Value;
655
656 macro_rules! hashmap {
657 (@single $($x:tt)*) => (());
658 (@count $($rest:expr),*) => (<[()]>::len(&[$(hashmap!(@single $rest)),*]));
659
660 ($($key:expr => $value:expr,)+) => { hashmap!($($key => $value),+) };
661 ($($key:expr => $value:expr),*) => {
662 {
663 let _cap = hashmap!(@count $($key),*);
664 let mut _map = ::indexmap::IndexMap::with_capacity(_cap);
665 $(
666 let _ = _map.insert($key, Value::String($value));
667 )*
668 _map
669 }
670 };
671 }
672
673 #[cfg(test)]
674 #[track_caller]
675 fn test_with_inputs(
676 templates: &[(bool, &str, &str, &str, &str)],
677 expected_outputs: &[&str],
678 inputs: Vec<IndexMap<String, MessageContent>>,
679 ) {
680 use crate::pipeline::chat_template::ChatTemplateValue;
681
682 use super::chat_template::apply_chat_template_to;
683 let mut failed = Vec::new();
684 let n_templates = templates.len();
685 for ((has_system, bos, eos, unk, template), expected) in
686 templates.iter().zip(expected_outputs)
687 {
688 let output = match apply_chat_template_to(
689 if !has_system {
690 inputs[1..].to_vec()
691 } else {
692 inputs.clone()
693 },
694 true,
695 None,
696 &ChatTemplateValue(Either::Left(template.to_string())),
697 Some(bos.to_string()),
698 Some(eos.to_string()),
699 Some(unk.to_string()),
700 Vec::new(),
701 ) {
702 Ok(v) => v,
703 Err(e) => {
704 failed.push(format!("Failed with {e}."));
705 continue;
706 }
707 };
708 if output != *expected {
709 failed.push(format!(
710 "Expected: `{}` \n\nGot: `{}`",
711 expected.replace('\n', "\\n"),
712 output.replace('\n', "\\n")
713 ));
714 }
715 }
716 if !failed.is_empty() {
717 for (i, line) in failed.iter().enumerate() {
718 println!("------------ Template {i} ------------");
719 println!("{line}");
720 }
721 println!("------------------------");
722 panic!("{}/{n_templates} chat templates failed.", failed.len());
723 }
724 }
725
726 #[test]
727 fn test_chat_templates() {
736 let templates = [
737 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"),
739 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
741 (true, "<s>", "</s>", "<unk>", "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"),
743 (false, "<s>", "</s>", "<unk>", "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"),
745 (false, "<bos>", "<eos>", "<unk>", "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"),
747 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
749 ];
750 let expected_outputs = [
751 "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n",
753 "<s>[INST] Hello [/INST]Hi there</s> [INST] Who are you [/INST] I am an assistant </s> [INST] Another question [/INST]",
755 "<s>[INST] <<SYS>>\nYou are a helpful assistant\n<</SYS>>\n\nHello [/INST] Hi there </s><s>[INST] Who are you [/INST] I am an assistant </s><s>[INST] Another question [/INST]",
757 "<s>[INST] Hello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
759 "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\nHi there<end_of_turn>\n<start_of_turn>user\nWho are you<end_of_turn>\n<start_of_turn>model\nI am an assistant<end_of_turn>\n<start_of_turn>user\nAnother question<end_of_turn>\n<start_of_turn>model\n",
761 ];
762 let messages = [
763 ["system", "You are a helpful assistant"],
764 ["user", "Hello"],
765 ["assistant", "Hi there"],
766 ["user", "Who are you"],
767 ["assistant", " I am an assistant "],
768 ["user", "Another question"],
769 ];
770 let mut inputs = Vec::new();
771 for [role, content] in messages {
772 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
773 IndexMap::new();
774 message.insert("role".to_string(), Either::Left(role.to_string()));
775 message.insert("content".to_string(), Either::Left(content.to_string()));
776 inputs.push(message);
777 }
778 test_with_inputs(&templates, &expected_outputs, inputs);
779 }
780
781 #[test]
782 fn test_image_chat_templates() {
795 let templates = [
796 (true, "<s>", "</s>", "<unk>", "{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"),
798 ];
799 let expected_outputs = [
800 "System: You are a helpful assistant<end_of_utterance>\nUser:<image>Hello, please describe the above.<end_of_utterance>\nAssistant: Hi there<end_of_utterance>\nUser:<image>This is me, who are you<end_of_utterance>\nAssistant: I am an assistant <end_of_utterance>\nUser:<image>Another question, what is this?<end_of_utterance>\nAssistant:",
802 ];
803
804 let mut inputs = Vec::new();
805
806 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
807 IndexMap::new();
808 message.insert("role".to_string(), Either::Left("system".to_string()));
809 message.insert(
810 "content".to_string(),
811 Either::Right(vec![hashmap! {
812 "type".to_string() => "text".to_string(),
813 "text".to_string() => "You are a helpful assistant".to_string()
814 }]),
815 );
816 inputs.push(message);
817
818 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
819 IndexMap::new();
820 message.insert("role".to_string(), Either::Left("user".to_string()));
821 message.insert(
822 "content".to_string(),
823 Either::Right(vec![
824 hashmap! {
825 "type".to_string() => "image".to_string()
826 },
827 hashmap! {
828 "type".to_string() => "text".to_string(),
829 "text".to_string() => "Hello, please describe the above.".to_string()
830 },
831 ]),
832 );
833 inputs.push(message);
834
835 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
836 IndexMap::new();
837 message.insert("role".to_string(), Either::Left("assistant".to_string()));
838 message.insert(
839 "content".to_string(),
840 Either::Right(vec![hashmap! {
841 "type".to_string() => "text".to_string(),
842 "text".to_string() => "Hi there".to_string()
843 }]),
844 );
845 inputs.push(message);
846
847 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
848 IndexMap::new();
849 message.insert("role".to_string(), Either::Left("user".to_string()));
850 message.insert(
851 "content".to_string(),
852 Either::Right(vec![
853 hashmap! {
854 "type".to_string() => "image".to_string()
855 },
856 hashmap! {
857 "type".to_string() => "text".to_string(),
858 "text".to_string() => "This is me, who are you".to_string()
859 },
860 ]),
861 );
862 inputs.push(message);
863
864 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
865 IndexMap::new();
866 message.insert("role".to_string(), Either::Left("assistant".to_string()));
867 message.insert(
868 "content".to_string(),
869 Either::Right(vec![hashmap! {
870 "type".to_string() => "text".to_string(),
871 "text".to_string() => " I am an assistant ".to_string()
872 }]),
873 );
874 inputs.push(message);
875
876 let mut message: IndexMap<String, Either<String, Vec<IndexMap<String, Value>>>> =
877 IndexMap::new();
878 message.insert("role".to_string(), Either::Left("user".to_string()));
879 message.insert(
880 "content".to_string(),
881 Either::Right(vec![
882 hashmap! {
883 "type".to_string() => "image".to_string()
884 },
885 hashmap! {
886 "type".to_string() => "text".to_string(),
887 "text".to_string() => "Another question, what is this?".to_string()
888 },
889 ]),
890 );
891 inputs.push(message);
892
893 test_with_inputs(&templates, &expected_outputs, inputs);
894 }
895}