1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16 "<|im_end|>", "<end_of_turn>", "<|end_of_text|>", ];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24 __type: Option<String>,
25 pub content: String,
26 lstrip: bool,
27 normalized: bool,
28 rstrip: bool,
29 single_word: bool,
30 special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39 #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44 #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49pub struct ChatTemplate {
51 add_bos_token: Option<bool>,
52 add_eos_token: Option<bool>,
53 added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54 additional_special_tokens: Option<Vec<String>>,
55 pub bos_token: Option<BeginEndUnkPadTok>,
56
57 pub chat_template: Option<ChatTemplateValue>,
61 clean_up_tokenization_spaces: Option<bool>,
62 device_map: Option<String>,
63 pub eos_token: Option<BeginEndUnkPadTok>,
64 legacy: Option<bool>,
65 model_max_length: Option<f64>,
66 pub pad_token: Option<BeginEndUnkPadTok>,
67 sp_model_kwargs: Option<HashMap<String, String>>,
68 spaces_between_special_tokens: Option<bool>,
69 tokenizer_class: Option<String>,
70 truncation_size: Option<String>,
71 pub unk_token: Option<BeginEndUnkPadTok>,
72 use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76 pub fn has_chat_template(&self) -> bool {
77 self.chat_template.is_some()
78 }
79
80 pub fn eos_tok(&self) -> Option<String> {
81 match self.eos_token.as_ref()?.0 {
82 Either::Left(ref lit) => Some(lit.clone()),
83 Either::Right(ref added) => Some(added.content.clone()),
84 }
85 }
86
87 pub fn bos_tok(&self) -> Option<String> {
88 match self.bos_token.as_ref()?.0 {
89 Either::Left(ref lit) => Some(lit.clone()),
90 Either::Right(ref added) => Some(added.content.clone()),
91 }
92 }
93
94 pub fn unk_tok(&self) -> Option<String> {
95 match self.unk_token.as_ref()?.0 {
96 Either::Left(ref lit) => Some(lit.clone()),
97 Either::Right(ref added) => Some(added.content.clone()),
98 }
99 }
100}
101
102pub fn calculate_eos_tokens(
103 chat_template: &ChatTemplate,
104 gen_conf: Option<GenerationConfig>,
105 tokenizer: &Tokenizer,
106) -> Vec<u32> {
107 let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
108 let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
109
110 for alternate in SUPPORTED_ALTERNATE_EOS {
111 if tokenizer.get_vocab(true).contains_key(*alternate) {
112 eos_tok_ids.push(alternate.to_string())
113 }
114 }
115
116 if let Some(gen_conf) = gen_conf {
117 if let Some(eos_field) = gen_conf.eos_token_id {
118 let ids = match eos_field {
119 Either::Left(id) => vec![id],
120 Either::Right(ids) => ids,
121 };
122 for id in ids {
123 let s = tokenizer
124 .decode(&[id], false)
125 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
126 if !eos_tok_ids.contains(&s) {
127 eos_tok_ids.push(s);
128 }
129 }
130 }
131
132 if let Some(bos_field) = gen_conf.bos_token_id {
133 let ids = match bos_field {
134 Either::Left(id) => vec![id],
135 Either::Right(ids) => ids,
136 };
137 for id in ids {
138 let s = tokenizer
139 .decode(&[id], false)
140 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
141 if !bos_tok_ids.contains(&s) {
142 bos_tok_ids.push(s);
143 }
144 }
145 }
146 }
147
148 eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
149 bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
150
151 let bos_render = bos_tok_ids
152 .iter()
153 .map(|val| format!("{val:?}"))
154 .collect::<Vec<String>>()
155 .join(", ");
156 let eos_render = eos_tok_ids
157 .iter()
158 .map(|val| format!("{val:?}"))
159 .collect::<Vec<String>>()
160 .join(", ");
161
162 info!(
163 "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
164 chat_template.unk_tok().unwrap_or("`None`".to_string()),
165 );
166
167 let mut eos_toks = Vec::new();
168 for eos_tok in eos_tok_ids {
169 eos_toks.push(
170 tokenizer
171 .get_vocab(true)
172 .get(&eos_tok)
173 .copied()
174 .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
175 )
176 }
177 eos_toks
178}
179
180#[allow(dead_code)]
181#[derive(Debug, Deserialize)]
182pub struct GenerationConfig {
183 #[serde(default)]
184 #[serde(with = "either::serde_untagged_optional")]
185 bos_token_id: Option<Either<u32, Vec<u32>>>,
186 #[serde(default)]
187 #[serde(with = "either::serde_untagged_optional")]
188 eos_token_id: Option<Either<u32, Vec<u32>>>,
189}
190
191fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
192 if let Ok(indent) = kwargs.get("indent") {
193 let mut buf = Vec::new();
194 let repeat = b" ".repeat(indent);
195 let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
196 let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
197 value.serialize(&mut ser).unwrap();
198 String::from_utf8(buf).map_err(|err| {
199 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
200 })
201 } else {
202 serde_json::to_string(&value).map_err(|err| {
203 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
204 })
205 }
206 .map_err(|err| {
207 Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
208 })
209 .map(|s| {
210 let mut rv = String::with_capacity(s.len());
212 for c in s.chars() {
213 match c {
214 '<' => rv.push_str("\\u003c"),
215 '>' => rv.push_str("\\u003e"),
216 '&' => rv.push_str("\\u0026"),
217 '\'' => rv.push_str("\\u0027"),
218 _ => rv.push(c),
219 }
220 }
221 Value::from_safe_string(rv)
222 })
223}
224
225fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
226 let date = chrono::Utc::now();
227 let date_string = date.format(&fmt).to_string();
228 Ok(date_string)
229}
230
231#[allow(clippy::too_many_arguments)]
232pub fn apply_chat_template_to(
233 messages: Vec<IndexMap<String, MessageContent>>,
234 add_generation_prompt: bool,
235 enable_thinking: Option<bool>,
236 template: &ChatTemplateValue,
237 bos_tok: Option<String>,
238 eos_tok: Option<String>,
239 unk_tok: Option<String>,
240 tools: Vec<Tool>,
241) -> Result<String> {
242 let mut env = Environment::new();
243
244 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
246
247 env.set_lstrip_blocks(true);
249 env.set_trim_blocks(true);
250
251 #[derive(Serialize, Deserialize)]
252 struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
253 let mut new_messages = Vec::new();
254 for message in messages {
255 let mut new_message = IndexMap::new();
256 for (k, v) in message {
257 new_message.insert(k, UntaggedContent(v));
258 }
259 new_messages.push(new_message);
260 }
261
262 let template = match &template.0 {
263 Either::Left(x) => x.clone(),
264 Either::Right(map) => {
265 let mut template = None;
266 let has_tool_use = map.iter().any(|t| {
267 t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
268 });
269 let must_use_tool_template = !tools.is_empty();
270
271 if must_use_tool_template && !has_tool_use {
272 anyhow::bail!(
273 "Tools were provided but this chat template does not handle tool usage"
274 );
275 }
276
277 for t in map {
278 let name = t.get("name");
279 if let Some(name) = name {
280 template = Some(t["template"].clone());
281 #[allow(clippy::if_same_then_else)]
282 if name == "tool_use" && !tools.is_empty() {
283 break;
284 } else if name == "default" && !must_use_tool_template {
285 break;
286 }
287 } else if t.contains_key("tool_use") && !tools.is_empty() {
288 template = Some(t["tool_use"].clone());
289 break;
290 } else if t.contains_key("default") && !must_use_tool_template {
291 template = Some(t["default"].clone());
292 break;
293 }
294 }
295
296 let Some(template) = template else {
297 anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
298 };
299 template
300 }
301 };
302 let mut template = template.replace("[::-1]", "|reverse");
303 let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
307 template = re
308 .replace_all(&template, |caps: ®ex::Captures| {
309 format!("range({})|reverse", &caps["expr"])
310 })
311 .into_owned();
312
313 if template.contains("{{ meta }}") {
314 template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
316 template = template.replace("{{ meta }}", "");
317 }
318 if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
319 template = template.replace("{% generation %}", "");
321 template = template.replace("{% endgeneration %}", "");
322 }
323
324 env.add_template("chat_template", &template)?;
325 env.add_function("raise_exception", raise_exception);
326 env.add_filter("tojson", tojson);
327 env.add_function("strftime_now", strftime_now);
328 let tmpl = env.get_template("chat_template").unwrap();
329
330 let date = chrono::Utc::now();
331 let date_string = date.format("%d, %B, %Y").to_string();
332
333 if tools.is_empty() {
334 Ok(tmpl.render(context! {
335 messages => new_messages,
336 add_generation_prompt => add_generation_prompt,
337 bos_token => bos_tok,
338 eos_token => eos_tok,
339 unk_token => unk_tok,
340 date_string => date_string,
341 enable_thinking => enable_thinking.unwrap_or(true),
342 })?)
343 } else {
344 Ok(tmpl.render(context! {
345 messages => new_messages,
346 add_generation_prompt => add_generation_prompt,
347 bos_token => bos_tok,
348 eos_token => eos_tok,
349 unk_token => unk_tok,
350 xml_tools => tools.clone(), tools => tools,
352 date_string => date_string,
353 enable_thinking => enable_thinking.unwrap_or(true),
354 })?)
355 }
356}