mistralrs_core/pipeline/
chat_template.rs

1use std::collections::HashMap;
2
3use anyhow::Result;
4use either::Either;
5use indexmap::IndexMap;
6use itertools::Itertools;
7use minijinja::{context, value::Kwargs, Environment, Error, ErrorKind, Value};
8use serde::{Deserialize, Serialize};
9use tokenizers::Tokenizer;
10use tracing::info;
11
12use crate::{MessageContent, Tool};
13
14const SUPPORTED_ALTERNATE_EOS: &[&str] = &[
15    "<|im_end|>",      // Handle ChatML case
16    "<end_of_turn>",   // Handle Gemma2 chat case
17    "<|end_of_text|>", // Hermes
18];
19
20#[allow(dead_code)]
21#[derive(Debug, Deserialize)]
22pub struct AddedTokensDecoder {
23    __type: Option<String>,
24    pub content: String,
25    lstrip: bool,
26    normalized: bool,
27    rstrip: bool,
28    single_word: bool,
29    special: Option<bool>,
30}
31
32fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
33    Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
34}
35
36#[derive(Debug, Deserialize)]
37pub struct BeginEndUnkPadTok(
38    #[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
39);
40
41#[derive(Debug, Deserialize)]
42pub struct ChatTemplateValue(
43    #[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
44);
45
46#[allow(dead_code)]
47#[derive(Debug, Deserialize, Default)]
48/// Template for chat models including bos/eos/unk as well as the chat template.
49pub struct ChatTemplate {
50    add_bos_token: Option<bool>,
51    add_eos_token: Option<bool>,
52    added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
53    additional_special_tokens: Option<Vec<String>>,
54    pub bos_token: Option<BeginEndUnkPadTok>,
55
56    /// Jinja format [chat templating] for chat completion.
57    ///
58    /// [chat templating]: https://huggingface.co/docs/transformers/chat_templating
59    pub chat_template: Option<ChatTemplateValue>,
60    clean_up_tokenization_spaces: Option<bool>,
61    device_map: Option<String>,
62    pub eos_token: Option<BeginEndUnkPadTok>,
63    legacy: Option<bool>,
64    model_max_length: Option<f64>,
65    pub pad_token: Option<BeginEndUnkPadTok>,
66    sp_model_kwargs: Option<HashMap<String, String>>,
67    spaces_between_special_tokens: Option<bool>,
68    tokenizer_class: Option<String>,
69    truncation_size: Option<String>,
70    pub unk_token: Option<BeginEndUnkPadTok>,
71    use_default_system_prompt: Option<bool>,
72}
73
74impl ChatTemplate {
75    pub fn has_chat_template(&self) -> bool {
76        self.chat_template.is_some()
77    }
78
79    pub fn eos_tok(&self) -> Option<String> {
80        match self.eos_token.as_ref()?.0 {
81            Either::Left(ref lit) => Some(lit.clone()),
82            Either::Right(ref added) => Some(added.content.clone()),
83        }
84    }
85
86    pub fn bos_tok(&self) -> Option<String> {
87        match self.bos_token.as_ref()?.0 {
88            Either::Left(ref lit) => Some(lit.clone()),
89            Either::Right(ref added) => Some(added.content.clone()),
90        }
91    }
92
93    pub fn unk_tok(&self) -> Option<String> {
94        match self.unk_token.as_ref()?.0 {
95            Either::Left(ref lit) => Some(lit.clone()),
96            Either::Right(ref added) => Some(added.content.clone()),
97        }
98    }
99}
100
101pub fn calculate_eos_tokens(
102    chat_template: &ChatTemplate,
103    gen_conf: Option<GenerationConfig>,
104    tokenizer: &Tokenizer,
105) -> Vec<u32> {
106    let mut eos_tok_ids = chat_template.eos_tok().map(|x| vec![x]).unwrap_or_default();
107    let mut bos_tok_ids = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
108
109    for alternate in SUPPORTED_ALTERNATE_EOS {
110        if tokenizer.get_vocab(true).contains_key(*alternate) {
111            eos_tok_ids.push(alternate.to_string())
112        }
113    }
114
115    if let Some(gen_conf) = gen_conf {
116        let ids = match gen_conf.eos_token_id {
117            Either::Left(id) => vec![id],
118            Either::Right(ids) => ids,
119        };
120        for id in ids {
121            let s = tokenizer
122                .decode(&[id], false)
123                .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
124            if !eos_tok_ids.contains(&s) {
125                eos_tok_ids.push(s);
126            }
127        }
128
129        let ids = match gen_conf.bos_token_id {
130            Either::Left(id) => vec![id],
131            Either::Right(ids) => ids,
132        };
133        for id in ids {
134            let s = tokenizer
135                .decode(&[id], false)
136                .unwrap_or_else(|_| panic!("Unable to decode id {id})"));
137            if !bos_tok_ids.contains(&s) {
138                bos_tok_ids.push(s);
139            }
140        }
141    }
142
143    eos_tok_ids = eos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
144    bos_tok_ids = bos_tok_ids.into_iter().dedup().collect::<Vec<_>>();
145
146    let bos_render = bos_tok_ids
147        .iter()
148        .map(|val| format!("{:?}", val))
149        .collect::<Vec<String>>()
150        .join(", ");
151    let eos_render = eos_tok_ids
152        .iter()
153        .map(|val| format!("{:?}", val))
154        .collect::<Vec<String>>()
155        .join(", ");
156
157    info!(
158        "bos_toks = {bos_render}, eos_toks = {eos_render}, unk_tok = {}",
159        chat_template.unk_tok().unwrap_or("`None`".to_string()),
160    );
161
162    let mut eos_toks = Vec::new();
163    for eos_tok in eos_tok_ids {
164        eos_toks.push(
165            tokenizer
166                .get_vocab(true)
167                .get(&eos_tok)
168                .copied()
169                .unwrap_or_else(|| panic!("Unable to extract `{eos_tok}` EOS token.")),
170        )
171    }
172    eos_toks
173}
174
175#[allow(dead_code)]
176#[derive(Debug, Deserialize)]
177pub struct GenerationConfig {
178    #[serde(with = "either::serde_untagged")]
179    bos_token_id: Either<u32, Vec<u32>>,
180    #[serde(with = "either::serde_untagged")]
181    eos_token_id: Either<u32, Vec<u32>>,
182}
183
184fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
185    if let Ok(indent) = kwargs.get("indent") {
186        let mut buf = Vec::new();
187        let repeat = b" ".repeat(indent);
188        let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
189        let mut ser = serde_json::Serializer::with_formatter(&mut buf, formatter);
190        value.serialize(&mut ser).unwrap();
191        String::from_utf8(buf).map_err(|err| {
192            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
193        })
194    } else {
195        serde_json::to_string(&value).map_err(|err| {
196            Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
197        })
198    }
199    .map_err(|err| {
200        Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
201    })
202    .map(|s| {
203        // When this filter is used the return value is safe for both HTML and JSON
204        let mut rv = String::with_capacity(s.len());
205        for c in s.chars() {
206            match c {
207                '<' => rv.push_str("\\u003c"),
208                '>' => rv.push_str("\\u003e"),
209                '&' => rv.push_str("\\u0026"),
210                '\'' => rv.push_str("\\u0027"),
211                _ => rv.push(c),
212            }
213        }
214        Value::from_safe_string(rv)
215    })
216}
217
218fn strftime_now(fmt: String) -> Result<String, minijinja::Error> {
219    let date = chrono::Utc::now();
220    let date_string = date.format(&fmt).to_string();
221    Ok(date_string)
222}
223
224pub fn apply_chat_template_to(
225    messages: Vec<IndexMap<String, MessageContent>>,
226    add_generation_prompt: bool,
227    template: &ChatTemplateValue,
228    bos_tok: Option<String>,
229    eos_tok: Option<String>,
230    unk_tok: Option<String>,
231    tools: Vec<Tool>,
232) -> Result<String> {
233    let mut env = Environment::new();
234
235    // enable python methods such as .strip()
236    env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
237
238    // https://github.com/huggingface/transformers/blob/76a33a10923ccc1074917f6b6a1e719e626b7dc9/src/transformers/tokenization_utils_base.py#L1842
239    env.set_lstrip_blocks(true);
240    env.set_trim_blocks(true);
241
242    #[derive(Serialize, Deserialize)]
243    struct UntaggedContent(#[serde(with = "either::serde_untagged")] MessageContent);
244    let mut new_messages = Vec::new();
245    for message in messages {
246        let mut new_message = IndexMap::new();
247        for (k, v) in message {
248            new_message.insert(k, UntaggedContent(v));
249        }
250        new_messages.push(new_message);
251    }
252
253    let template = match &template.0 {
254        Either::Left(x) => x.clone(),
255        Either::Right(map) => {
256            let mut template = None;
257            let has_tool_use = map.iter().any(|t| {
258                t.get("name").is_some_and(|name| name == "tool_use") || t.contains_key("tool_use")
259            });
260            let must_use_tool_template = !tools.is_empty();
261
262            if must_use_tool_template && !has_tool_use {
263                anyhow::bail!(
264                    "Tools were provided but this chat template does not handle tool usage"
265                );
266            }
267
268            for t in map {
269                let name = t.get("name");
270                if let Some(name) = name {
271                    template = Some(t["template"].clone());
272                    #[allow(clippy::if_same_then_else)]
273                    if name == "tool_use" && !tools.is_empty() {
274                        break;
275                    } else if name == "default" && !must_use_tool_template {
276                        break;
277                    }
278                } else if t.contains_key("tool_use") && !tools.is_empty() {
279                    template = Some(t["tool_use"].clone());
280                    break;
281                } else if t.contains_key("default") && !must_use_tool_template {
282                    template = Some(t["default"].clone());
283                    break;
284                }
285            }
286
287            let Some(template) = template else {
288                anyhow::bail!("Chat template does not contain a `tool_use` or `default` key. Please ensure it contains at least a `default` key, although `tool_use` should be specified for using tools.");
289            };
290            template
291        }
292    };
293
294    env.add_template("chat_template", &template)?;
295    env.add_function("raise_exception", raise_exception);
296    env.add_filter("tojson", tojson);
297    env.add_function("strftime_now", strftime_now);
298    let tmpl = env.get_template("chat_template").unwrap();
299
300    let date = chrono::Utc::now();
301    let date_string = date.format("%d, %B, %Y").to_string();
302
303    if tools.is_empty() {
304        Ok(tmpl.render(context! {
305            messages => new_messages,
306            add_generation_prompt => add_generation_prompt,
307            bos_token => bos_tok,
308            eos_token => eos_tok,
309            unk_token => unk_tok,
310            date_string => date_string,
311        })?)
312    } else {
313        Ok(tmpl.render(context! {
314            messages => new_messages,
315            add_generation_prompt => add_generation_prompt,
316            bos_token => bos_tok,
317            eos_token => eos_tok,
318            unk_token => unk_tok,
319            tools => tools,
320            date_string => date_string,
321        })?)
322    }
323}