1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16 "<|im_end|>", "<end_of_turn>", "<|end_of_text|>", ];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24 __type: Option<String>,
25 pub content: String,
26 lstrip: bool,
27 normalized: bool,
28 rstrip: bool,
29 single_word: bool,
30 special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39 #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44 #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49pub struct ChatTemplate {
51 add_bos_token: Option<bool>,
52 add_eos_token: Option<bool>,
53 added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54 additional_special_tokens: Option<Vec<String>>,
55 pub bos_token: Option<BeginEndUnkPadTok>,
56
57 pub chat_template: Option<ChatTemplateValue>,
61 clean_up_tokenization_spaces: Option<bool>,
62 device_map: Option<String>,
63 pub eos_token: Option<BeginEndUnkPadTok>,
64 legacy: Option<bool>,
65 model_max_length: Option<f64>,
66 pub pad_token: Option<BeginEndUnkPadTok>,
67 sp_model_kwargs: Option<HashMap<String, String>>,
68 spaces_between_special_tokens: Option<bool>,
69 tokenizer_class: Option<String>,
70 truncation_size: Option<String>,
71 pub unk_token: Option<BeginEndUnkPadTok>,
72 use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76 pub fn has_chat_template(&self) -> bool {
77 self.chat_template.is_some()
78 }
79
80 pub fn eos_tok(&self) -> Option<String> {
81 match self.eos_token.as_ref()?.0 {
82 Either::Left(ref lit) => Some(lit.clone()),
83 Either::Right(ref added) => Some(added.content.clone()),
84 }
85 }
86
87 pub fn bos_tok(&self) -> Option<String> {
88 match self.bos_token.as_ref()?.0 {
89 Either::Left(ref lit) => Some(lit.clone()),
90 Either::Right(ref added) => Some(added.content.clone()),
91 }
92 }
93
94 pub fn unk_tok(&self) -> Option<String> {
95 match self.unk_token.as_ref()?.0 {
96 Either::Left(ref lit) => Some(lit.clone()),
97 Either::Right(ref added) => Some(added.content.clone()),
98 }
99 }
100}
101
102pub fn calculate_eos_tokens(
103 chat_template: &ChatTemplate,
104 gen_conf: Option<GenerationConfig>,
105 tokenizer: &Tokenizer,
106) -> Vec<u32> {
107 let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
108 let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
109
110 for alternate in SUPPORTED_ALTERNATE_EOS {
111 if tokenizer.get_vocab(true).contains_key(*alternate) {
112 eos_tok_ids.push(alternate.to_string())
113 }
114 }
115
116 if let Some(gen_conf) = gen_conf {
117 if let Some(eos_field) = gen_conf.eos_token_id {
118 let ids = match eos_field {
119 Either::Left(id) => vec![id],
120 Either::Right(ids) => ids,
121 };
122 for id in ids {
123 let s = tokenizer
124 .decode(&[id], false)
125 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
126 if !eos_tok_ids.contains(&s) {
127 eos_tok_ids.push(s);
128 }
129 }
130 }
131
132 if let Some(bos_field) = gen_conf.bos_token_id {
133 let ids = match bos_field {
134 Either::Left(id) => vec![id],
135 Either::Right(ids) => ids,
136 };
137 for id in ids {
138 let s = tokenizer
139 .decode(&[id], false)
140 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
141 if !bos_tok_ids.contains(&s) {
142 bos_tok_ids.push(s);
143 }
144 }
145 }
146 }
147
148 eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
149 bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
150
151 let bos_render = bos_tok_ids
152 .iter()
153 .map(|val| format!("{val:?}"))
154 .collect::<Vec<String>>()
155 .join(", ");
156 let eos_render = eos_tok_ids
157 .iter()
158 .map(|val| format!("{val:?}"))
159 .collect::<Vec<String>>()
160 .join(", ");
161
162 info!(
163 "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
164 chat_template.unk_tok().unwrap_or("`None`".to_string()),
165 );
166
167 let mut eos_toks = Vec::new();
168 for eos_tok in eos_tok_ids {
169 eos_toks.push(
170 tokenizer
171 .get_vocab(true)
172 .get(&eos_tok)
173 .copied()
174 .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
175 )
176 }
177 eos_toks
178}
179
180#[allow(dead_code)]
181#[derive(Debug, Deserialize)]
182pub struct GenerationConfig {
183 #[serde(with = "either::serde_untagged_optional")]
184 bos_token_id: Option<Either<u32, Vec<u32>>>,
185 #[serde(with = "either::serde_untagged_optional")]
186 eos_token_id: Option<Either<u32, Vec<u32>>>,
187}
188
189fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
190 if let Ok(indent) = kwargs.get("indent") {
191 let mut buf = Vec::new();
192 let repeat = b" ".repeat(indent);
193 let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
194 let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
195 value.serialize(&mut ser).unwrap();
196 String::from_utf8(buf).map_err(|err| {
197 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
198 })
199 } else {
200 serde_json::to_string(&value).map_err(|err| {
201 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
202 })
203 }
204 .map_err(|err| {
205 Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
206 })
207 .map(|s| {
208 let mut rv = String::with_capacity(s.len());
210 for c in s.chars() {
211 match c {
212 '<' => rv.push_str("\\u003c"),
213 '>' => rv.push_str("\\u003e"),
214 '&' => rv.push_str("\\u0026"),
215 '\'' => rv.push_str("\\u0027"),
216 _ => rv.push(c),
217 }
218 }
219 Value::from_safe_string(rv)
220 })
221}
222
223fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
224 let date = chrono::Utc::now();
225 let date_string = date.format(&fmt).to_string();
226 Ok(date_string)
227}
228
229#[allow(clippy::too_many_arguments)]
230pub fn apply_chat_template_to(
231 messages: Vec<IndexMap<String, MessageContent>>,
232 add_generation_prompt: bool,
233 enable_thinking: Option<bool>,
234 template: &ChatTemplateValue,
235 bos_tok: Option<String>,
236 eos_tok: Option<String>,
237 unk_tok: Option<String>,
238 tools: Vec<Tool>,
239) -> Result<String> {
240 let mut env = Environment::new();
241
242 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
244
245 env.set_lstrip_blocks(true);
247 env.set_trim_blocks(true);
248
249 #[derive(Serialize, Deserialize)]
250 struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
251 let mut new_messages = Vec::new();
252 for message in messages {
253 let mut new_message = IndexMap::new();
254 for (k, v) in message {
255 new_message.insert(k, UntaggedContent(v));
256 }
257 new_messages.push(new_message);
258 }
259
260 let template = match &template.0 {
261 Either::Left(x) => x.clone(),
262 Either::Right(map) => {
263 let mut template = None;
264 let has_tool_use = map.iter().any(|t| {
265 t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
266 });
267 let must_use_tool_template = !tools.is_empty();
268
269 if must_use_tool_template && !has_tool_use {
270 anyhow::bail!(
271 "Tools were provided but this chat template does not handle tool usage"
272 );
273 }
274
275 for t in map {
276 let name = t.get("name");
277 if let Some(name) = name {
278 template = Some(t["template"].clone());
279 #[allow(clippy::if_same_then_else)]
280 if name == "tool_use" && !tools.is_empty() {
281 break;
282 } else if name == "default" && !must_use_tool_template {
283 break;
284 }
285 } else if t.contains_key("tool_use") && !tools.is_empty() {
286 template = Some(t["tool_use"].clone());
287 break;
288 } else if t.contains_key("default") && !must_use_tool_template {
289 template = Some(t["default"].clone());
290 break;
291 }
292 }
293
294 let Some(template) = template else {
295 anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
296 };
297 template
298 }
299 };
300 let mut template = template.replace("[::-1]", "|reverse");
301 let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
305 template = re
306 .replace_all(&template, |caps: ®ex::Captures| {
307 format!("range({})|reverse", &caps["expr"])
308 })
309 .into_owned();
310
311 if template.contains("{{ meta }}") {
312 template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
314 template = template.replace("{{ meta }}", "");
315 }
316
317 env.add_template("chat_template", &template)?;
318 env.add_function("raise_exception", raise_exception);
319 env.add_filter("tojson", tojson);
320 env.add_function("strftime_now", strftime_now);
321 let tmpl = env.get_template("chat_template").unwrap();
322
323 let date = chrono::Utc::now();
324 let date_string = date.format("%d, %B, %Y").to_string();
325
326 if tools.is_empty() {
327 Ok(tmpl.render(context! {
328 messages => new_messages,
329 add_generation_prompt => add_generation_prompt,
330 bos_token => bos_tok,
331 eos_token => eos_tok,
332 unk_token => unk_tok,
333 date_string => date_string,
334 enable_thinking => enable_thinking,
335 })?)
336 } else {
337 Ok(tmpl.render(context! {
338 messages => new_messages,
339 add_generation_prompt => add_generation_prompt,
340 bos_token => bos_tok,
341 eos_token => eos_tok,
342 unk_token => unk_tok,
343 tools => tools,
344 date_string => date_string,
345 enable_thinking => enable_thinking,
346 })?)
347 }
348}