mistralrs_core/gguf/
gguf_tokenizer.rs

1// https://github.com/huggingface/transformers/blob/8685b3c5d2dd2550527773d2a02499495a759e31/src/transformers/convert_slow_tokenizer.py
2
3use std::{collections::HashMap, sync::atomic::Ordering};
4
5use anyhow::Result;
6use itertools::Itertools;
7use tokenizers::{
8    decoders::{
9        self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
10    },
11    models::{bpe::BpeBuilder, unigram::Unigram},
12    normalizers::{self, Prepend, Replace},
13    pre_tokenizers,
14    processors::{
15        self,
16        template::{self, TemplateProcessing},
17    },
18    AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
19};
20use tracing::info;
21
22use crate::utils::gguf_metadata::ContentMetadata;
23use crate::DEBUG;
24
25use super::Content;
26
27pub(crate) struct GgufTokenizerConversion {
28    pub tokenizer: Tokenizer,
29    pub bos: Option<String>,
30    pub eos: Option<String>,
31    pub unk: Option<String>,
32}
33
34struct PropsGGUF {
35    model: String,
36    tokens: Vec<String>,
37    added_tokens: Option<Vec<String>>,
38    scores: Option<Vec<f32>>,
39    merges: Option<Vec<String>>,
40    unk: Option<u32>,
41    eos: u32,
42    bos: u32,
43    add_bos_token: Option<bool>,
44}
45
46impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
47    type Error = anyhow::Error;
48
49    fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
50        let required = ["model", "tokens", "eos_token_id", "bos_token_id"];
51        c.has_required_keys(&required)?;
52
53        let props = Self {
54            model: c.get_value("model")?,
55            tokens: c.get_value("tokens")?,
56            added_tokens: c.get_value("added_tokens").ok(),
57            scores: c.get_value("scores").ok(),
58            merges: c.get_value("merges").ok(),
59            unk: c.get_value("unknown_token_id").ok(),
60            eos: c.get_value("eos_token_id")?,
61            bos: c.get_value("bos_token_id")?,
62            add_bos_token: c.get_value("add_bos_token").ok(),
63        };
64
65        Ok(props)
66    }
67}
68
69struct AddedTokensCollection {
70    bos: String,
71    eos: String,
72    unk: Option<String>,
73}
74
75pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
76    content: &Content<'_, R>,
77) -> Result<GgufTokenizerConversion> {
78    let metadata = ContentMetadata {
79        path_prefix: "tokenizer.ggml",
80        metadata: content.get_metadata(),
81    };
82    let props = PropsGGUF::try_from(metadata)?;
83
84    let (tokenizer, kind, special_tokens) = match props.model.as_str() {
85        "llama" | "replit" => unigram_tokenizer(&props)?,
86        "gpt2" => bpe_tokenizer(&props)?,
87        other => {
88            anyhow::bail!("Tokenizer model `{other}` not supported.");
89        }
90    };
91
92    info!(
93        "GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num added tokens: {}, num merges: {}, num scores: {}",
94        tokenizer.get_vocab_size(true),
95        props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
96        props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
97        props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
98        model = props.model,
99    );
100    if DEBUG.load(Ordering::Relaxed) {
101        info!("Tokenizer: {tokenizer:?}");
102    }
103
104    let AddedTokensCollection { bos, eos, unk } = special_tokens;
105
106    Ok(GgufTokenizerConversion {
107        tokenizer,
108        bos: Some(bos),
109        eos: Some(eos),
110        unk,
111    })
112}
113
114// TODO: Add support for additional tokenizer models: WordPiece, WordLevel
115// https://docs.rs/tokenizers/latest/tokenizers/models/enum.ModelWrapper.html
116#[derive(Debug)]
117enum TokenizerKind {
118    Unigram,
119    Bpe,
120}
121
122/// Add the special tokens and return their string representations
123fn add_special_tokens(
124    p: &PropsGGUF,
125    tokenizer: &mut Tokenizer,
126    bos: u32,
127    eos: u32,
128    unk: Option<u32>,
129) -> AddedTokensCollection {
130    // Add special tokens (bos, eos, unk):
131    let mut special_tokens: [Option<String>; 3] = Default::default();
132
133    // A little bit awkward here since eos/bos are assumed not options so we need to handle an Option
134    for (i, token_id) in [Some(bos), Some(eos), unk].into_iter().enumerate() {
135        if let Some(token_id) = token_id {
136            let token = p.tokens[token_id as usize].as_str();
137            special_tokens[i] = Some(token.to_string());
138            tokenizer.add_special_tokens(&[AddedToken::from(token.to_string(), true)]);
139        }
140    }
141
142    // Destructure array of options:
143    let [bos_str, eos_str, unk_str] = special_tokens;
144    // Would need to unwrap bos/eos here, or change the struct types
145    AddedTokensCollection {
146        bos: bos_str.unwrap(),
147        eos: eos_str.unwrap(),
148        unk: unk_str,
149    }
150}
151
152fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind, AddedTokensCollection)> {
153    let PropsGGUF { unk, eos, bos, .. } = *p;
154    // Unigram (SentencePiece) default UNK is 0
155    let unk = unk.unwrap_or(0);
156
157    // Create the Tokenizer model:
158    let model = {
159        let vocab: Vec<(String, f64)> = {
160            let Some(s) = p.scores.as_ref() else {
161                anyhow::bail!(
162                    "`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
163                );
164            };
165            let scores = s.iter().cloned().map(|f_32| f_32 as f64);
166
167            p.tokens.iter().cloned().zip(scores).collect()
168        };
169
170        Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
171    };
172
173    // Decoder + Normalizer config reference:
174    // https://github.com/EricLBuehler/mistral.rs/pull/389#discussion_r1630620763
175    let decoder = Decoder::Sequence(vec![
176        Decoder::Replace("▁", " "),
177        Decoder::ByteFallback,
178        Decoder::Fuse,
179        Decoder::Strip(' ', 1, 0),
180    ]);
181
182    let normalizer = Normalizer::Sequence(vec![
183        Normalizer::Prepend("▁"),
184        Normalizer::Replace(" ", "▁"),
185    ]);
186
187    let mut tokenizer: Tokenizer = TokenizerX::new(
188        ModelWrapper::Unigram(model),
189        Some(decoder),
190        Some(normalizer),
191    )?;
192
193    // Add special tokens (bos, eos, unk):
194    let special_tokens = add_special_tokens(p, &mut tokenizer, bos, eos, Some(unk));
195
196    Ok((tokenizer, TokenizerKind::Unigram, special_tokens))
197}
198
199fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind, AddedTokensCollection)> {
200    // BPE merges have each string item as a space-delimited pair:
201    // https://github.com/EricLBuehler/mistral.rs/pull/397#discussion_r1631988370
202    let merges = p
203        .merges
204        .as_ref()
205        .ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
206        .iter()
207        .map(|merge| {
208            let split: (&str, &str) = merge
209                .splitn(2, ' ')
210                .collect_tuple()
211                .expect("Failed to convert split into 2-tuple");
212            (split.0.to_string(), split.1.to_string())
213        })
214        .collect::<Vec<_>>();
215
216    let mut vocab = HashMap::new();
217    for (i, token) in p.tokens.iter().enumerate() {
218        #[allow(clippy::cast_possible_truncation)]
219        vocab.insert(token.clone(), i as u32);
220    }
221
222    let PropsGGUF {
223        eos,
224        bos,
225        unk,
226        add_bos_token,
227        ..
228    } = *p;
229
230    let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
231    if let Some(unk) = unk {
232        bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
233    };
234
235    let bpe = bpe.build().map_err(anyhow::Error::msg)?;
236
237    let mut tokenizer = TokenizerX::new(
238        ModelWrapper::BPE(bpe),
239        Some(Decoder::ByteLevel(true, true, true)),
240        None,
241    )?;
242    tokenizer.with_pre_tokenizer(Some(pre_tokenizers::byte_level::ByteLevel::new(
243        false, true, true,
244    )));
245    if add_bos_token.is_some_and(|x| x) {
246        let mut special_toks = HashMap::new();
247        special_toks.insert(
248            p.tokens[bos as usize].clone(),
249            template::SpecialToken::new(
250                p.tokens[bos as usize].clone(),
251                vec![bos],
252                vec![p.tokens[bos as usize].clone()],
253            )
254            .unwrap(),
255        );
256        tokenizer.with_post_processor(Some(
257            TemplateProcessing::builder()
258                .try_single(format!("{}:0 $A:0", p.tokens[bos as usize]))
259                .unwrap()
260                .try_pair(format!("{}:0 $A:0 $B:1", p.tokens[bos as usize]))
261                .unwrap()
262                .special_tokens(special_toks)
263                .build()
264                .unwrap(),
265        ));
266    } else {
267        tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
268            true, false, true,
269        )));
270    }
271
272    let special_tokens = add_special_tokens(p, &mut tokenizer, bos, eos, unk);
273
274    Ok((tokenizer, TokenizerKind::Bpe, special_tokens))
275}
276
277// This is a workaround to have a better builder API.
278// Upstream `TokenizerBuilder` is difficult to work with:
279// https://github.com/huggingface/tokenizers/issues/1549
280struct TokenizerX;
281
282impl TokenizerX {
283    #[allow(clippy::new_ret_no_self)]
284    fn new<'a>(
285        model: ModelWrapper,
286        decoder: Option<Decoder<'a>>,
287        normalizer: Option<Normalizer<'a>>,
288    ) -> Result<Tokenizer> {
289        let mut tokenizer = Tokenizer::new(model);
290
291        // Handle local enum to remote enum type:
292        if let Some(decoder) = decoder {
293            let d = DecoderWrapper::try_from(decoder)?;
294            tokenizer.with_decoder(Some(d));
295        }
296        if let Some(normalizer) = normalizer {
297            let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
298            tokenizer.with_normalizer(Some(n));
299        }
300
301        Ok(tokenizer)
302    }
303}
304
305// Convenient alternative to upstream:
306// https://docs.rs/tokenizers/latest/tokenizers/decoders/enum.DecoderWrapper.html
307enum Decoder<'a> {
308    ByteFallback,
309    Fuse,
310    Replace(&'a str, &'a str),
311    Strip(char, usize, usize),
312    Sequence(Vec<Self>),
313    ByteLevel(bool, bool, bool),
314}
315
316// Convert into upstream type wrapped enum variants:
317impl TryFrom<Decoder<'_>> for DecoderWrapper {
318    type Error = anyhow::Error;
319
320    fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
321        let value: DecoderWrapper = match variant {
322            Decoder::ByteFallback => ByteFallback::default().into(),
323            Decoder::Fuse => Fuse::default().into(),
324            Decoder::Replace(pattern, content) => Replace::new(pattern, content)
325                .map_err(anyhow::Error::msg)?
326                .into(),
327            Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
328            Decoder::Sequence(decoders) => {
329                let seq = decoders
330                    .into_iter()
331                    .map(DecoderWrapper::try_from)
332                    .collect::<Result<Vec<DecoderWrapper>>>()?;
333
334                decoders::sequence::Sequence::new(seq).into()
335            }
336            Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
337                ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
338            }
339        };
340
341        Ok(value)
342    }
343}
344
345// Convenient alternative to upstream:
346// https://docs.rs/tokenizers/latest/tokenizers/normalizers/enum.NormalizerWrapper.html
347enum Normalizer<'a> {
348    Prepend(&'a str),
349    Replace(&'a str, &'a str),
350    Sequence(Vec<Self>),
351}
352
353impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
354    type Error = anyhow::Error;
355
356    fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
357        let value: NormalizerWrapper = match variant {
358            Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
359            Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
360                .map_err(anyhow::Error::msg)?
361                .into(),
362            Normalizer::Sequence(decoders) => {
363                let seq = decoders
364                    .into_iter()
365                    .map(NormalizerWrapper::try_from)
366                    .collect::<Result<Vec<NormalizerWrapper>>>()?;
367
368                normalizers::Sequence::new(seq).into()
369            }
370        };
371
372        Ok(value)
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use anyhow::Result;
379    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
380    use tokenizers::Tokenizer;
381
382    #[allow(dead_code)]
383    #[derive(Debug)]
384    enum TokenizerType {
385        /// Mistral v0.1 tokenizer
386        Llama,
387        Replit,
388        Gpt2,
389        Rwkv,
390    }
391
392    fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
393        match tokenizer {
394            TokenizerType::Llama => {
395                let api = ApiBuilder::new().with_progress(true).build().unwrap();
396                let api = api.repo(Repo::with_revision(
397                    "EricB/mistralrs_tests".to_string(),
398                    RepoType::Model,
399                    "main".to_string(),
400                ));
401
402                let filename = api.get("llama_gguf_tokenizer.json").unwrap();
403                let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
404                Ok(tokenizer)
405            }
406            TokenizerType::Gpt2 => {
407                let api = ApiBuilder::new().with_progress(true).build().unwrap();
408                let api = api.repo(Repo::with_revision(
409                    "EricB/mistralrs_tests".to_string(),
410                    RepoType::Model,
411                    "main".to_string(),
412                ));
413
414                let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
415                let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
416                Ok(tokenizer)
417            }
418            other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
419        }
420    }
421
422    fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
423        match tokenizer {
424            TokenizerType::Llama => {
425                let api = ApiBuilder::new().with_progress(true).build().unwrap();
426                let api = api.repo(Repo::with_revision(
427                    "EricB/mistralrs_tests".to_string(),
428                    RepoType::Model,
429                    "main".to_string(),
430                ));
431
432                let tokenizer_filename = api.get("tokenizer.json").unwrap();
433                Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
434            }
435            TokenizerType::Gpt2 => {
436                let api = ApiBuilder::new().with_progress(true).build().unwrap();
437                let api = api.repo(Repo::with_revision(
438                    "EricB/mistralrs_tests".to_string(),
439                    RepoType::Model,
440                    "main".to_string(),
441                ));
442
443                let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
444                Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
445            }
446            other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
447        }
448    }
449
450    // Content based upon https://github.com/ggerganov/llama.cpp/blob/master/tests/test-tokenizer-random.py#L99-L161
451    fn get_test_passage() -> String {
452        let passage = "Hello, world! \n🚀 (normal) 😶‍🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
453
454        passage.to_owned()
455    }
456
457    // The provided passage should encode and decode back into the same passage string:
458    fn codec_roundtrip(
459        tokenizer: &Tokenizer,
460        passage: &str,
461        add_special_tokens: bool,
462    ) -> Result<String> {
463        let tokenized = tokenizer
464            .encode_fast(passage, add_special_tokens)
465            .map_err(anyhow::Error::msg)?;
466
467        // NOTE: The special tokens bool param meaning differs between encode() / decode():
468        decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
469    }
470
471    fn decode(
472        tokenizer: &Tokenizer,
473        token_ids: &[u32],
474        skip_special_tokens: bool,
475    ) -> Result<String> {
476        tokenizer
477            .decode(token_ids, skip_special_tokens)
478            .map_err(anyhow::Error::msg)
479    }
480
481    #[test]
482    fn test_encode_decode_llama() -> Result<()> {
483        use rand::rng;
484        use rand::seq::SliceRandom;
485
486        let passage = get_test_passage();
487        let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
488        let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
489
490        // Without adding special tokens
491        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
492        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
493        assert_eq!(hf_decoded, gguf_decoded);
494        assert_eq!(passage, gguf_decoded);
495
496        // With special tokens added
497        // SKIPPED:
498        // - Bugged the GGUF tokenizer does not prepend `<s> `
499        // - Due to HF tokenizer using BPE (tokenizer.json) while GGUF tokenizer uses Unigram (metadata)?
500        /*
501        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
502        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
503        assert_eq!(hf_decoded, gguf_decoded);
504        */
505
506        #[allow(clippy::cast_possible_truncation)]
507        let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
508        tokens.shuffle(&mut rng());
509
510        // Without skipping special tokens
511        let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
512        let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
513        assert_eq!(hf_decoded, gguf_decoded);
514
515        // With skipping special tokens
516        let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
517        let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
518        assert_eq!(hf_decoded, gguf_decoded);
519
520        Ok(())
521    }
522
523    #[test]
524    fn test_encode_decode_gpt2() -> Result<()> {
525        use rand::rng;
526        use rand::seq::SliceRandom;
527
528        let passage = get_test_passage();
529        let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
530        let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
531
532        // Without adding special tokens
533        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
534        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
535        assert_eq!(hf_decoded, gguf_decoded);
536        assert_eq!(passage, gguf_decoded);
537
538        // With special tokens added
539        // SKIPPED:
540        // - Bugged the GGUF tokenizer does not prepend `<s> `
541        // - Due to HF tokenizer using BPE (tokenizer.json) while GGUF tokenizer uses Unigram (metadata)?
542        /*
543        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
544        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
545        assert_eq!(hf_decoded, gguf_decoded);
546        */
547
548        #[allow(clippy::cast_possible_truncation)]
549        let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
550        tokens.shuffle(&mut rng());
551
552        // Without skipping special tokens
553        let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
554        let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
555        assert_eq!(hf_decoded, gguf_decoded);
556
557        // With skipping special tokens
558        let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
559        let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
560        assert_eq!(hf_decoded, gguf_decoded);
561
562        Ok(())
563    }
564}