mistralrs_core/pipeline/
processing.rs1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use either::Either;
5use indexmap::IndexMap;
6
7use crate::{
8 vision_models::{preprocessor_config::PreProcessorConfig, processor_config::ProcessorConfig},
9 MessageContent, Pipeline, Tool,
10};
11
12use super::{chat_template::apply_chat_template_to, text_models_inputs_processor, InputsProcessor};
13
14pub trait ProcessorCreator {
16 fn new_processor(
17 _: Option<ProcessorConfig>,
18 _: PreProcessorConfig,
19 ) -> Arc<dyn Processor + Send + Sync>;
20}
21
22pub enum MessagesAction {
23 Keep,
25 FlattenOnlyText,
27}
28
29pub trait Processor {
33 fn process(
35 &self,
36 pipeline: &dyn Pipeline,
37 messages: Vec<IndexMap<String, MessageContent>>,
38 add_generation_prompt: bool,
39 add_special_tokens: bool,
40 tools: Vec<Tool>,
41 ) -> Result<(Vec<u32>, String)> {
42 let prompt = apply_chat_template(
53 pipeline,
54 messages,
55 add_generation_prompt,
56 self.template_action(),
57 tools,
58 )?;
59 let encoding = pipeline
60 .tokenizer()
61 .with_context(|| {
62 "Default `Processor::process` requires the model to have a tokenizer."
63 })?
64 .encode_fast(prompt.clone(), add_special_tokens)
65 .map_err(anyhow::Error::msg)?;
66 Ok((encoding.get_ids().to_vec(), prompt))
67 }
68 fn inputs_processor(&self) -> Arc<dyn InputsProcessor>;
69 fn get_special_tokens(&self) -> &[&'static str];
70 fn template_action(&self) -> MessagesAction;
71}
72
73pub(crate) fn apply_chat_template(
74 pipeline: &dyn Pipeline,
75 messages: Vec<IndexMap<String, MessageContent>>,
76 add_generation_prompt: bool,
77 action: MessagesAction,
78 tools: Vec<Tool>,
79) -> Result<String> {
80 let messages = match action {
81 MessagesAction::Keep => messages,
82 MessagesAction::FlattenOnlyText => {
83 let mut new_messages = Vec::new();
86 for message in messages {
87 let mut new_message = IndexMap::new();
88 for (k, v) in message {
89 if k == "content" {
90 match v {
91 Either::Left(lv) => {
92 new_message.insert(k, Either::Left(lv));
93 }
94 Either::Right(rv) => {
95 'outer: for content_row in rv {
96 for (content_k, content_v) in content_row {
97 if content_k == "text" {
98 if let Some(content_str) = content_v.as_str() {
99 new_message.insert(
100 k,
101 Either::Left(content_str.to_string()),
102 );
103 break 'outer;
104 }
105 }
106 }
107 }
108 }
109 }
110 } else {
111 new_message.insert(k, Either::Left(v.left().unwrap()));
112 }
113 }
114 new_messages.push(new_message)
115 }
116 new_messages
117 }
118 };
119 let chat_template = pipeline
120 .get_chat_template()
121 .with_context(|| "`apply_chat_template` expects the pipeline to have a chat template.")?;
122 let template = chat_template.chat_template.as_ref().unwrap();
123 let bos_tok = if let Some(ref bos) = chat_template.bos_token {
124 match bos.0 {
125 Either::Left(ref lit) => Some(lit.to_string()),
126 Either::Right(ref added) => Some(added.content.to_string()),
127 }
128 } else {
129 None
130 };
131 let eos_tok = if let Some(ref eos) = chat_template.eos_token {
132 match eos.0 {
133 Either::Left(ref lit) => Some(lit.to_string()),
134 Either::Right(ref added) => Some(added.content.to_string()),
135 }
136 } else {
137 None
138 };
139 let unk_tok = if let Some(ref unk) = chat_template.unk_token {
140 match unk.0 {
141 Either::Left(ref lit) => Some(lit.to_string()),
142 Either::Right(ref added) => Some(added.content.to_string()),
143 }
144 } else {
145 None
146 };
147 apply_chat_template_to(
148 messages,
149 add_generation_prompt,
150 template,
151 bos_tok,
152 eos_tok,
153 unk_tok,
154 tools,
155 )
156}
157
158pub struct BasicProcessor;
159
160impl Processor for BasicProcessor {
161 fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
162 Arc::new(text_models_inputs_processor::TextInputsProcessor)
163 }
164 fn get_special_tokens(&self) -> &[&'static str] {
165 &[]
166 }
167 fn template_action(&self) -> MessagesAction {
168 MessagesAction::Keep
169 }
170}