1use std::{collections::HashMap, sync::atomic::Ordering};
4
5use anyhow::Result;
6use itertools::Itertools;
7use tokenizers::{
8 decoders::{
9 self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
10 },
11 models::{bpe::BpeBuilder, unigram::Unigram},
12 normalizers::{self, Prepend, Replace},
13 pre_tokenizers,
14 processors::{
15 self,
16 template::{self, TemplateProcessing},
17 },
18 AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
19};
20use tracing::info;
21
22use crate::utils::gguf_metadata::ContentMetadata;
23use crate::DEBUG;
24
25use super::Content;
26
27pub(crate) struct GgufTokenizerConversion {
28 pub tokenizer: Tokenizer,
29 pub bos: Option<String>,
30 pub eos: Option<String>,
31 pub unk: Option<String>,
32}
33
34struct PropsGGUF {
35 model: String,
36 tokens: Vec<String>,
37 added_tokens: Option<Vec<String>>,
38 scores: Option<Vec<f32>>,
39 merges: Option<Vec<String>>,
40 unk: Option<u32>,
41 eos: u32,
42 bos: u32,
43 add_bos_token: Option<bool>,
44}
45
46impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
47 type Error = anyhow::Error;
48
49 fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
50 let required = ["model", "tokens", "eos_token_id", "bos_token_id"];
51 c.has_required_keys(&required)?;
52
53 let props = Self {
54 model: c.get_value("model")?,
55 tokens: c.get_value("tokens")?,
56 added_tokens: c.get_value("added_tokens").ok(),
57 scores: c.get_value("scores").ok(),
58 merges: c.get_value("merges").ok(),
59 unk: c.get_value("unknown_token_id").ok(),
60 eos: c.get_value("eos_token_id")?,
61 bos: c.get_value("bos_token_id")?,
62 add_bos_token: c.get_value("add_bos_token").ok(),
63 };
64
65 Ok(props)
66 }
67}
68
69struct AddedTokensCollection {
70 bos: String,
71 eos: String,
72 unk: Option<String>,
73}
74
75pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
76 content: &Content<'_, R>,
77) -> Result<GgufTokenizerConversion> {
78 let metadata = ContentMetadata {
79 path_prefix: "tokenizer.ggml",
80 metadata: content.get_metadata(),
81 };
82 let props = PropsGGUF::try_from(metadata)?;
83
84 let (tokenizer, kind, special_tokens) = match props.model.as_str() {
85 "llama" | "replit" => unigram_tokenizer(&props)?,
86 "gpt2" => bpe_tokenizer(&props)?,
87 other => {
88 anyhow::bail!("Tokenizer model `{other}` not supported.");
89 }
90 };
91
92 info!(
93 "GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num added tokens: {}, num merges: {}, num scores: {}",
94 tokenizer.get_vocab_size(true),
95 props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
96 props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
97 props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
98 model = props.model,
99 );
100 if DEBUG.load(Ordering::Relaxed) {
101 info!("Tokenizer: {tokenizer:?}");
102 }
103
104 let AddedTokensCollection { bos, eos, unk } = special_tokens;
105
106 Ok(GgufTokenizerConversion {
107 tokenizer,
108 bos: Some(bos),
109 eos: Some(eos),
110 unk,
111 })
112}
113
114#[derive(Debug)]
117enum TokenizerKind {
118 Unigram,
119 Bpe,
120}
121
122fn add_special_tokens(
124 p: &PropsGGUF,
125 tokenizer: &mut Tokenizer,
126 bos: u32,
127 eos: u32,
128 unk: Option<u32>,
129) -> AddedTokensCollection {
130 let mut special_tokens: [Option<String>; 3] = Default::default();
132
133 for (i, token_id) in [Some(bos), Some(eos), unk].into_iter().enumerate() {
135 if let Some(token_id) = token_id {
136 let token = p.tokens[token_id as usize].as_str();
137 special_tokens[i] = Some(token.to_string());
138 tokenizer.add_special_tokens(&[AddedToken::from(token.to_string(), true)]);
139 }
140 }
141
142 let [bos_str, eos_str, unk_str] = special_tokens;
144 AddedTokensCollection {
146 bos: bos_str.unwrap(),
147 eos: eos_str.unwrap(),
148 unk: unk_str,
149 }
150}
151
152fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind, AddedTokensCollection)> {
153 let PropsGGUF { unk, eos, bos, .. } = *p;
154 let unk = unk.unwrap_or(0);
156
157 let model = {
159 let vocab: Vec<(String, f64)> = {
160 let Some(s) = p.scores.as_ref() else {
161 anyhow::bail!(
162 "`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
163 );
164 };
165 let scores = s.iter().cloned().map(|f_32| f_32 as f64);
166
167 p.tokens.iter().cloned().zip(scores).collect()
168 };
169
170 Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
171 };
172
173 let decoder = Decoder::Sequence(vec![
176 Decoder::Replace("▁", " "),
177 Decoder::ByteFallback,
178 Decoder::Fuse,
179 Decoder::Strip(' ', 1, 0),
180 ]);
181
182 let normalizer = Normalizer::Sequence(vec![
183 Normalizer::Prepend("▁"),
184 Normalizer::Replace(" ", "▁"),
185 ]);
186
187 let mut tokenizer: Tokenizer = TokenizerX::new(
188 ModelWrapper::Unigram(model),
189 Some(decoder),
190 Some(normalizer),
191 )?;
192
193 let special_tokens = add_special_tokens(p, &mut tokenizer, bos, eos, Some(unk));
195
196 Ok((tokenizer, TokenizerKind::Unigram, special_tokens))
197}
198
199fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind, AddedTokensCollection)> {
200 let merges = p
203 .merges
204 .as_ref()
205 .ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
206 .iter()
207 .map(|merge| {
208 let split: (&str, &str) = merge
209 .splitn(2, ' ')
210 .collect_tuple()
211 .expect("Failed to convert split into 2-tuple");
212 (split.0.to_string(), split.1.to_string())
213 })
214 .collect::<Vec<_>>();
215
216 let mut vocab = HashMap::new();
217 for (i, token) in p.tokens.iter().enumerate() {
218 #[allow(clippy::cast_possible_truncation)]
219 vocab.insert(token.clone(), i as u32);
220 }
221
222 let PropsGGUF {
223 eos,
224 bos,
225 unk,
226 add_bos_token,
227 ..
228 } = *p;
229
230 let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
231 if let Some(unk) = unk {
232 bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
233 };
234
235 let bpe = bpe.build().map_err(anyhow::Error::msg)?;
236
237 let mut tokenizer = TokenizerX::new(
238 ModelWrapper::BPE(bpe),
239 Some(Decoder::ByteLevel(true, true, true)),
240 None,
241 )?;
242 tokenizer.with_pre_tokenizer(Some(pre_tokenizers::byte_level::ByteLevel::new(
243 false, true, true,
244 )));
245 if add_bos_token.is_some_and(|x| x) {
246 let mut special_toks = HashMap::new();
247 special_toks.insert(
248 p.tokens[bos as usize].clone(),
249 template::SpecialToken::new(
250 p.tokens[bos as usize].clone(),
251 vec![bos],
252 vec![p.tokens[bos as usize].clone()],
253 )
254 .unwrap(),
255 );
256 tokenizer.with_post_processor(Some(
257 TemplateProcessing::builder()
258 .try_single(format!("{}:0 $A:0", p.tokens[bos as usize]))
259 .unwrap()
260 .try_pair(format!("{}:0 $A:0 $B:1", p.tokens[bos as usize]))
261 .unwrap()
262 .special_tokens(special_toks)
263 .build()
264 .unwrap(),
265 ));
266 } else {
267 tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
268 true, false, true,
269 )));
270 }
271
272 let special_tokens = add_special_tokens(p, &mut tokenizer, bos, eos, unk);
273
274 Ok((tokenizer, TokenizerKind::Bpe, special_tokens))
275}
276
277struct TokenizerX;
281
282impl TokenizerX {
283 #[allow(clippy::new_ret_no_self)]
284 fn new<'a>(
285 model: ModelWrapper,
286 decoder: Option<Decoder<'a>>,
287 normalizer: Option<Normalizer<'a>>,
288 ) -> Result<Tokenizer> {
289 let mut tokenizer = Tokenizer::new(model);
290
291 if let Some(decoder) = decoder {
293 let d = DecoderWrapper::try_from(decoder)?;
294 tokenizer.with_decoder(Some(d));
295 }
296 if let Some(normalizer) = normalizer {
297 let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
298 tokenizer.with_normalizer(Some(n));
299 }
300
301 Ok(tokenizer)
302 }
303}
304
305enum Decoder<'a> {
308 ByteFallback,
309 Fuse,
310 Replace(&'a str, &'a str),
311 Strip(char, usize, usize),
312 Sequence(Vec<Self>),
313 ByteLevel(bool, bool, bool),
314}
315
316impl TryFrom<Decoder<'_>> for DecoderWrapper {
318 type Error = anyhow::Error;
319
320 fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
321 let value: DecoderWrapper = match variant {
322 Decoder::ByteFallback => ByteFallback::default().into(),
323 Decoder::Fuse => Fuse::default().into(),
324 Decoder::Replace(pattern, content) => Replace::new(pattern, content)
325 .map_err(anyhow::Error::msg)?
326 .into(),
327 Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
328 Decoder::Sequence(decoders) => {
329 let seq = decoders
330 .into_iter()
331 .map(DecoderWrapper::try_from)
332 .collect::<Result<Vec<DecoderWrapper>>>()?;
333
334 decoders::sequence::Sequence::new(seq).into()
335 }
336 Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
337 ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
338 }
339 };
340
341 Ok(value)
342 }
343}
344
345enum Normalizer<'a> {
348 Prepend(&'a str),
349 Replace(&'a str, &'a str),
350 Sequence(Vec<Self>),
351}
352
353impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
354 type Error = anyhow::Error;
355
356 fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
357 let value: NormalizerWrapper = match variant {
358 Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
359 Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
360 .map_err(anyhow::Error::msg)?
361 .into(),
362 Normalizer::Sequence(decoders) => {
363 let seq = decoders
364 .into_iter()
365 .map(NormalizerWrapper::try_from)
366 .collect::<Result<Vec<NormalizerWrapper>>>()?;
367
368 normalizers::Sequence::new(seq).into()
369 }
370 };
371
372 Ok(value)
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use anyhow::Result;
379 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
380 use tokenizers::Tokenizer;
381
382 #[allow(dead_code)]
383 #[derive(Debug)]
384 enum TokenizerType {
385 Llama,
387 Replit,
388 Gpt2,
389 Rwkv,
390 }
391
392 fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
393 match tokenizer {
394 TokenizerType::Llama => {
395 let api = ApiBuilder::new().with_progress(true).build().unwrap();
396 let api = api.repo(Repo::with_revision(
397 "EricB/mistralrs_tests".to_string(),
398 RepoType::Model,
399 "main".to_string(),
400 ));
401
402 let filename = api.get("llama_gguf_tokenizer.json").unwrap();
403 let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
404 Ok(tokenizer)
405 }
406 TokenizerType::Gpt2 => {
407 let api = ApiBuilder::new().with_progress(true).build().unwrap();
408 let api = api.repo(Repo::with_revision(
409 "EricB/mistralrs_tests".to_string(),
410 RepoType::Model,
411 "main".to_string(),
412 ));
413
414 let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
415 let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
416 Ok(tokenizer)
417 }
418 other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
419 }
420 }
421
422 fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
423 match tokenizer {
424 TokenizerType::Llama => {
425 let api = ApiBuilder::new().with_progress(true).build().unwrap();
426 let api = api.repo(Repo::with_revision(
427 "EricB/mistralrs_tests".to_string(),
428 RepoType::Model,
429 "main".to_string(),
430 ));
431
432 let tokenizer_filename = api.get("tokenizer.json").unwrap();
433 Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
434 }
435 TokenizerType::Gpt2 => {
436 let api = ApiBuilder::new().with_progress(true).build().unwrap();
437 let api = api.repo(Repo::with_revision(
438 "EricB/mistralrs_tests".to_string(),
439 RepoType::Model,
440 "main".to_string(),
441 ));
442
443 let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
444 Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
445 }
446 other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
447 }
448 }
449
450 fn get_test_passage() -> String {
452 let passage = "Hello, world! \n🚀 (normal) 😶🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
453
454 passage.to_owned()
455 }
456
457 fn codec_roundtrip(
459 tokenizer: &Tokenizer,
460 passage: &str,
461 add_special_tokens: bool,
462 ) -> Result<String> {
463 let tokenized = tokenizer
464 .encode_fast(passage, add_special_tokens)
465 .map_err(anyhow::Error::msg)?;
466
467 decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
469 }
470
471 fn decode(
472 tokenizer: &Tokenizer,
473 token_ids: &[u32],
474 skip_special_tokens: bool,
475 ) -> Result<String> {
476 tokenizer
477 .decode(token_ids, skip_special_tokens)
478 .map_err(anyhow::Error::msg)
479 }
480
481 #[test]
482 fn test_encode_decode_llama() -> Result<()> {
483 use rand::rng;
484 use rand::seq::SliceRandom;
485
486 let passage = get_test_passage();
487 let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
488 let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
489
490 let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
492 let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
493 assert_eq!(hf_decoded, gguf_decoded);
494 assert_eq!(passage, gguf_decoded);
495
496 #[allow(clippy::cast_possible_truncation)]
507 let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
508 tokens.shuffle(&mut rng());
509
510 let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
512 let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
513 assert_eq!(hf_decoded, gguf_decoded);
514
515 let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
517 let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
518 assert_eq!(hf_decoded, gguf_decoded);
519
520 Ok(())
521 }
522
523 #[test]
524 fn test_encode_decode_gpt2() -> Result<()> {
525 use rand::rng;
526 use rand::seq::SliceRandom;
527
528 let passage = get_test_passage();
529 let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
530 let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
531
532 let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
534 let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
535 assert_eq!(hf_decoded, gguf_decoded);
536 assert_eq!(passage, gguf_decoded);
537
538 #[allow(clippy::cast_possible_truncation)]
549 let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
550 tokens.shuffle(&mut rng());
551
552 let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
554 let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
555 assert_eq!(hf_decoded, gguf_decoded);
556
557 let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
559 let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
560 assert_eq!(hf_decoded, gguf_decoded);
561
562 Ok(())
563 }
564}