mistralrs_core/pipeline/
chat_template.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use tokenizers::Tokenizer;
11use tracing::info;
12
13use crate::{MessageContent, Tool};
14
15const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
16    "<|im_end|>",      // Handle ChatML case
17    "<end_of_turn>",   // Handle Gemma2 chat case
18    "<|end_of_text|>", // Hermes
19];
20
21#[allow(dead_code)]
22#[derive(Debug, Deserialize)]
23pub struct AddedTokensDecoder {
24    __type: Option<String>,
25    pub content: String,
26    lstrip: bool,
27    normalized: bool,
28    rstrip: bool,
29    single_word: bool,
30    special: Option<bool>,
31}
32
33fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
34    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
35}
36
37#[derive(Debug, Deserialize)]
38pub struct BeginEndUnkPadTok(
39    #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
40);
41
42#[derive(Debug, Deserialize)]
43pub struct ChatTemplateValue(
44    #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
45);
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize, Default)]
49/// Template for chat models including bos/eos/unk as well as the chat template.
50pub struct ChatTemplate {
51    add_bos_token: Option<bool>,
52    add_eos_token: Option<bool>,
53    added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
54    additional_special_tokens: Option<Vec<String>>,
55    pub bos_token: Option<BeginEndUnkPadTok>,
56
57    /// Jinja format [chat templating] for chat completion.
58    ///
59    /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
60    pub chat_template: Option<ChatTemplateValue>,
61    clean_up_tokenization_spaces: Option<bool>,
62    device_map: Option<String>,
63    pub eos_token: Option<BeginEndUnkPadTok>,
64    legacy: Option<bool>,
65    model_max_length: Option<f64>,
66    pub pad_token: Option<BeginEndUnkPadTok>,
67    sp_model_kwargs: Option<HashMap<String, String>>,
68    spaces_between_special_tokens: Option<bool>,
69    tokenizer_class: Option<String>,
70    truncation_size: Option<String>,
71    pub unk_token: Option<BeginEndUnkPadTok>,
72    use_default_system_prompt: Option<bool>,
73}
74
75impl ChatTemplate {
76    pub fn has_chat_template(&self) -> bool {
77        self.chat_template.is_some()
78    }
79
80    pub fn eos_tok(&self) -> Option<String> {
81        match self.eos_token.as_ref()?.0 {
82            Either::Left(ref lit) => Some(lit.clone()),
83            Either::Right(ref added) => Some(added.content.clone()),
84        }
85    }
86
87    pub fn bos_tok(&self) -> Option<String> {
88        match self.bos_token.as_ref()?.0 {
89            Either::Left(ref lit) => Some(lit.clone()),
90            Either::Right(ref added) => Some(added.content.clone()),
91        }
92    }
93
94    pub fn unk_tok(&self) -> Option<String> {
95        match self.unk_token.as_ref()?.0 {
96            Either::Left(ref lit) => Some(lit.clone()),
97            Either::Right(ref added) => Some(added.content.clone()),
98        }
99    }
100}
101
102pub fn calculate_eos_tokens(
103    chat_template: &ChatTemplate,
104    gen_conf: Option<GenerationConfig>,
105    tokenizer: &Tokenizer,
106) -> Vec<u32> {
107    let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
108    let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
109
110    for alternate in SUPPORTED_ALTERNATE_EOS {
111        if tokenizer.get_vocab(true).contains_key(*alternate) {
112            eos_tok_ids.push(alternate.to_string())
113        }
114    }
115
116    if let Some(gen_conf) = gen_conf {
117        if let Some(eos_field) = gen_conf.eos_token_id {
118            let ids = match eos_field {
119                Either::Left(id) => vec![id],
120                Either::Right(ids) => ids,
121            };
122            for id in ids {
123                let s = tokenizer
124                    .decode(&[id], false)
125                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
126                if !eos_tok_ids.contains(&s) {
127                    eos_tok_ids.push(s);
128                }
129            }
130        }
131
132        if let Some(bos_field) = gen_conf.bos_token_id {
133            let ids = match bos_field {
134                Either::Left(id) => vec![id],
135                Either::Right(ids) => ids,
136            };
137            for id in ids {
138                let s = tokenizer
139                    .decode(&[id], false)
140                    .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
141                if !bos_tok_ids.contains(&s) {
142                    bos_tok_ids.push(s);
143                }
144            }
145        }
146    }
147
148    eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
149    bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
150
151    let bos_render = bos_tok_ids
152        .iter()
153        .map(|val| format!("{val:?}"))
154        .collect::<Vec<String>>()
155        .join(", ");
156    let eos_render = eos_tok_ids
157        .iter()
158        .map(|val| format!("{val:?}"))
159        .collect::<Vec<String>>()
160        .join(", ");
161
162    info!(
163        "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
164        chat_template.unk_tok().unwrap_or("`None`".to_string()),
165    );
166
167    let mut eos_toks = Vec::new();
168    for eos_tok in eos_tok_ids {
169        eos_toks.push(
170            tokenizer
171                .get_vocab(true)
172                .get(&eos_tok)
173                .copied()
174                .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
175        )
176    }
177    eos_toks
178}
179
180#[allow(dead_code)]
181#[derive(Debug, Deserialize)]
182pub struct GenerationConfig {
183    #[serde(with = "either::serde_untagged_optional")]
184    bos_token_id: Option<Either<u32, Vec<u32>>>,
185    #[serde(with = "either::serde_untagged_optional")]
186    eos_token_id: Option<Either<u32, Vec<u32>>>,
187}
188
189fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
190    if let Ok(indent) = kwargs.get("indent") {
191        let mut buf = Vec::new();
192        let repeat = b" ".repeat(indent);
193        let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
194        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
195        value.serialize(&mut ser).unwrap();
196        String::from_utf8(buf).map_err(|err| {
197            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
198        })
199    } else {
200        serde_json::to_string(&value).map_err(|err| {
201            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
202        })
203    }
204    .map_err(|err| {
205        Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
206    })
207    .map(|s| {
208        // When this filter is used the return value is safe for both HTML and JSON
209        let mut rv = String::with_capacity(s.len());
210        for c in s.chars() {
211            match c {
212                '<' => rv.push_str("\\u003c"),
213                '>' => rv.push_str("\\u003e"),
214                '&' => rv.push_str("\\u0026"),
215                '\'' => rv.push_str("\\u0027"),
216                _ => rv.push(c),
217            }
218        }
219        Value::from_safe_string(rv)
220    })
221}
222
223fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
224    let date = chrono::Utc::now();
225    let date_string = date.format(&fmt).to_string();
226    Ok(date_string)
227}
228
229#[allow(clippy::too_many_arguments)]
230pub fn apply_chat_template_to(
231    messages: Vec<IndexMap<String, MessageContent>>,
232    add_generation_prompt: bool,
233    enable_thinking: Option<bool>,
234    template: &ChatTemplateValue,
235    bos_tok: Option<String>,
236    eos_tok: Option<String>,
237    unk_tok: Option<String>,
238    tools: Vec<Tool>,
239) -> Result<String> {
240    let mut env = Environment::new();
241
242    // enable python methods such as .strip()
243    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
244
245    // https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/tokenization_utils_base.py#L1842
246    env.set_lstrip_blocks(true);
247    env.set_trim_blocks(true);
248
249    #[derive(Serialize, Deserialize)]
250    struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
251    let mut new_messages = Vec::new();
252    for message in messages {
253        let mut new_message = IndexMap::new();
254        for (k, v) in message {
255            new_message.insert(k, UntaggedContent(v));
256        }
257        new_messages.push(new_message);
258    }
259
260    let template = match &template.0 {
261        Either::Left(x) => x.clone(),
262        Either::Right(map) => {
263            let mut template = None;
264            let has_tool_use = map.iter().any(|t| {
265                t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
266            });
267            let must_use_tool_template = !tools.is_empty();
268
269            if must_use_tool_template && !has_tool_use {
270                anyhow::bail!(
271                    "Tools were provided but this chat template does not handle tool usage"
272                );
273            }
274
275            for t in map {
276                let name = t.get("name");
277                if let Some(name) = name {
278                    template = Some(t["template"].clone());
279                    #[allow(clippy::if_same_then_else)]
280                    if name == "tool_use" && !tools.is_empty() {
281                        break;
282                    } else if name == "default" && !must_use_tool_template {
283                        break;
284                    }
285                } else if t.contains_key("tool_use") && !tools.is_empty() {
286                    template = Some(t["tool_use"].clone());
287                    break;
288                } else if t.contains_key("default") && !must_use_tool_template {
289                    template = Some(t["default"].clone());
290                    break;
291                }
292            }
293
294            let Some(template) = template else {
295                anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
296            };
297            template
298        }
299    };
300    let mut template = template.replace("[::-1]", "|reverse");
301    // Convert Python‑style descending ranges `range(..., -1, -1)` to a forward
302    // range followed by Jinja’s `|reverse` filter so it works even when
303    // negative‑step ranges aren’t supported.
304    let re = Regex::new(r"range\((?P<expr>[^,]+),\s*-1,\s*-1\)").unwrap();
305    template = re
306        .replace_all(&template, |caps: &regex::Captures| {
307            format!("range({})|reverse", &caps["expr"])
308        })
309        .into_owned();
310
311    if template.contains("{{ meta }}") {
312        //fix for GLM4 models
313        template = template.replace("{%- set meta = message.get(\"metadata\", \"\") %}", "");
314        template = template.replace("{{ meta }}", "");
315    }
316
317    env.add_template("chat_template", &template)?;
318    env.add_function("raise_exception", raise_exception);
319    env.add_filter("tojson", tojson);
320    env.add_function("strftime_now", strftime_now);
321    let tmpl = env.get_template("chat_template").unwrap();
322
323    let date = chrono::Utc::now();
324    let date_string = date.format("%d, %B, %Y").to_string();
325
326    if tools.is_empty() {
327        Ok(tmpl.render(context! {
328            messages => new_messages,
329            add_generation_prompt => add_generation_prompt,
330            bos_token => bos_tok,
331            eos_token => eos_tok,
332            unk_token => unk_tok,
333            date_string => date_string,
334            enable_thinking => enable_thinking,
335        })?)
336    } else {
337        Ok(tmpl.render(context! {
338            messages => new_messages,
339            add_generation_prompt => add_generation_prompt,
340            bos_token => bos_tok,
341            eos_token => eos_tok,
342            unk_token => unk_tok,
343            tools => tools,
344            date_string => date_string,
345            enable_thinking => enable_thinking,
346        })?)
347    }
348}