1use std::sync::atomic::Ordering;
4
5use crate::utils::gguf_metadata::ContentMetadata;
6use crate::DEBUG;
7use ahash::AHashMap;
8use anyhow::Result;
9use candle_core::quantized::gguf_file::Value;
10use itertools::Itertools;
11use tokenizers::pre_tokenizers::{
12 sequence::Sequence,
13 split::{Split, SplitPattern},
14 PreTokenizerWrapper,
15};
16use tokenizers::tokenizer::normalizer::SplitDelimiterBehavior;
17use tokenizers::{
18 decoders::{
19 self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
20 },
21 models::{bpe::BpeBuilder, unigram::Unigram},
22 normalizers::{self, Prepend, Replace},
23 processors, AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
24};
25use tracing::info;
26
27use super::Content;
28
29pub(crate) struct GgufTokenizerConversion {
30 pub tokenizer: Tokenizer,
31 pub bos: Option<String>,
32 pub eos: Option<String>,
33 pub unk: Option<String>,
34}
35
36struct PropsGGUF {
37 model: String,
38 tokens: Vec<String>,
39 added_tokens: Option<Vec<String>>,
40 scores: Option<Vec<f32>>,
41 merges: Option<Vec<String>>,
42 unk: Option<u32>,
43 bos: Option<u32>,
44 eos: u32,
45}
46
47impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
48 type Error = anyhow::Error;
49
50 fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
51 let required = ["model", "tokens", "eos_token_id"];
52 c.has_required_keys(&required)?;
53
54 let props = Self {
55 model: c.get_value("model")?,
56 tokens: c.get_value("tokens")?,
57 added_tokens: c.get_value("added_tokens").ok(),
58 scores: c.get_value("scores").ok(),
59 merges: c.get_value("merges").ok(),
60 unk: c.get_value("unknown_token_id").ok(),
61 eos: c.get_value("eos_token_id")?,
62 bos: c.get_value("bos_token_id").ok(),
63 };
64
65 Ok(props)
66 }
67}
68
69pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
70 content: &Content<'_, R>,
71) -> Result<GgufTokenizerConversion> {
72 let metadata = ContentMetadata {
73 path_prefix: "tokenizer.ggml",
74 metadata: content.get_metadata(),
75 };
76
77 let md_get = |s: &str| match metadata.metadata.get(s) {
78 None => candle_core::bail!("cannot find {s} in metadata"),
79 Some(v) => Ok(v),
80 };
81
82 let mut token_types = Vec::<i32>::new();
83 if metadata.metadata.contains_key("tokenizer.ggml.token_type") {
84 let vtypes: &Vec<Value> = md_get("tokenizer.ggml.token_type")
85 .unwrap()
86 .to_vec()
87 .unwrap();
88 let v: Vec<i32> = vtypes.iter().map(|v| v.to_i32().unwrap()).collect();
89 token_types.extend(v);
90 }
91
92 let props = PropsGGUF::try_from(metadata)?;
93
94 let (mut tokenizer, kind) = match props.model.as_str() {
95 "llama" | "replit" => unigram_tokenizer(&props)?,
96 "gpt2" => bpe_tokenizer(&props)?,
97 other => {
98 anyhow::bail!("Tokenizer model `{other}` not supported.");
99 }
100 };
101
102 let mut num_special_tokens = 0;
104 #[allow(clippy::needless_range_loop)]
105 if token_types.len() == props.tokens.len() {
106 for i in 0..props.tokens.len() {
107 if token_types[i] != 1i32 {
108 let tk = props.tokens[i].clone();
109 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
110 num_special_tokens += 1;
111 }
112 }
113 }
114
115 info!(
116 "GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num special tokens {}, num added tokens: {}, num merges: {}, num scores: {}",
117 tokenizer.get_vocab_size(true),
118 num_special_tokens,
119 props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
120 props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
121 props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
122 model = props.model,
123 );
124 if DEBUG.load(Ordering::Relaxed) {
125 info!("Tokenizer: {tokenizer:?}");
126 }
127
128 let unk = match props.unk {
129 Some(u) => Some(props.tokens[u as usize].clone()),
130 _ => None,
131 };
132
133 let bos = if props.bos.is_some() {
134 Some(props.tokens[props.bos.unwrap() as usize].clone())
135 } else {
136 None
137 };
138
139 Ok(GgufTokenizerConversion {
140 tokenizer,
141 bos,
142 eos: Some(props.tokens[props.eos as usize].clone()),
143 unk,
144 })
145}
146
147#[derive(Debug)]
150enum TokenizerKind {
151 Unigram,
152 Bpe,
153}
154
155fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
156 let PropsGGUF { unk, eos, bos, .. } = *p;
157 let unk = unk.unwrap_or(0);
159
160 let model = {
162 let vocab: Vec<(String, f64)> = {
163 let Some(s) = p.scores.as_ref() else {
164 anyhow::bail!(
165 "`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
166 );
167 };
168 let scores = s.iter().cloned().map(|f_32| f_32 as f64);
169
170 p.tokens.iter().cloned().zip(scores).collect()
171 };
172
173 Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
174 };
175
176 let decoder = Decoder::Sequence(vec![
179 Decoder::Replace("▁", " "),
180 Decoder::ByteFallback,
181 Decoder::Fuse,
182 Decoder::Strip(' ', 1, 0),
183 ]);
184
185 let normalizer = Normalizer::Sequence(vec![
186 Normalizer::Prepend("▁"),
187 Normalizer::Replace(" ", "▁"),
188 ]);
189
190 let mut tokenizer: Tokenizer = TokenizerX::new(
191 ModelWrapper::Unigram(model),
192 Some(decoder),
193 Some(normalizer),
194 )?;
195
196 for v in [bos, Some(eos), Some(unk)].iter().flatten() {
198 let tk = p.tokens[*v as usize].clone();
199 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
200 }
201 Ok((tokenizer, TokenizerKind::Unigram))
202}
203
204fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
205 let merges = p
208 .merges
209 .as_ref()
210 .ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
211 .iter()
212 .map(|merge| {
213 let split: (&str, &str) = merge
214 .splitn(2, ' ')
215 .collect_tuple()
216 .expect("Failed to convert split into 2-tuple");
217 (split.0.to_string(), split.1.to_string())
218 })
219 .collect::<Vec<_>>();
220
221 let mut vocab = AHashMap::new();
222 for (i, token) in p.tokens.iter().enumerate() {
223 #[allow(clippy::cast_possible_truncation)]
224 vocab.insert(token.clone(), i as u32);
225 }
226
227 let PropsGGUF { bos, eos, unk, .. } = *p;
228
229 let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
230 if let Some(unk) = unk {
231 bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
232 };
233
234 let bpe = bpe.build().map_err(anyhow::Error::msg)?;
235
236 let mut tokenizer = TokenizerX::new(
237 ModelWrapper::BPE(bpe),
238 Some(Decoder::ByteLevel(true, true, true)),
239 None,
240 )?;
241
242 let split = Split::new(
243 SplitPattern::Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()),
244 SplitDelimiterBehavior::Isolated,
245 false,
246 ).unwrap();
247
248 let pre_tokenizer = Sequence::new(vec![
254 PreTokenizerWrapper::Split(split),
255 PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, false, false)),
256 ]);
257
258 tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
259
260 tokenizer.with_decoder(Some(decoders::byte_level::ByteLevel::new(
261 false, false, false,
262 )));
263 tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
264 false, false, false,
265 )));
266
267 for v in [bos, Some(eos), unk].iter().flatten() {
268 let tk = p.tokens[*v as usize].clone();
269 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
270 }
271
272 Ok((tokenizer, TokenizerKind::Bpe))
273}
274
275struct TokenizerX;
279
280impl TokenizerX {
281 #[allow(clippy::new_ret_no_self)]
282 fn new<'a>(
283 model: ModelWrapper,
284 decoder: Option<Decoder<'a>>,
285 normalizer: Option<Normalizer<'a>>,
286 ) -> Result<Tokenizer> {
287 let mut tokenizer = Tokenizer::new(model);
288
289 if let Some(decoder) = decoder {
291 let d = DecoderWrapper::try_from(decoder)?;
292 tokenizer.with_decoder(Some(d));
293 }
294 if let Some(normalizer) = normalizer {
295 let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
296 tokenizer.with_normalizer(Some(n));
297 }
298
299 Ok(tokenizer)
300 }
301}
302
303enum Decoder<'a> {
306 ByteFallback,
307 Fuse,
308 Replace(&'a str, &'a str),
309 Strip(char, usize, usize),
310 Sequence(Vec<Self>),
311 ByteLevel(bool, bool, bool),
312}
313
314impl TryFrom<Decoder<'_>> for DecoderWrapper {
316 type Error = anyhow::Error;
317
318 fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
319 let value: DecoderWrapper = match variant {
320 Decoder::ByteFallback => ByteFallback::default().into(),
321 Decoder::Fuse => Fuse::default().into(),
322 Decoder::Replace(pattern, content) => Replace::new(pattern, content)
323 .map_err(anyhow::Error::msg)?
324 .into(),
325 Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
326 Decoder::Sequence(decoders) => {
327 let seq = decoders
328 .into_iter()
329 .map(DecoderWrapper::try_from)
330 .collect::<Result<Vec<DecoderWrapper>>>()?;
331
332 decoders::sequence::Sequence::new(seq).into()
333 }
334 Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
335 ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
336 }
337 };
338
339 Ok(value)
340 }
341}
342
343enum Normalizer<'a> {
346 Prepend(&'a str),
347 Replace(&'a str, &'a str),
348 Sequence(Vec<Self>),
349}
350
351impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
352 type Error = anyhow::Error;
353
354 fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
355 let value: NormalizerWrapper = match variant {
356 Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
357 Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
358 .map_err(anyhow::Error::msg)?
359 .into(),
360 Normalizer::Sequence(decoders) => {
361 let seq = decoders
362 .into_iter()
363 .map(NormalizerWrapper::try_from)
364 .collect::<Result<Vec<NormalizerWrapper>>>()?;
365
366 normalizers::Sequence::new(seq).into()
367 }
368 };
369
370 Ok(value)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use anyhow::Result;
377 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
378 use tokenizers::Tokenizer;
379
380 #[allow(dead_code)]
381 #[derive(Debug)]
382 enum TokenizerType {
383 Llama,
385 Replit,
386 Gpt2,
387 Rwkv,
388 }
389
390 fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
391 match tokenizer {
392 TokenizerType::Llama => {
393 let api = ApiBuilder::new().with_progress(true).build().unwrap();
394 let api = api.repo(Repo::with_revision(
395 "EricB/mistralrs_tests".to_string(),
396 RepoType::Model,
397 "main".to_string(),
398 ));
399
400 let filename = api.get("llama_gguf_tokenizer.json").unwrap();
401 let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
402 Ok(tokenizer)
403 }
404 TokenizerType::Gpt2 => {
405 let api = ApiBuilder::new().with_progress(true).build().unwrap();
406 let api = api.repo(Repo::with_revision(
407 "EricB/mistralrs_tests".to_string(),
408 RepoType::Model,
409 "main".to_string(),
410 ));
411
412 let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
413 let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
414 Ok(tokenizer)
415 }
416 other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
417 }
418 }
419
420 fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
421 match tokenizer {
422 TokenizerType::Llama => {
423 let api = ApiBuilder::new().with_progress(true).build().unwrap();
424 let api = api.repo(Repo::with_revision(
425 "EricB/mistralrs_tests".to_string(),
426 RepoType::Model,
427 "main".to_string(),
428 ));
429
430 let tokenizer_filename = api.get("tokenizer.json").unwrap();
431 Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
432 }
433 TokenizerType::Gpt2 => {
434 let api = ApiBuilder::new().with_progress(true).build().unwrap();
435 let api = api.repo(Repo::with_revision(
436 "EricB/mistralrs_tests".to_string(),
437 RepoType::Model,
438 "main".to_string(),
439 ));
440
441 let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
442 Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
443 }
444 other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
445 }
446 }
447
448 fn get_test_passage() -> String {
450 let passage = "Hello, world! \n🚀 (normal) 😶🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
451
452 passage.to_owned()
453 }
454
455 fn codec_roundtrip(
457 tokenizer: &Tokenizer,
458 passage: &str,
459 add_special_tokens: bool,
460 ) -> Result<String> {
461 let tokenized = tokenizer
462 .encode_fast(passage, add_special_tokens)
463 .map_err(anyhow::Error::msg)?;
464
465 decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
467 }
468
469 fn decode(
470 tokenizer: &Tokenizer,
471 token_ids: &[u32],
472 skip_special_tokens: bool,
473 ) -> Result<String> {
474 tokenizer
475 .decode(token_ids, skip_special_tokens)
476 .map_err(anyhow::Error::msg)
477 }
478
479 #[test]
480 fn test_encode_decode_llama() -> Result<()> {
481 use rand::rng;
482 use rand::seq::SliceRandom;
483
484 let passage = get_test_passage();
485 let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
486 let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
487
488 let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
490 let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
491 assert_eq!(hf_decoded, gguf_decoded);
492 assert_eq!(passage, gguf_decoded);
493
494 #[allow(clippy::cast_possible_truncation)]
505 let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
506 tokens.shuffle(&mut rng());
507
508 let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
510 let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
511 assert_eq!(hf_decoded, gguf_decoded);
512
513 let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
515 let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
516 assert_eq!(hf_decoded, gguf_decoded);
517
518 Ok(())
519 }
520
521 #[test]
522 fn test_encode_decode_gpt2() -> Result<()> {
523 use rand::rng;
524 use rand::seq::SliceRandom;
525
526 let passage = get_test_passage();
527 let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
528 let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
529
530 let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
532 let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
533 assert_eq!(hf_decoded, gguf_decoded);
534 assert_eq!(passage, gguf_decoded);
535
536 #[allow(clippy::cast_possible_truncation)]
547 let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
548 tokens.shuffle(&mut rng());
549
550 let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
552 let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
553 assert_eq!(hf_decoded, gguf_decoded);
554
555 let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
557 let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
558 assert_eq!(hf_decoded, gguf_decoded);
559
560 Ok(())
561 }
562}