1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16 "<|im_end|>", "<end_of_turn>", "<|end_of_text|>", ];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24 __type: Option<String>,
25 pub content: String,
26 lstrip: bool,
27 normalized: bool,
28 rstrip: bool,
29 single_word: bool,
30 special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34 Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39 #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44 #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49pub struct ChatTemplate {
51 add_bos_token: Option<bool>,
52 add_eos_token: Option<bool>,
53 added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54 additional_special_tokens: Option<Vec<String>>,
55 pub bos_token: Option<BeginEndUnkPadTok>,
56
57 pub chat_template: Option<ChatTemplateValue>,
61 clean_up_tokenization_spaces: Option<bool>,
62 device_map: Option<String>,
63 pub eos_token: Option<BeginEndUnkPadTok>,
64 legacy: Option<bool>,
65 model_max_length: Option<f64>,
66 pub pad_token: Option<BeginEndUnkPadTok>,
67 sp_model_kwargs: Option<HashMap<String, String>>,
68 spaces_between_special_tokens: Option<bool>,
69 tokenizer_class: Option<String>,
70 truncation_size: Option<String>,
71 pub unk_token: Option<BeginEndUnkPadTok>,
72 use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76 pub fn has_chat_template(&self) -> bool {
77 self.chat_template.is_some()
78 }
79
80 pub fn is_harmony_format(&self) -> bool {
82 if let Some(ref template_value) = self.chat_template {
83 let template_str = match &template_value.0 {
84 Either::Left(s) => s.as_str(),
85 Either::Right(vec) => {
86 return vec
88 .iter()
89 .any(|t| t.values().any(|v| crate::harmony::is_harmony_template(v)));
90 }
91 };
92 crate::harmony::is_harmony_template(template_str)
93 } else {
94 false
95 }
96 }
97
98 pub fn eos_tok(&self) -> Option<String> {
99 match self.eos_token.as_ref()?.0 {
100 Either::Left(ref lit) => Some(lit.clone()),
101 Either::Right(ref added) => Some(added.content.clone()),
102 }
103 }
104
105 pub fn bos_tok(&self) -> Option<String> {
106 match self.bos_token.as_ref()?.0 {
107 Either::Left(ref lit) => Some(lit.clone()),
108 Either::Right(ref added) => Some(added.content.clone()),
109 }
110 }
111
112 pub fn unk_tok(&self) -> Option<String> {
113 match self.unk_token.as_ref()?.0 {
114 Either::Left(ref lit) => Some(lit.clone()),
115 Either::Right(ref added) => Some(added.content.clone()),
116 }
117 }
118}
119
120pub fn calculate_eos_tokens(
121 chat_template: &ChatTemplate,
122 gen_conf: Option<GenerationConfig>,
123 tokenizer: &Tokenizer,
124) -> Vec<u32> {
125 let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
126 let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
127
128 for alternate in SUPPORTED_ALTERNATE_EOS {
129 if tokenizer.get_vocab(true).contains_key(*alternate) {
130 eos_tok_ids.push(alternate.to_string())
131 }
132 }
133
134 if let Some(gen_conf) = gen_conf {
135 if let Some(eos_field) = gen_conf.eos_token_id {
136 let ids = match eos_field {
137 Either::Left(id) => vec![id],
138 Either::Right(ids) => ids,
139 };
140 for id in ids {
141 let s = tokenizer
142 .decode(&[id], false)
143 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
144 if !eos_tok_ids.contains(&s) {
145 eos_tok_ids.push(s);
146 }
147 }
148 }
149
150 if let Some(bos_field) = gen_conf.bos_token_id {
151 let ids = match bos_field {
152 Either::Left(id) => vec![id],
153 Either::Right(ids) => ids,
154 };
155 for id in ids {
156 let s = tokenizer
157 .decode(&[id], false)
158 .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
159 if !bos_tok_ids.contains(&s) {
160 bos_tok_ids.push(s);
161 }
162 }
163 }
164 }
165
166 eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
167 bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
168
169 let bos_render = bos_tok_ids
170 .iter()
171 .map(|val| format!("{val:?}"))
172 .collect::<Vec<String>>()
173 .join(", ");
174 let eos_render = eos_tok_ids
175 .iter()
176 .map(|val| format!("{val:?}"))
177 .collect::<Vec<String>>()
178 .join(", ");
179
180 info!(
181 "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
182 chat_template.unk_tok().unwrap_or("`None`".to_string()),
183 );
184
185 let mut eos_toks = Vec::new();
186 for eos_tok in eos_tok_ids {
187 eos_toks.push(
188 tokenizer
189 .get_vocab(true)
190 .get(&eos_tok)
191 .copied()
192 .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
193 )
194 }
195 eos_toks
196}
197
198#[allow(dead_code)]
199#[derive(Debug, Deserialize)]
200pub struct GenerationConfig {
201 #[serde(default)]
202 #[serde(with = "either::serde_untagged_optional")]
203 bos_token_id: Option<Either<u32, Vec<u32>>>,
204 #[serde(default)]
205 #[serde(with = "either::serde_untagged_optional")]
206 eos_token_id: Option<Either<u32, Vec<u32>>>,
207}
208
209fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
210 if let Ok(indent) = kwargs.get("indent") {
211 let mut buf = Vec::new();
212 let repeat = b" ".repeat(indent);
213 let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
214 let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
215 value.serialize(&mut ser).unwrap();
216 String::from_utf8(buf).map_err(|err| {
217 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
218 })
219 } else {
220 serde_json::to_string(&value).map_err(|err| {
221 Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
222 })
223 }
224 .map_err(|err| {
225 Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
226 })
227 .map(|s| {
228 let mut rv = String::with_capacity(s.len());
230 for c in s.chars() {
231 match c {
232 '<' => rv.push_str("\\u003c"),
233 '>' => rv.push_str("\\u003e"),
234 '&' => rv.push_str("\\u0026"),
235 '\'' => rv.push_str("\\u0027"),
236 _ => rv.push(c),
237 }
238 }
239 Value::from_safe_string(rv)
240 })
241}
242
243fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
244 let date = chrono::Utc::now();
245 let date_string = date.format(&fmt).to_string();
246 Ok(date_string)
247}
248
249use crate::request::ReasoningEffort;
250
251#[allow(clippy::too_many_arguments)]
252pub fn apply_chat_template_to(
253 messages: Vec<IndexMap<String, MessageContent>>,
254 add_generation_prompt: bool,
255 enable_thinking: Option<bool>,
256 reasoning_effort: Option<ReasoningEffort>,
257 template: &ChatTemplateValue,
258 bos_tok: Option<String>,
259 eos_tok: Option<String>,
260 unk_tok: Option<String>,
261 tools: Vec<Tool>,
262) -> Result<String> {
263 let mut env = Environment::new();
264
265 env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
267
268 env.set_lstrip_blocks(true);
270 env.set_trim_blocks(true);
271
272 #[derive(Serialize, Deserialize)]
273 struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
274 let mut new_messages = Vec::new();
275 for message in messages {
276 let mut new_message = IndexMap::new();
277 for (k, v) in message {
278 new_message.insert(k, UntaggedContent(v));
279 }
280 new_messages.push(new_message);
281 }
282
283 let template = match &template.0 {
284 Either::Left(x) => x.clone(),
285 Either::Right(map) => {
286 let mut template = None;
287 let has_tool_use = map.iter().any(|t| {
288 t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
289 });
290 let must_use_tool_template = !tools.is_empty();
291
292 if must_use_tool_template && !has_tool_use {
293 anyhow::bail!(
294 "Tools were provided but this chat template does not handle tool usage"
295 );
296 }
297
298 for t in map {
299 let name = t.get("name");
300 if let Some(name) = name {
301 template = Some(t["template"].clone());
302 #[allow(clippy::if_same_then_else)]
303 if name == "tool_use" && !tools.is_empty() {
304 break;
305 } else if name == "default" && !must_use_tool_template {
306 break;
307 }
308 } else if t.contains_key("tool_use") && !tools.is_empty() {
309 template = Some(t["tool_use"].clone());
310 break;
311 } else if t.contains_key("default") && !must_use_tool_template {
312 template = Some(t["default"].clone());
313 break;
314 }
315 }
316
317 let Some(template) = template else {
318 anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
319 };
320 template
321 }
322 };
323 let mut template = template.replace("[::-1]", "|reverse");
324 let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
328 template = re
329 .replace_all(&template, |caps: ®ex::Captures| {
330 format!("range({})|reverse", &caps["expr"])
331 })
332 .into_owned();
333
334 if template.contains("{{ meta }}") {
335 template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
337 template = template.replace("{{ meta }}", "");
338 }
339 if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
340 template = template.replace("{% generation %}", "");
342 template = template.replace("{% endgeneration %}", "");
343 }
344
345 env.add_template("chat_template", &template)?;
346 env.add_function("raise_exception", raise_exception);
347 env.add_filter("tojson", tojson);
348 env.add_function("strftime_now", strftime_now);
349 let tmpl = env.get_template("chat_template").unwrap();
350
351 let date = chrono::Utc::now();
352 let date_string = date.format("%d, %B, %Y").to_string();
353
354 let reasoning_effort_str = reasoning_effort.map(|r| r.as_str()).unwrap_or("medium");
356
357 let builtin_tool_names = [
361 "browser",
362 "python",
363 "code_interpreter",
364 "web_search",
365 "brave_search",
366 "wolfram_alpha",
367 ];
368 let builtin_tools: Vec<&str> = tools
369 .iter()
370 .filter_map(|t| {
371 let name = t.function.name.as_str();
372 if builtin_tool_names.contains(&name) {
373 Some(name)
374 } else {
375 None
376 }
377 })
378 .collect();
379
380 if tools.is_empty() {
381 Ok(tmpl.render(context! {
382 messages => new_messages,
383 add_generation_prompt => add_generation_prompt,
384 bos_token => bos_tok,
385 eos_token => eos_tok,
386 unk_token => unk_tok,
387 date_string => date_string,
388 enable_thinking => enable_thinking.unwrap_or(true),
389 reasoning_effort => reasoning_effort_str,
390 })?)
391 } else {
392 Ok(tmpl.render(context! {
393 messages => new_messages,
394 add_generation_prompt => add_generation_prompt,
395 bos_token => bos_tok,
396 eos_token => eos_tok,
397 unk_token => unk_tok,
398 xml_tools => tools.clone(), tools => tools,
400 builtin_tools => builtin_tools,
401 date_string => date_string,
402 enable_thinking => enable_thinking.unwrap_or(true),
403 reasoning_effort => reasoning_effort_str,
404 })?)
405 }
406}