mistralrs_core/pipeline/
chat_template.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16    "<|im_end|>",      // Handle ChatML case
17    "<end_of_turn>",   // Handle Gemma2 chat case
18    "<|end_of_text|>", // Hermes
19];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24    __type: Option<String>,
25    pub content: String,
26    lstrip: bool,
27    normalized: bool,
28    rstrip: bool,
29    single_word: bool,
30    special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39    #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44    #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49/// Template for chat models including bos/eos/unk as well as the chat template.
50pub struct ChatTemplate {
51    add_bos_token: Option<bool>,
52    add_eos_token: Option<bool>,
53    added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54    additional_special_tokens: Option<Vec<String>>,
55    pub bos_token: Option<BeginEndUnkPadTok>,
56
57    /// Jinja format [chat templating] for chat completion.
58    ///
59    /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
60    pub chat_template: Option<ChatTemplateValue>,
61    clean_up_tokenization_spaces: Option<bool>,
62    device_map: Option<String>,
63    pub eos_token: Option<BeginEndUnkPadTok>,
64    legacy: Option<bool>,
65    model_max_length: Option<f64>,
66    pub pad_token: Option<BeginEndUnkPadTok>,
67    sp_model_kwargs: Option<HashMap<String, String>>,
68    spaces_between_special_tokens: Option<bool>,
69    tokenizer_class: Option<String>,
70    truncation_size: Option<String>,
71    pub unk_token: Option<BeginEndUnkPadTok>,
72    use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76    pub fn has_chat_template(&self) -> bool {
77        self.chat_template.is_some()
78    }
79
80    pub fn eos_tok(&self) -> Option<String> {
81        match self.eos_token.as_ref()?.0 {
82            Either::Left(ref lit) => Some(lit.clone()),
83            Either::Right(ref added) => Some(added.content.clone()),
84        }
85    }
86
87    pub fn bos_tok(&self) -> Option<String> {
88        match self.bos_token.as_ref()?.0 {
89            Either::Left(ref lit) => Some(lit.clone()),
90            Either::Right(ref added) => Some(added.content.clone()),
91        }
92    }
93
94    pub fn unk_tok(&self) -> Option<String> {
95        match self.unk_token.as_ref()?.0 {
96            Either::Left(ref lit) => Some(lit.clone()),
97            Either::Right(ref added) => Some(added.content.clone()),
98        }
99    }
100}
101
102pub fn calculate_eos_tokens(
103    chat_template: &ChatTemplate,
104    gen_conf: Option<GenerationConfig>,
105    tokenizer: &Tokenizer,
106) -> Vec<u32> {
107    let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
108    let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
109
110    for alternate in SUPPORTED_ALTERNATE_EOS {
111        if tokenizer.get_vocab(true).contains_key(*alternate) {
112            eos_tok_ids.push(alternate.to_string())
113        }
114    }
115
116    if let Some(gen_conf) = gen_conf {
117        if let Some(eos_field) = gen_conf.eos_token_id {
118            let ids = match eos_field {
119                Either::Left(id) => vec![id],
120                Either::Right(ids) => ids,
121            };
122            for id in ids {
123                let s = tokenizer
124                    .decode(&[id], false)
125                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
126                if !eos_tok_ids.contains(&s) {
127                    eos_tok_ids.push(s);
128                }
129            }
130        }
131
132        if let Some(bos_field) = gen_conf.bos_token_id {
133            let ids = match bos_field {
134                Either::Left(id) => vec![id],
135                Either::Right(ids) => ids,
136            };
137            for id in ids {
138                let s = tokenizer
139                    .decode(&[id], false)
140                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
141                if !bos_tok_ids.contains(&s) {
142                    bos_tok_ids.push(s);
143                }
144            }
145        }
146    }
147
148    eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
149    bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
150
151    let bos_render = bos_tok_ids
152        .iter()
153        .map(|val| format!("{val:?}"))
154        .collect::<Vec<String>>()
155        .join(", ");
156    let eos_render = eos_tok_ids
157        .iter()
158        .map(|val| format!("{val:?}"))
159        .collect::<Vec<String>>()
160        .join(", ");
161
162    info!(
163        "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
164        chat_template.unk_tok().unwrap_or("`None`".to_string()),
165    );
166
167    let mut eos_toks = Vec::new();
168    for eos_tok in eos_tok_ids {
169        eos_toks.push(
170            tokenizer
171                .get_vocab(true)
172                .get(&eos_tok)
173                .copied()
174                .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
175        )
176    }
177    eos_toks
178}
179
180#[allow(dead_code)]
181#[derive(Debug, Deserialize)]
182pub struct GenerationConfig {
183    #[serde(default)]
184    #[serde(with = "either::serde_untagged_optional")]
185    bos_token_id: Option<Either<u32, Vec<u32>>>,
186    #[serde(default)]
187    #[serde(with = "either::serde_untagged_optional")]
188    eos_token_id: Option<Either<u32, Vec<u32>>>,
189}
190
191fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
192    if let Ok(indent) = kwargs.get("indent") {
193        let mut buf = Vec::new();
194        let repeat = b" ".repeat(indent);
195        let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
196        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
197        value.serialize(&mut ser).unwrap();
198        String::from_utf8(buf).map_err(|err| {
199            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
200        })
201    } else {
202        serde_json::to_string(&value).map_err(|err| {
203            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
204        })
205    }
206    .map_err(|err| {
207        Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
208    })
209    .map(|s| {
210        // When this filter is used the return value is safe for both HTML and JSON
211        let mut rv = String::with_capacity(s.len());
212        for c in s.chars() {
213            match c {
214                '<' => rv.push_str("\\u003c"),
215                '>' => rv.push_str("\\u003e"),
216                '&' => rv.push_str("\\u0026"),
217                '\'' => rv.push_str("\\u0027"),
218                _ => rv.push(c),
219            }
220        }
221        Value::from_safe_string(rv)
222    })
223}
224
225fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
226    let date = chrono::Utc::now();
227    let date_string = date.format(&fmt).to_string();
228    Ok(date_string)
229}
230
231#[allow(clippy::too_many_arguments)]
232pub fn apply_chat_template_to(
233    messages: Vec<IndexMap<String, MessageContent>>,
234    add_generation_prompt: bool,
235    enable_thinking: Option<bool>,
236    template: &ChatTemplateValue,
237    bos_tok: Option<String>,
238    eos_tok: Option<String>,
239    unk_tok: Option<String>,
240    tools: Vec<Tool>,
241) -> Result<String> {
242    let mut env = Environment::new();
243
244    // enable python methods such as .strip()
245    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
246
247    // https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/tokenization_utils_base.py#L1842
248    env.set_lstrip_blocks(true);
249    env.set_trim_blocks(true);
250
251    #[derive(Serialize, Deserialize)]
252    struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
253    let mut new_messages = Vec::new();
254    for message in messages {
255        let mut new_message = IndexMap::new();
256        for (k, v) in message {
257            new_message.insert(k, UntaggedContent(v));
258        }
259        new_messages.push(new_message);
260    }
261
262    let template = match &template.0 {
263        Either::Left(x) => x.clone(),
264        Either::Right(map) => {
265            let mut template = None;
266            let has_tool_use = map.iter().any(|t| {
267                t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
268            });
269            let must_use_tool_template = !tools.is_empty();
270
271            if must_use_tool_template && !has_tool_use {
272                anyhow::bail!(
273                    "Tools were provided but this chat template does not handle tool usage"
274                );
275            }
276
277            for t in map {
278                let name = t.get("name");
279                if let Some(name) = name {
280                    template = Some(t["template"].clone());
281                    #[allow(clippy::if_same_then_else)]
282                    if name == "tool_use" && !tools.is_empty() {
283                        break;
284                    } else if name == "default" && !must_use_tool_template {
285                        break;
286                    }
287                } else if t.contains_key("tool_use") && !tools.is_empty() {
288                    template = Some(t["tool_use"].clone());
289                    break;
290                } else if t.contains_key("default") && !must_use_tool_template {
291                    template = Some(t["default"].clone());
292                    break;
293                }
294            }
295
296            let Some(template) = template else {
297                anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
298            };
299            template
300        }
301    };
302    let mut template = template.replace("[::-1]", "|reverse");
303    // Convert Python‑style descending ranges `range(..., -1, -1)` to a forward
304    // range followed by Jinja’s `|reverse` filter so it works even when
305    // negative‑step ranges aren’t supported.
306    let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
307    template = re
308        .replace_all(&template, |caps: &regex::Captures| {
309            format!("range({})|reverse", &caps["expr"])
310        })
311        .into_owned();
312
313    if template.contains("{{ meta }}") {
314        // Fix for GLM4 models
315        template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
316        template = template.replace("{{ meta }}", "");
317    }
318    if template.contains("{% generation %}") && template.contains("{% endgeneration %}") {
319        // Strip for smollm3 models
320        template = template.replace("{% generation %}", "");
321        template = template.replace("{% endgeneration %}", "");
322    }
323
324    env.add_template("chat_template", &template)?;
325    env.add_function("raise_exception", raise_exception);
326    env.add_filter("tojson", tojson);
327    env.add_function("strftime_now", strftime_now);
328    let tmpl = env.get_template("chat_template").unwrap();
329
330    let date = chrono::Utc::now();
331    let date_string = date.format("%d, %B, %Y").to_string();
332
333    if tools.is_empty() {
334        Ok(tmpl.render(context! {
335            messages => new_messages,
336            add_generation_prompt => add_generation_prompt,
337            bos_token => bos_tok,
338            eos_token => eos_tok,
339            unk_token => unk_tok,
340            date_string => date_string,
341            enable_thinking => enable_thinking.unwrap_or(true),
342        })?)
343    } else {
344        Ok(tmpl.render(context! {
345            messages => new_messages,
346            add_generation_prompt => add_generation_prompt,
347            bos_token => bos_tok,
348            eos_token => eos_tok,
349            unk_token => unk_tok,
350            xml_tools => tools.clone(), // SmolLM3
351            tools => tools,
352            date_string => date_string,
353            enable_thinking => enable_thinking.unwrap_or(true),
354        })?)
355    }
356}