mistralrs_core/pipeline/
chat_template.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16    "<|im_end|>",      // Handle ChatML case
17    "<end_of_turn>",   // Handle Gemma2 chat case
18    "<|end_of_text|>", // Hermes
19];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24    __type: Option<String>,
25    pub content: String,
26    lstrip: bool,
27    normalized: bool,
28    rstrip: bool,
29    single_word: bool,
30    special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39    #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44    #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49/// Template for chat models including bos/eos/unk as well as the chat template.
50pub struct ChatTemplate {
51    add_bos_token: Option<bool>,
52    add_eos_token: Option<bool>,
53    added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54    additional_special_tokens: Option<Vec<String>>,
55    pub bos_token: Option<BeginEndUnkPadTok>,
56
57    /// Jinja format [chat templating] for chat completion.
58    ///
59    /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
60    pub chat_template: Option<ChatTemplateValue>,
61    clean_up_tokenization_spaces: Option<bool>,
62    device_map: Option<String>,
63    pub eos_token: Option<BeginEndUnkPadTok>,
64    legacy: Option<bool>,
65    model_max_length: Option<f64>,
66    pub pad_token: Option<BeginEndUnkPadTok>,
67    sp_model_kwargs: Option<HashMap<String, String>>,
68    spaces_between_special_tokens: Option<bool>,
69    tokenizer_class: Option<String>,
70    truncation_size: Option<String>,
71    pub unk_token: Option<BeginEndUnkPadTok>,
72    use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76    pub fn has_chat_template(&self) -> bool {
77        self.chat_template.is_some()
78    }
79
80    /// Check if this chat template uses OpenAI Harmony format.
81    pub fn is_harmony_format(&self) -> bool {
82        if let Some(ref template_value) = self.chat_template {
83            let template_str = match &template_value.0 {
84                Either::Left(s) => s.as_str(),
85                Either::Right(vec) => {
86                    // For multi-template format, check if any template contains Harmony markers
87                    return vec
88                        .iter()
89                        .any(|t| t.values().any(|v| crate::harmony::is_harmony_template(v)));
90                }
91            };
92            crate::harmony::is_harmony_template(template_str)
93        } else {
94            false
95        }
96    }
97
98    pub fn eos_tok(&self) -> Option<String> {
99        match self.eos_token.as_ref()?.0 {
100            Either::Left(ref lit) => Some(lit.clone()),
101            Either::Right(ref added) => Some(added.content.clone()),
102        }
103    }
104
105    pub fn bos_tok(&self) -> Option<String> {
106        match self.bos_token.as_ref()?.0 {
107            Either::Left(ref lit) => Some(lit.clone()),
108            Either::Right(ref added) => Some(added.content.clone()),
109        }
110    }
111
112    pub fn unk_tok(&self) -> Option<String> {
113        match self.unk_token.as_ref()?.0 {
114            Either::Left(ref lit) => Some(lit.clone()),
115            Either::Right(ref added) => Some(added.content.clone()),
116        }
117    }
118}
119
120pub fn calculate_eos_tokens(
121    chat_template: &ChatTemplate,
122    gen_conf: Option<GenerationConfig>,
123    tokenizer: &Tokenizer,
124) -> Vec<u32> {
125    let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
126    let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
127
128    for alternate in SUPPORTED_ALTERNATE_EOS {
129        if tokenizer.get_vocab(true).contains_key(*alternate) {
130            eos_tok_ids.push(alternate.to_string())
131        }
132    }
133
134    if let Some(gen_conf) = gen_conf {
135        if let Some(eos_field) = gen_conf.eos_token_id {
136            let ids = match eos_field {
137                Either::Left(id) => vec![id],
138                Either::Right(ids) => ids,
139            };
140            for id in ids {
141                let s = tokenizer
142                    .decode(&[id], false)
143                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
144                if !eos_tok_ids.contains(&s) {
145                    eos_tok_ids.push(s);
146                }
147            }
148        }
149
150        if let Some(bos_field) = gen_conf.bos_token_id {
151            let ids = match bos_field {
152                Either::Left(id) => vec![id],
153                Either::Right(ids) => ids,
154            };
155            for id in ids {
156                let s = tokenizer
157                    .decode(&[id], false)
158                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
159                if !bos_tok_ids.contains(&s) {
160                    bos_tok_ids.push(s);
161                }
162            }
163        }
164    }
165
166    eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
167    bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
168
169    let bos_render = bos_tok_ids
170        .iter()
171        .map(|val| format!("{val:?}"))
172        .collect::<Vec<String>>()
173        .join(", ");
174    let eos_render = eos_tok_ids
175        .iter()
176        .map(|val| format!("{val:?}"))
177        .collect::<Vec<String>>()
178        .join(", ");
179
180    info!(
181        "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
182        chat_template.unk_tok().unwrap_or("`None`".to_string()),
183    );
184
185    let mut eos_toks = Vec::new();
186    for eos_tok in eos_tok_ids {
187        eos_toks.push(
188            tokenizer
189                .get_vocab(true)
190                .get(&eos_tok)
191                .copied()
192                .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
193        )
194    }
195    eos_toks
196}
197
198#[allow(dead_code)]
199#[derive(Debug, Deserialize)]
200pub struct GenerationConfig {
201    #[serde(default)]
202    #[serde(with = "either::serde_untagged_optional")]
203    bos_token_id: Option<Either<u32, Vec<u32>>>,
204    #[serde(default)]
205    #[serde(with = "either::serde_untagged_optional")]
206    eos_token_id: Option<Either<u32, Vec<u32>>>,
207}
208
209fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
210    if let Ok(indent) = kwargs.get("indent") {
211        let mut buf = Vec::new();
212        let repeat = b" ".repeat(indent);
213        let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
214        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
215        value.serialize(&mut ser).unwrap();
216        String::from_utf8(buf).map_err(|err| {
217            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
218        })
219    } else {
220        serde_json::to_string(&value).map_err(|err| {
221            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
222        })
223    }
224    .map_err(|err| {
225        Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
226    })
227    .map(|s| {
228        // When this filter is used the return value is safe for both HTML and JSON
229        let mut rv = String::with_capacity(s.len());
230        for c in s.chars() {
231            match c {
232                '<' => rv.push_str("\\u003c"),
233                '>' => rv.push_str("\\u003e"),
234                '&' => rv.push_str("\\u0026"),
235                '\'' => rv.push_str("\\u0027"),
236                _ => rv.push(c),
237            }
238        }
239        Value::from_safe_string(rv)
240    })
241}
242
243fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
244    let date = chrono::Utc::now();
245    let date_string = date.format(&fmt).to_string();
246    Ok(date_string)
247}
248
249use crate::request::ReasoningEffort;
250
251#[allow(clippy::too_many_arguments)]
252pub fn apply_chat_template_to(
253    messages: Vec<IndexMap<String, MessageContent>>,
254    add_generation_prompt: bool,
255    enable_thinking: Option<bool>,
256    reasoning_effort: Option<ReasoningEffort>,
257    template: &ChatTemplateValue,
258    bos_tok: Option<String>,
259    eos_tok: Option<String>,
260    unk_tok: Option<String>,
261    tools: Vec<Tool>,
262) -> Result<String> {
263    let mut env = Environment::new();
264
265    // enable python methods such as .strip()
266    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
267
268    // https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/tokenization_utils_base.py#L1842
269    env.set_lstrip_blocks(true);
270    env.set_trim_blocks(true);
271
272    #[derive(Serialize, Deserialize)]
273    struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
274    let mut new_messages = Vec::new();
275    for message in messages {
276        let mut new_message = IndexMap::new();
277        for (k, v) in message {
278            new_message.insert(k, UntaggedContent(v));
279        }
280        new_messages.push(new_message);
281    }
282
283    let template = match &template.0 {
284        Either::Left(x) => x.clone(),
285        Either::Right(map) => {
286            let mut template = None;
287            let has_tool_use = map.iter().any(|t| {
288                t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
289            });
290            let must_use_tool_template = !tools.is_empty();
291
292            if must_use_tool_template && !has_tool_use {
293                anyhow::bail!(
294                    "Tools were provided but this chat template does not handle tool usage"
295                );
296            }
297
298            for t in map {
299                let name = t.get("name");
300                if let Some(name) = name {
301                    template = Some(t["template"].clone());
302                    #[allow(clippy::if_same_then_else)]
303                    if name == "tool_use" && !tools.is_empty() {
304                        break;
305                    } else if name == "default" && !must_use_tool_template {
306                        break;
307                    }
308                } else if t.contains_key("tool_use") && !tools.is_empty() {
309                    template = Some(t["tool_use"].clone());
310                    break;
311                } else if t.contains_key("default") && !must_use_tool_template {
312                    template = Some(t["default"].clone());
313                    break;
314                }
315            }
316
317            let Some(template) = template else {
318                anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
319            };
320            template
321        }
322    };
323    let mut template = template.replace("[::-1]", "|reverse");
324    // Convert Python‑style descending ranges `range(..., -1, -1)` to a forward
325    // range followed by Jinja’s `|reverse` filter so it works even when
326    // negative‑step ranges aren’t supported.
327    let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
328    template = re
329        .replace_all(&template, |caps: &regex::Captures| {
330            format!("range({})|reverse", &caps["expr"])
331        })
332        .into_owned();
333
334    if template.contains("{{ meta }}") {
335        // Fix for GLM4 models
336        template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
337        template = template.replace("{{ meta }}", "");
338    }
339    if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
340        // Strip for smollm3 models
341        template = template.replace("{% generation %}", "");
342        template = template.replace("{% endgeneration %}", "");
343    }
344
345    env.add_template("chat_template", &template)?;
346    env.add_function("raise_exception", raise_exception);
347    env.add_filter("tojson", tojson);
348    env.add_function("strftime_now", strftime_now);
349    let tmpl = env.get_template("chat_template").unwrap();
350
351    let date = chrono::Utc::now();
352    let date_string = date.format("%d, %B, %Y").to_string();
353
354    // Convert reasoning effort to string for template
355    let reasoning_effort_str = reasoning_effort.map(|r| r.as_str()).unwrap_or("medium");
356
357    // Detect builtin tools from the tools list
358    // Known builtin tools for GPT-OSS/Harmony format: "browser", "python"
359    // Known builtin tools for Llama 3.x: "wolfram_alpha", "web_search", "brave_search", "python", "code_interpreter"
360    let builtin_tool_names = [
361        "browser",
362        "python",
363        "code_interpreter",
364        "web_search",
365        "brave_search",
366        "wolfram_alpha",
367    ];
368    let builtin_tools: Vec<&str> = tools
369        .iter()
370        .filter_map(|t| {
371            let name = t.function.name.as_str();
372            if builtin_tool_names.contains(&name) {
373                Some(name)
374            } else {
375                None
376            }
377        })
378        .collect();
379
380    if tools.is_empty() {
381        Ok(tmpl.render(context! {
382            messages => new_messages,
383            add_generation_prompt => add_generation_prompt,
384            bos_token => bos_tok,
385            eos_token => eos_tok,
386            unk_token => unk_tok,
387            date_string => date_string,
388            enable_thinking => enable_thinking.unwrap_or(true),
389            reasoning_effort => reasoning_effort_str,
390        })?)
391    } else {
392        Ok(tmpl.render(context! {
393            messages => new_messages,
394            add_generation_prompt => add_generation_prompt,
395            bos_token => bos_tok,
396            eos_token => eos_tok,
397            unk_token => unk_tok,
398            xml_tools => tools.clone(), // SmolLM3
399            tools => tools,
400            builtin_tools => builtin_tools,
401            date_string => date_string,
402            enable_thinking => enable_thinking.unwrap_or(true),
403            reasoning_effort => reasoning_effort_str,
404        })?)
405    }
406}