1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use serde::{Deserialize, Serialize};
9use tokenizers::Tokenizer;
10use tracing::info;
11
12use crate::{MessageContent, Tool};
13
14const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
15 "<|im_end|>", "<end_of_turn>", "<|end_of_text|>", ];
19
20#[allow(dead_code)]
21#[derive(Debug, Deserialize)]
22pub struct AddedTokensDecoder {
23 __type: Option<String>,
24 pub content: String,
25 lstrip: bool,
26 normalized: bool,
27 rstrip: bool,
28 single_word: bool,
29 special: Option<bool>,
30}
31
32fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
33 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
34}
35
36#[derive(Debug, Deserialize)]
37pub struct BeginEndUnkPadTok(
38 #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
39);
40
41#[derive(Debug, Deserialize)]
42pub struct ChatTemplateValue(
43 #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
44);
45
46#[allow(dead_code)]
47#[derive(Debug, Deserialize, Default)]
48pub struct ChatTemplate {
50 add_bos_token: Option<bool>,
51 add_eos_token: Option<bool>,
52 added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
53 additional_special_tokens: Option<Vec<String>>,
54 pub bos_token: Option<BeginEndUnkPadTok>,
55
56 pub chat_template: Option<ChatTemplateValue>,
60 clean_up_tokenization_spaces: Option<bool>,
61 device_map: Option<String>,
62 pub eos_token: Option<BeginEndUnkPadTok>,
63 legacy: Option<bool>,
64 model_max_length: Option<f64>,
65 pub pad_token: Option<BeginEndUnkPadTok>,
66 sp_model_kwargs: Option<HashMap<String, String>>,
67 spaces_between_special_tokens: Option<bool>,
68 tokenizer_class: Option<String>,
69 truncation_size: Option<String>,
70 pub unk_token: Option<BeginEndUnkPadTok>,
71 use_default_system_prompt: Option<bool>,
72}
73
74impl ChatTemplate {
75 pub fn has_chat_template(&self) -> bool {
76 self.chat_template.is_some()
77 }
78
79 pub fn eos_tok(&self) -> Option<String> {
80 match self.eos_token.as_ref()?.0 {
81 Either::Left(ref lit) => Some(lit.clone()),
82 Either::Right(ref added) => Some(added.content.clone()),
83 }
84 }
85
86 pub fn bos_tok(&self) -> Option<String> {
87 match self.bos_token.as_ref()?.0 {
88 Either::Left(ref lit) => Some(lit.clone()),
89 Either::Right(ref added) => Some(added.content.clone()),
90 }
91 }
92
93 pub fn unk_tok(&self) -> Option<String> {
94 match self.unk_token.as_ref()?.0 {
95 Either::Left(ref lit) => Some(lit.clone()),
96 Either::Right(ref added) => Some(added.content.clone()),
97 }
98 }
99}
100
101pub fn calculate_eos_tokens(
102 chat_template: &ChatTemplate,
103 gen_conf: Option<GenerationConfig>,
104 tokenizer: &Tokenizer,
105) -> Vec<u32> {
106 let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
107 let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
108
109 for alternate in SUPPORTED_ALTERNATE_EOS {
110 if tokenizer.get_vocab(true).contains_key(*alternate) {
111 eos_tok_ids.push(alternate.to_string())
112 }
113 }
114
115 if let Some(gen_conf) = gen_conf {
116 let ids = match gen_conf.eos_token_id {
117 Either::Left(id) => vec![id],
118 Either::Right(ids) => ids,
119 };
120 for id in ids {
121 let s = tokenizer
122 .decode(&[id], false)
123 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
124 if !eos_tok_ids.contains(&s) {
125 eos_tok_ids.push(s);
126 }
127 }
128
129 let ids = match gen_conf.bos_token_id {
130 Either::Left(id) => vec![id],
131 Either::Right(ids) => ids,
132 };
133 for id in ids {
134 let s = tokenizer
135 .decode(&[id], false)
136 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
137 if !bos_tok_ids.contains(&s) {
138 bos_tok_ids.push(s);
139 }
140 }
141 }
142
143 eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
144 bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
145
146 let bos_render = bos_tok_ids
147 .iter()
148 .map(|val| format!("{:?}", val))
149 .collect::<Vec<String>>()
150 .join(", ");
151 let eos_render = eos_tok_ids
152 .iter()
153 .map(|val| format!("{:?}", val))
154 .collect::<Vec<String>>()
155 .join(", ");
156
157 info!(
158 "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
159 chat_template.unk_tok().unwrap_or("`None`".to_string()),
160 );
161
162 let mut eos_toks = Vec::new();
163 for eos_tok in eos_tok_ids {
164 eos_toks.push(
165 tokenizer
166 .get_vocab(true)
167 .get(&eos_tok)
168 .copied()
169 .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
170 )
171 }
172 eos_toks
173}
174
175#[allow(dead_code)]
176#[derive(Debug, Deserialize)]
177pub struct GenerationConfig {
178 #[serde(with = "either::serde_untagged")]
179 bos_token_id: Either<u32, Vec<u32>>,
180 #[serde(with = "either::serde_untagged")]
181 eos_token_id: Either<u32, Vec<u32>>,
182}
183
184fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
185 if let Ok(indent) = kwargs.get("indent") {
186 let mut buf = Vec::new();
187 let repeat = b" ".repeat(indent);
188 let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
189 let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
190 value.serialize(&mut ser).unwrap();
191 String::from_utf8(buf).map_err(|err| {
192 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
193 })
194 } else {
195 serde_json::to_string(&value).map_err(|err| {
196 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
197 })
198 }
199 .map_err(|err| {
200 Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
201 })
202 .map(|s| {
203 let mut rv = String::with_capacity(s.len());
205 for c in s.chars() {
206 match c {
207 '<' => rv.push_str("\\u003c"),
208 '>' => rv.push_str("\\u003e"),
209 '&' => rv.push_str("\\u0026"),
210 '\'' => rv.push_str("\\u0027"),
211 _ => rv.push(c),
212 }
213 }
214 Value::from_safe_string(rv)
215 })
216}
217
218fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
219 let date = chrono::Utc::now();
220 let date_string = date.format(&fmt).to_string();
221 Ok(date_string)
222}
223
224pub fn apply_chat_template_to(
225 messages: Vec<IndexMap<String, MessageContent>>,
226 add_generation_prompt: bool,
227 template: &ChatTemplateValue,
228 bos_tok: Option<String>,
229 eos_tok: Option<String>,
230 unk_tok: Option<String>,
231 tools: Vec<Tool>,
232) -> Result<String> {
233 let mut env = Environment::new();
234
235 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
237
238 env.set_lstrip_blocks(true);
240 env.set_trim_blocks(true);
241
242 #[derive(Serialize, Deserialize)]
243 struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
244 let mut new_messages = Vec::new();
245 for message in messages {
246 let mut new_message = IndexMap::new();
247 for (k, v) in message {
248 new_message.insert(k, UntaggedContent(v));
249 }
250 new_messages.push(new_message);
251 }
252
253 let template = match &template.0 {
254 Either::Left(x) => x.clone(),
255 Either::Right(map) => {
256 let mut template = None;
257 let has_tool_use = map.iter().any(|t| {
258 t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
259 });
260 let must_use_tool_template = !tools.is_empty();
261
262 if must_use_tool_template && !has_tool_use {
263 anyhow::bail!(
264 "Tools were provided but this chat template does not handle tool usage"
265 );
266 }
267
268 for t in map {
269 let name = t.get("name");
270 if let Some(name) = name {
271 template = Some(t["template"].clone());
272 #[allow(clippy::if_same_then_else)]
273 if name == "tool_use" && !tools.is_empty() {
274 break;
275 } else if name == "default" && !must_use_tool_template {
276 break;
277 }
278 } else if t.contains_key("tool_use") && !tools.is_empty() {
279 template = Some(t["tool_use"].clone());
280 break;
281 } else if t.contains_key("default") && !must_use_tool_template {
282 template = Some(t["default"].clone());
283 break;
284 }
285 }
286
287 let Some(template) = template else {
288 anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
289 };
290 template
291 }
292 };
293
294 env.add_template("chat_template", &template)?;
295 env.add_function("raise_exception", raise_exception);
296 env.add_filter("tojson", tojson);
297 env.add_function("strftime_now", strftime_now);
298 let tmpl = env.get_template("chat_template").unwrap();
299
300 let date = chrono::Utc::now();
301 let date_string = date.format("%d, %B, %Y").to_string();
302
303 if tools.is_empty() {
304 Ok(tmpl.render(context! {
305 messages => new_messages,
306 add_generation_prompt => add_generation_prompt,
307 bos_token => bos_tok,
308 eos_token => eos_tok,
309 unk_token => unk_tok,
310 date_string => date_string,
311 })?)
312 } else {
313 Ok(tmpl.render(context! {
314 messages => new_messages,
315 add_generation_prompt => add_generation_prompt,
316 bos_token => bos_tok,
317 eos_token => eos_tok,
318 unk_token => unk_tok,
319 tools => tools,
320 date_string => date_string,
321 })?)
322 }
323}