mistralrs_core/gguf/
gguf_tokenizer.rs

1// https://github.com/huggingface/transformers/blob/8685b3c5d2dd2550527773d2a02499495a759e31/src/transformers/convert_slow_tokenizer.py
2
3use std::sync::atomic::Ordering;
4
5use crate::utils::gguf_metadata::ContentMetadata;
6use crate::DEBUG;
7use ahash::AHashMap;
8use anyhow::Result;
9use candle_core::quantized::gguf_file::Value;
10use itertools::Itertools;
11use tokenizers::pre_tokenizers::{
12    sequence::Sequence,
13    split::{Split, SplitPattern},
14    PreTokenizerWrapper,
15};
16use tokenizers::tokenizer::normalizer::SplitDelimiterBehavior;
17use tokenizers::{
18    decoders::{
19        self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
20    },
21    models::{bpe::BpeBuilder, unigram::Unigram},
22    normalizers::{self, Prepend, Replace},
23    processors, AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
24};
25use tracing::info;
26
27use super::Content;
28
29pub(crate) struct GgufTokenizerConversion {
30    pub tokenizer: Tokenizer,
31    pub bos: Option<String>,
32    pub eos: Option<String>,
33    pub unk: Option<String>,
34}
35
36struct PropsGGUF {
37    model: String,
38    tokens: Vec<String>,
39    added_tokens: Option<Vec<String>>,
40    scores: Option<Vec<f32>>,
41    merges: Option<Vec<String>>,
42    unk: Option<u32>,
43    bos: Option<u32>,
44    eos: u32,
45}
46
47impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
48    type Error = anyhow::Error;
49
50    fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
51        let required = ["model", "tokens", "eos_token_id"];
52        c.has_required_keys(&required)?;
53
54        let props = Self {
55            model: c.get_value("model")?,
56            tokens: c.get_value("tokens")?,
57            added_tokens: c.get_value("added_tokens").ok(),
58            scores: c.get_value("scores").ok(),
59            merges: c.get_value("merges").ok(),
60            unk: c.get_value("unknown_token_id").ok(),
61            eos: c.get_value("eos_token_id")?,
62            bos: c.get_value("bos_token_id").ok(),
63        };
64
65        Ok(props)
66    }
67}
68
69pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
70    content: &Content<'_, R>,
71) -> Result<GgufTokenizerConversion> {
72    let metadata = ContentMetadata {
73        path_prefix: "tokenizer.ggml",
74        metadata: content.get_metadata(),
75    };
76
77    let md_get = |s: &str| match metadata.metadata.get(s) {
78        None => candle_core::bail!("cannot find {s} in metadata"),
79        Some(v) => Ok(v),
80    };
81
82    let mut token_types = Vec::<i32>::new();
83    if metadata.metadata.contains_key("tokenizer.ggml.token_type") {
84        let vtypes: &Vec<Value> = md_get("tokenizer.ggml.token_type")
85            .unwrap()
86            .to_vec()
87            .unwrap();
88        let v: Vec<i32> = vtypes.iter().map(|v| v.to_i32().unwrap()).collect();
89        token_types.extend(v);
90    }
91
92    let props = PropsGGUF::try_from(metadata)?;
93
94    let (mut tokenizer, kind) = match props.model.as_str() {
95        "llama" | "replit" => unigram_tokenizer(&props)?,
96        "gpt2" => bpe_tokenizer(&props)?,
97        other => {
98            anyhow::bail!("Tokenizer model `{other}` not supported.");
99        }
100    };
101
102    //token type other than 1 treated as special token
103    let mut num_special_tokens = 0;
104    #[allow(clippy::needless_range_loop)]
105    if token_types.len() == props.tokens.len() {
106        for i in 0..props.tokens.len() {
107            if token_types[i] != 1i32 {
108                let tk = props.tokens[i].clone();
109                tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
110                num_special_tokens += 1;
111            }
112        }
113    }
114
115    info!(
116        "GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num special tokens {}, num added tokens: {}, num merges: {}, num scores: {}",
117        tokenizer.get_vocab_size(true),
118        num_special_tokens,
119        props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
120        props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
121        props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
122        model = props.model,
123    );
124    if DEBUG.load(Ordering::Relaxed) {
125        info!("Tokenizer: {tokenizer:?}");
126    }
127
128    let unk = match props.unk {
129        Some(u) => Some(props.tokens[u as usize].clone()),
130        _ => None,
131    };
132
133    let bos = if props.bos.is_some() {
134        Some(props.tokens[props.bos.unwrap() as usize].clone())
135    } else {
136        None
137    };
138
139    Ok(GgufTokenizerConversion {
140        tokenizer,
141        bos,
142        eos: Some(props.tokens[props.eos as usize].clone()),
143        unk,
144    })
145}
146
147// TODO: Add support for additional tokenizer models: WordPiece, WordLevel
148// https://docs.rs/tokenizers/latest/tokenizers/models/enum.ModelWrapper.html
149#[derive(Debug)]
150enum TokenizerKind {
151    Unigram,
152    Bpe,
153}
154
155fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
156    let PropsGGUF { unk, eos, bos, .. } = *p;
157    // Unigram (SentencePiece) default UNK is 0
158    let unk = unk.unwrap_or(0);
159
160    // Create the Tokenizer model:
161    let model = {
162        let vocab: Vec<(String, f64)> = {
163            let Some(s) = p.scores.as_ref() else {
164                anyhow::bail!(
165                    "`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
166                );
167            };
168            let scores = s.iter().cloned().map(|f_32| f_32 as f64);
169
170            p.tokens.iter().cloned().zip(scores).collect()
171        };
172
173        Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
174    };
175
176    // Decoder + Normalizer config reference:
177    // https://github.com/EricLBuehler/mistral.rs/pull/389#discussion_r1630620763
178    let decoder = Decoder::Sequence(vec![
179        Decoder::Replace("▁", " "),
180        Decoder::ByteFallback,
181        Decoder::Fuse,
182        Decoder::Strip(' ', 1, 0),
183    ]);
184
185    let normalizer = Normalizer::Sequence(vec![
186        Normalizer::Prepend("▁"),
187        Normalizer::Replace(" ", "▁"),
188    ]);
189
190    let mut tokenizer: Tokenizer = TokenizerX::new(
191        ModelWrapper::Unigram(model),
192        Some(decoder),
193        Some(normalizer),
194    )?;
195
196    // Add special tokens (bos, eos, unk):
197    for v in [bos, Some(eos), Some(unk)].iter().flatten() {
198        let tk = p.tokens[*v as usize].clone();
199        tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
200    }
201    Ok((tokenizer, TokenizerKind::Unigram))
202}
203
204fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
205    // BPE merges have each string item as a space-delimited pair:
206    // https://github.com/EricLBuehler/mistral.rs/pull/397#discussion_r1631988370
207    let merges = p
208        .merges
209        .as_ref()
210        .ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
211        .iter()
212        .map(|merge| {
213            let split: (&str, &str) = merge
214                .splitn(2, ' ')
215                .collect_tuple()
216                .expect("Failed to convert split into 2-tuple");
217            (split.0.to_string(), split.1.to_string())
218        })
219        .collect::<Vec<_>>();
220
221    let mut vocab = AHashMap::new();
222    for (i, token) in p.tokens.iter().enumerate() {
223        #[allow(clippy::cast_possible_truncation)]
224        vocab.insert(token.clone(), i as u32);
225    }
226
227    let PropsGGUF { bos, eos, unk, .. } = *p;
228
229    let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
230    if let Some(unk) = unk {
231        bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
232    };
233
234    let bpe = bpe.build().map_err(anyhow::Error::msg)?;
235
236    let mut tokenizer = TokenizerX::new(
237        ModelWrapper::BPE(bpe),
238        Some(Decoder::ByteLevel(true, true, true)),
239        None,
240    )?;
241
242    let split = Split::new(
243        SplitPattern::Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()),
244        SplitDelimiterBehavior::Isolated,
245        false,
246    ).unwrap();
247
248    // example:
249    // "type": "ByteLevel",
250    // "add_prefix_space": false,
251    // "trim_offsets": false,
252    // "use_regex": false
253    let pre_tokenizer = Sequence::new(vec![
254        PreTokenizerWrapper::Split(split),
255        PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, false, false)),
256    ]);
257
258    tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
259
260    tokenizer.with_decoder(Some(decoders::byte_level::ByteLevel::new(
261        false, false, false,
262    )));
263    tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
264        false, false, false,
265    )));
266
267    for v in [bos, Some(eos), unk].iter().flatten() {
268        let tk = p.tokens[*v as usize].clone();
269        tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
270    }
271
272    Ok((tokenizer, TokenizerKind::Bpe))
273}
274
275// This is a workaround to have a better builder API.
276// Upstream `TokenizerBuilder` is difficult to work with:
277// https://github.com/huggingface/tokenizers/issues/1549
278struct TokenizerX;
279
280impl TokenizerX {
281    #[allow(clippy::new_ret_no_self)]
282    fn new<'a>(
283        model: ModelWrapper,
284        decoder: Option<Decoder<'a>>,
285        normalizer: Option<Normalizer<'a>>,
286    ) -> Result<Tokenizer> {
287        let mut tokenizer = Tokenizer::new(model);
288
289        // Handle local enum to remote enum type:
290        if let Some(decoder) = decoder {
291            let d = DecoderWrapper::try_from(decoder)?;
292            tokenizer.with_decoder(Some(d));
293        }
294        if let Some(normalizer) = normalizer {
295            let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
296            tokenizer.with_normalizer(Some(n));
297        }
298
299        Ok(tokenizer)
300    }
301}
302
303// Convenient alternative to upstream:
304// https://docs.rs/tokenizers/latest/tokenizers/decoders/enum.DecoderWrapper.html
305enum Decoder<'a> {
306    ByteFallback,
307    Fuse,
308    Replace(&'a str, &'a str),
309    Strip(char, usize, usize),
310    Sequence(Vec<Self>),
311    ByteLevel(bool, bool, bool),
312}
313
314// Convert into upstream type wrapped enum variants:
315impl TryFrom<Decoder<'_>> for DecoderWrapper {
316    type Error = anyhow::Error;
317
318    fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
319        let value: DecoderWrapper = match variant {
320            Decoder::ByteFallback => ByteFallback::default().into(),
321            Decoder::Fuse => Fuse::default().into(),
322            Decoder::Replace(pattern, content) => Replace::new(pattern, content)
323                .map_err(anyhow::Error::msg)?
324                .into(),
325            Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
326            Decoder::Sequence(decoders) => {
327                let seq = decoders
328                    .into_iter()
329                    .map(DecoderWrapper::try_from)
330                    .collect::<Result<Vec<DecoderWrapper>>>()?;
331
332                decoders::sequence::Sequence::new(seq).into()
333            }
334            Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
335                ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
336            }
337        };
338
339        Ok(value)
340    }
341}
342
343// Convenient alternative to upstream:
344// https://docs.rs/tokenizers/latest/tokenizers/normalizers/enum.NormalizerWrapper.html
345enum Normalizer<'a> {
346    Prepend(&'a str),
347    Replace(&'a str, &'a str),
348    Sequence(Vec<Self>),
349}
350
351impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
352    type Error = anyhow::Error;
353
354    fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
355        let value: NormalizerWrapper = match variant {
356            Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
357            Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
358                .map_err(anyhow::Error::msg)?
359                .into(),
360            Normalizer::Sequence(decoders) => {
361                let seq = decoders
362                    .into_iter()
363                    .map(NormalizerWrapper::try_from)
364                    .collect::<Result<Vec<NormalizerWrapper>>>()?;
365
366                normalizers::Sequence::new(seq).into()
367            }
368        };
369
370        Ok(value)
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use anyhow::Result;
377    use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
378    use tokenizers::Tokenizer;
379
380    #[allow(dead_code)]
381    #[derive(Debug)]
382    enum TokenizerType {
383        /// Mistral v0.1 tokenizer
384        Llama,
385        Replit,
386        Gpt2,
387        Rwkv,
388    }
389
390    fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
391        match tokenizer {
392            TokenizerType::Llama => {
393                let api = ApiBuilder::new().with_progress(true).build().unwrap();
394                let api = api.repo(Repo::with_revision(
395                    "EricB/mistralrs_tests".to_string(),
396                    RepoType::Model,
397                    "main".to_string(),
398                ));
399
400                let filename = api.get("llama_gguf_tokenizer.json").unwrap();
401                let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
402                Ok(tokenizer)
403            }
404            TokenizerType::Gpt2 => {
405                let api = ApiBuilder::new().with_progress(true).build().unwrap();
406                let api = api.repo(Repo::with_revision(
407                    "EricB/mistralrs_tests".to_string(),
408                    RepoType::Model,
409                    "main".to_string(),
410                ));
411
412                let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
413                let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
414                Ok(tokenizer)
415            }
416            other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
417        }
418    }
419
420    fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
421        match tokenizer {
422            TokenizerType::Llama => {
423                let api = ApiBuilder::new().with_progress(true).build().unwrap();
424                let api = api.repo(Repo::with_revision(
425                    "EricB/mistralrs_tests".to_string(),
426                    RepoType::Model,
427                    "main".to_string(),
428                ));
429
430                let tokenizer_filename = api.get("tokenizer.json").unwrap();
431                Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
432            }
433            TokenizerType::Gpt2 => {
434                let api = ApiBuilder::new().with_progress(true).build().unwrap();
435                let api = api.repo(Repo::with_revision(
436                    "EricB/mistralrs_tests".to_string(),
437                    RepoType::Model,
438                    "main".to_string(),
439                ));
440
441                let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
442                Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
443            }
444            other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
445        }
446    }
447
448    // Content based upon https://github.com/ggerganov/llama.cpp/blob/master/tests/test-tokenizer-random.py#L99-L161
449    fn get_test_passage() -> String {
450        let passage = "Hello, world! \n🚀 (normal) 😶‍🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
451
452        passage.to_owned()
453    }
454
455    // The provided passage should encode and decode back into the same passage string:
456    fn codec_roundtrip(
457        tokenizer: &Tokenizer,
458        passage: &str,
459        add_special_tokens: bool,
460    ) -> Result<String> {
461        let tokenized = tokenizer
462            .encode_fast(passage, add_special_tokens)
463            .map_err(anyhow::Error::msg)?;
464
465        // NOTE: The special tokens bool param meaning differs between encode() / decode():
466        decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
467    }
468
469    fn decode(
470        tokenizer: &Tokenizer,
471        token_ids: &[u32],
472        skip_special_tokens: bool,
473    ) -> Result<String> {
474        tokenizer
475            .decode(token_ids, skip_special_tokens)
476            .map_err(anyhow::Error::msg)
477    }
478
479    #[test]
480    fn test_encode_decode_llama() -> Result<()> {
481        use rand::rng;
482        use rand::seq::SliceRandom;
483
484        let passage = get_test_passage();
485        let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
486        let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
487
488        // Without adding special tokens
489        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
490        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
491        assert_eq!(hf_decoded, gguf_decoded);
492        assert_eq!(passage, gguf_decoded);
493
494        // With special tokens added
495        // SKIPPED:
496        // - Bugged the GGUF tokenizer does not prepend `<s> `
497        // - Due to HF tokenizer using BPE (tokenizer.json) while GGUF tokenizer uses Unigram (metadata)?
498        /*
499        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
500        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
501        assert_eq!(hf_decoded, gguf_decoded);
502        */
503
504        #[allow(clippy::cast_possible_truncation)]
505        let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
506        tokens.shuffle(&mut rng());
507
508        // Without skipping special tokens
509        let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
510        let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
511        assert_eq!(hf_decoded, gguf_decoded);
512
513        // With skipping special tokens
514        let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
515        let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
516        assert_eq!(hf_decoded, gguf_decoded);
517
518        Ok(())
519    }
520
521    #[test]
522    fn test_encode_decode_gpt2() -> Result<()> {
523        use rand::rng;
524        use rand::seq::SliceRandom;
525
526        let passage = get_test_passage();
527        let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
528        let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
529
530        // Without adding special tokens
531        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
532        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
533        assert_eq!(hf_decoded, gguf_decoded);
534        assert_eq!(passage, gguf_decoded);
535
536        // With special tokens added
537        // SKIPPED:
538        // - Bugged the GGUF tokenizer does not prepend `<s> `
539        // - Due to HF tokenizer using BPE (tokenizer.json) while GGUF tokenizer uses Unigram (metadata)?
540        /*
541        let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), true)?;
542        let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), true)?;
543        assert_eq!(hf_decoded, gguf_decoded);
544        */
545
546        #[allow(clippy::cast_possible_truncation)]
547        let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
548        tokens.shuffle(&mut rng());
549
550        // Without skipping special tokens
551        let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
552        let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
553        assert_eq!(hf_decoded, gguf_decoded);
554
555        // With skipping special tokens
556        let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
557        let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
558        assert_eq!(hf_decoded, gguf_decoded);
559
560        Ok(())
561    }
562}