mistralrs_core/gguf/
gguf_tokenizer.rs

1// https://github.com/huggingface/transformers/blob/8685b3c5d2dd2550527773d2a02499495a759e31/src/transformers/convert_slow_tokenizer.py
2
3use std::{collections::HashMap, sync::atomic::Ordering};
4
5use crate::utils::gguf_metadata::ContentMetadata;
6use crate::DEBUG;
7use anyhow::Result;
8use candle_core::quantized::gguf_file::Value;
9use itertools::Itertools;
10use tokenizers::pre_tokenizers::{
11    sequence::Sequence,
12    split::{Split, SplitPattern},
13    PreTokenizerWrapper,
14};
15use tokenizers::tokenizer::normalizer::SplitDelimiterBehavior;
16use tokenizers::{
17    decoders::{
18        self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
19    },
20    models::{bpe::BpeBuilder, unigram::Unigram},
21    normalizers::{self, Prepend, Replace},
22    processors, AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
23};
24use tracing::info;
25
26use super::Content;
27
28pub(crate) struct GgufTokenizerConversion {
29    pub tokenizer: Tokenizer,
30    pub bos: Option<String>,
31    pub eos: Option<String>,
32    pub unk: Option<String>,
33}
34
35struct PropsGGUF {
36    model: String,
37    tokens: Vec<String>,
38    added_tokens: Option<Vec<String>>,
39    scores: Option<Vec<f32>>,
40    merges: Option<Vec<String>>,
41    unk: Option<u32>,
42    eos: u32,
43    bos: u32,
44}
45
46impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
47    type Error = anyhow::Error;
48
49    fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
50        let required = ["model", "tokens", "eos_token_id", "bos_token_id"];
51        c.has_required_keys(&required)?;
52
53        let props = Self {
54            model: c.get_value("model")?,
55            tokens: c.get_value("tokens")?,
56            added_tokens: c.get_value("added_tokens").ok(),
57            scores: c.get_value("scores").ok(),
58            merges: c.get_value("merges").ok(),
59            unk: c.get_value("unknown_token_id").ok(),
60            eos: c.get_value("eos_token_id")?,
61            bos: c.get_value("bos_token_id")?,
62        };
63
64        Ok(props)
65    }
66}
67
68pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
69    content: &Content<'_, R>,
70) -> Result<GgufTokenizerConversion> {
71    let metadata = ContentMetadata {
72        path_prefix: "tokenizer.ggml",
73        metadata: content.get_metadata(),
74    };
75
76    let md_get = |s: &str| match metadata.metadata.get(s) {
77        None => candle_core::bail!("cannot find {s} in metadata"),
78        Some(v) => Ok(v),
79    };
80
81    let mut token_types = Vec::<i32>::new();
82    if metadata.metadata.contains_key("tokenizer.ggml.token_type") {
83        let vtypes: &Vec<Value> = md_get("tokenizer.ggml.token_type")
84            .unwrap()
85            .to_vec()
86            .unwrap();
87        let v: Vec<i32> = vtypes.iter().map(|v| v.to_i32().unwrap()).collect();
88        token_types.extend(v);
89    }
90
91    let props = PropsGGUF::try_from(metadata)?;
92
93    let (mut tokenizer, kind) = match props.model.as_str() {
94        "llama" | "replit" => unigram_tokenizer(&props)?,
95        "gpt2" => bpe_tokenizer(&props)?,
96        other => {
97            anyhow::bail!("Tokenizer model `{other}` not supported.");
98        }
99    };
100
101    //token type other than 1 treated as special token
102    let mut num_special_tokens = 0;
103    #[allow(clippy::needless_range_loop)]
104    if token_types.len() == props.tokens.len() {
105        for i in 0..props.tokens.len() {
106            if token_types[i] != 1i32 {
107                let tk = props.tokens[i].clone();
108                tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
109                num_special_tokens += 1;
110            }
111        }
112    }
113
114    info!(
115        "GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num special tokens {}, num added tokens: {}, num merges: {}, num scores: {}",
116        tokenizer.get_vocab_size(true),
117        num_special_tokens,
118        props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
119        props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
120        props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
121        model = props.model,
122    );
123    if DEBUG.load(Ordering::Relaxed) {
124        info!("Tokenizer: {tokenizer:?}");
125    }
126
127    let unk = match props.unk {
128        Some(u) => Some(props.tokens[u as usize].clone()),
129        _ => None,
130    };
131
132    Ok(GgufTokenizerConversion {
133        tokenizer,
134        bos: Some(props.tokens[props.bos as usize].clone()),
135        eos: Some(props.tokens[props.eos as usize].clone()),
136        unk,
137    })
138}
139
140// TODO: Add support for additional tokenizer models: WordPiece, WordLevel
141// https://docs.rs/tokenizers/latest/tokenizers/models/enum.ModelWrapper.html
142#[derive(Debug)]
143enum TokenizerKind {
144    Unigram,
145    Bpe,
146}
147
148fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
149    let PropsGGUF { unk, eos, bos, .. } = *p;
150    // Unigram (SentencePiece) default UNK is 0
151    let unk = unk.unwrap_or(0);
152
153    // Create the Tokenizer model:
154    let model = {
155        let vocab: Vec<(String, f64)> = {
156            let Some(s) = p.scores.as_ref() else {
157                anyhow::bail!(
158                    "`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
159                );
160            };
161            let scores = s.iter().cloned().map(|f_32| f_32 as f64);
162
163            p.tokens.iter().cloned().zip(scores).collect()
164        };
165
166        Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
167    };
168
169    // Decoder + Normalizer config reference:
170    // https://github.com/EricLBuehler/mistral.rs/pull/389#discussion_r1630620763
171    let decoder = Decoder::Sequence(vec![
172        Decoder::Replace("▁", " "),
173        Decoder::ByteFallback,
174        Decoder::Fuse,
175        Decoder::Strip(' ', 1, 0),
176    ]);
177
178    let normalizer = Normalizer::Sequence(vec![
179        Normalizer::Prepend("▁"),
180        Normalizer::Replace(" ", "▁"),
181    ]);
182
183    let mut tokenizer: Tokenizer = TokenizerX::new(
184        ModelWrapper::Unigram(model),
185        Some(decoder),
186        Some(normalizer),
187    )?;
188
189    // Add special tokens (bos, eos, unk):
190    for i in [bos, eos, unk] {
191        let tk = p.tokens[i as usize].clone();
192        tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
193    }
194    Ok((tokenizer, TokenizerKind::Unigram))
195}
196
197fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
198    // BPE merges have each string item as a space-delimited pair:
199    // https://github.com/EricLBuehler/mistral.rs/pull/397#discussion_r1631988370
200    let merges = p
201        .merges
202        .as_ref()
203        .ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
204        .iter()
205        .map(|merge| {
206            let split: (&str, &str) = merge
207                .splitn(2, ' ')
208                .collect_tuple()
209                .expect("Failed to convert split into 2-tuple");
210            (split.0.to_string(), split.1.to_string())
211        })
212        .collect::<Vec<_>>();
213
214    let mut vocab = HashMap::new();
215    for (i, token) in p.tokens.iter().enumerate() {
216        #[allow(clippy::cast_possible_truncation)]
217        vocab.insert(token.clone(), i as u32);
218    }
219
220    let PropsGGUF { eos, bos, unk, .. } = *p;
221
222    let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
223    if let Some(unk) = unk {
224        bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
225    };
226
227    let bpe = bpe.build().map_err(anyhow::Error::msg)?;
228
229    let mut tokenizer = TokenizerX::new(
230        ModelWrapper::BPE(bpe),
231        Some(Decoder::ByteLevel(true, true, true)),
232        None,
233    )?;
234
235    let split = Split::new(
236        SplitPattern::Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()),
237        SplitDelimiterBehavior::Isolated,
238        false,
239    ).unwrap();
240
241    // example:
242    // "type": "ByteLevel",
243    // "add_prefix_space": false,
244    // "trim_offsets": false,
245    // "use_regex": false
246    let pre_tokenizer = Sequence::new(vec![
247        PreTokenizerWrapper::Split(split),
248        PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, false, false)),
249    ]);
250
251    tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
252
253    tokenizer.with_decoder(Some(decoders::byte_level::ByteLevel::new(
254        false, false, false,
255    )));
256    tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
257        false, false, false,
258    )));
259
260    for i in [bos, eos] {
261        let tk = p.tokens[i as usize].clone();
262        tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
263    }
264    if unk.is_some() {
265        let tk = p.tokens[unk.unwrap() as usize].clone();
266        tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
267    }
268
269    Ok((tokenizer, TokenizerKind::Bpe))
270}
271
272// This is a workaround to have a better builder API.
273// Upstream `TokenizerBuilder` is difficult to work with:
274// https://github.com/huggingface/tokenizers/issues/1549
275struct TokenizerX;
276
277impl TokenizerX {
278    #[allow(clippy::new_ret_no_self)]
279    fn new<'a>(
280        model: ModelWrapper,
281        decoder: Option<Decoder<'a>>,
282        normalizer: Option<Normalizer<'a>>,
283    ) -> Result<Tokenizer> {
284        let mut tokenizer = Tokenizer::new(model);
285
286        // Handle local enum to remote enum type:
287        if let Some(decoder) = decoder {
288            let d = DecoderWrapper::try_from(decoder)?;
289            tokenizer.with_decoder(Some(d));
290        }
291        if let Some(normalizer) = normalizer {
292            let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
293            tokenizer.with_normalizer(Some(n));
294        }
295
296        Ok(tokenizer)
297    }
298}
299
300// Convenient alternative to upstream:
301// https://docs.rs/tokenizers/latest/tokenizers/decoders/enum.DecoderWrapper.html
302enum Decoder<'a> {
303    ByteFallback,
304    Fuse,
305    Replace(&'a str, &'a str),
306    Strip(char, usize, usize),
307    Sequence(Vec<Self>),
308    ByteLevel(bool, bool, bool),
309}
310
311// Convert into upstream type wrapped enum variants:
312impl TryFrom<Decoder<'_>> for DecoderWrapper {
313    type Error = anyhow::Error;
314
315    fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
316        let value: DecoderWrapper = match variant {
317            Decoder::ByteFallback => ByteFallback::default().into(),
318            Decoder::Fuse => Fuse::default().into(),
319            Decoder::Replace(pattern, content) => Replace::new(pattern, content)
320                .map_err(anyhow::Error::msg)?
321                .into(),
322            Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
323            Decoder::Sequence(decoders) => {
324                let seq = decoders
325                    .into_iter()
326                    .map(DecoderWrapper::try_from)
327                    .collect::<Result<Vec<DecoderWrapper>>>()?;
328
329                decoders::sequence::Sequence::new(seq).into()
330            }
331            Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
332                ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
333            }
334        };
335
336        Ok(value)
337    }
338}
339
340// Convenient alternative to upstream:
341// https://docs.rs/tokenizers/latest/tokenizers/normalizers/enum.NormalizerWrapper.html
342enum Normalizer<'a> {
343    Prepend(&'a str),
344    Replace(&'a str, &'a str),
345    Sequence(Vec<Self>),
346}
347
348impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
349    type Error = anyhow::Error;
350
351    fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
352        let value: NormalizerWrapper = match variant {
353            Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
354            Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
355                .map_err(anyhow::Error::msg)?
356                .into(),
357            Normalizer::Sequence(decoders) => {
358                let seq = decoders
359                    .into_iter()
360                    .map(NormalizerWrapper::try_from)
361                    .collect::<Result<Vec<NormalizerWrapper>>>()?;
362
363                normalizers::Sequence::new(seq).into()
364            }
365        };
366
367        Ok(value)
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use anyhow::Result;
374    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
375    use tokenizers::Tokenizer;
376
377    #[allow(dead_code)]
378    #[derive(Debug)]
379    enum TokenizerType {
380        /// Mistral v0.1 tokenizer
381        Llama,
382        Replit,
383        Gpt2,
384        Rwkv,
385    }
386
387    fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
388        match tokenizer {
389            TokenizerType::Llama => {
390                let api = ApiBuilder::new().with_progress(true).build().unwrap();
391                let api = api.repo(Repo::with_revision(
392                    "EricB/mistralrs_tests".to_string(),
393                    RepoType::Model,
394                    "main".to_string(),
395                ));
396
397                let filename = api.get("llama_gguf_tokenizer.json").unwrap();
398                let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
399                Ok(tokenizer)
400            }
401            TokenizerType::Gpt2 => {
402                let api = ApiBuilder::new().with_progress(true).build().unwrap();
403                let api = api.repo(Repo::with_revision(
404                    "EricB/mistralrs_tests".to_string(),
405                    RepoType::Model,
406                    "main".to_string(),
407                ));
408
409                let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
410                let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
411                Ok(tokenizer)
412            }
413            other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
414        }
415    }
416
417    fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
418        match tokenizer {
419            TokenizerType::Llama => {
420                let api = ApiBuilder::new().with_progress(true).build().unwrap();
421                let api = api.repo(Repo::with_revision(
422                    "EricB/mistralrs_tests".to_string(),
423                    RepoType::Model,
424                    "main".to_string(),
425                ));
426
427                let tokenizer_filename = api.get("tokenizer.json").unwrap();
428                Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
429            }
430            TokenizerType::Gpt2 => {
431                let api = ApiBuilder::new().with_progress(true).build().unwrap();
432                let api = api.repo(Repo::with_revision(
433                    "EricB/mistralrs_tests".to_string(),
434                    RepoType::Model,
435                    "main".to_string(),
436                ));
437
438                let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
439                Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
440            }
441            other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
442        }
443    }
444
445    // Content based upon https://github.com/ggerganov/llama.cpp/blob/master/tests/test-tokenizer-random.py#L99-L161
446    fn get_test_passage() -> String {
447        let passage = "Hello, world! \n🚀 (normal) 😶‍🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
448
449        passage.to_owned()
450    }
451
452    // The provided passage should encode and decode back into the same passage string:
453    fn codec_roundtrip(
454        tokenizer: &Tokenizer,
455        passage: &str,
456        add_special_tokens: bool,
457    ) -> Result<String> {
458        let tokenized = tokenizer
459            .encode_fast(passage, add_special_tokens)
460            .map_err(anyhow::Error::msg)?;
461
462        // NOTE: The special tokens bool param meaning differs between encode() / decode():
463        decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
464    }
465
466    fn decode(
467        tokenizer: &Tokenizer,
468        token_ids: &[u32],
469        skip_special_tokens: bool,
470    ) -> Result<String> {
471        tokenizer
472            .decode(token_ids, skip_special_tokens)
473            .map_err(anyhow::Error::msg)
474    }
475
476    #[test]
477    fn test_encode_decode_llama() -> Result<()> {
478        use rand::rng;
479        use rand::seq::SliceRandom;
480
481        let passage = get_test_passage();
482        let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
483        let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
484
485        // Without adding special tokens
486        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
487        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
488        assert_eq!(hf_decoded, gguf_decoded);
489        assert_eq!(passage, gguf_decoded);
490
491        // With special tokens added
492        // SKIPPED:
493        // - Bugged the GGUF tokenizer does not prepend `<s> `
494        // - Due to HF tokenizer using BPE (tokenizer.json) while GGUF tokenizer uses Unigram (metadata)?
495        /*
496        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
497        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
498        assert_eq!(hf_decoded, gguf_decoded);
499        */
500
501        #[allow(clippy::cast_possible_truncation)]
502        let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
503        tokens.shuffle(&mut rng());
504
505        // Without skipping special tokens
506        let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
507        let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
508        assert_eq!(hf_decoded, gguf_decoded);
509
510        // With skipping special tokens
511        let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
512        let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
513        assert_eq!(hf_decoded, gguf_decoded);
514
515        Ok(())
516    }
517
518    #[test]
519    fn test_encode_decode_gpt2() -> Result<()> {
520        use rand::rng;
521        use rand::seq::SliceRandom;
522
523        let passage = get_test_passage();
524        let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
525        let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
526
527        // Without adding special tokens
528        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
529        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
530        assert_eq!(hf_decoded, gguf_decoded);
531        assert_eq!(passage, gguf_decoded);
532
533        // With special tokens added
534        // SKIPPED:
535        // - Bugged the GGUF tokenizer does not prepend `<s> `
536        // - Due to HF tokenizer using BPE (tokenizer.json) while GGUF tokenizer uses Unigram (metadata)?
537        /*
538        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
539        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
540        assert_eq!(hf_decoded, gguf_decoded);
541        */
542
543        #[allow(clippy::cast_possible_truncation)]
544        let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
545        tokens.shuffle(&mut rng());
546
547        // Without skipping special tokens
548        let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
549        let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
550        assert_eq!(hf_decoded, gguf_decoded);
551
552        // With skipping special tokens
553        let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
554        let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
555        assert_eq!(hf_decoded, gguf_decoded);
556
557        Ok(())
558    }
559}