1use std::{collections::HashMap, sync::atomic::Ordering};
4
5use crate::utils::gguf_metadata::ContentMetadata;
6use crate::DEBUG;
7use anyhow::Result;
8use candle_core::quantized::gguf_file::Value;
9use itertools::Itertools;
10use tokenizers::pre_tokenizers::{
11 sequence::Sequence,
12 split::{Split, SplitPattern},
13 PreTokenizerWrapper,
14};
15use tokenizers::tokenizer::normalizer::SplitDelimiterBehavior;
16use tokenizers::{
17 decoders::{
18 self, byte_fallback::ByteFallback, byte_level::ByteLevel, fuse::Fuse, strip::Strip,
19 },
20 models::{bpe::BpeBuilder, unigram::Unigram},
21 normalizers::{self, Prepend, Replace},
22 processors, AddedToken, DecoderWrapper, ModelWrapper, NormalizerWrapper, Tokenizer,
23};
24use tracing::info;
25
26use super::Content;
27
28pub(crate) struct GgufTokenizerConversion {
29 pub tokenizer: Tokenizer,
30 pub bos: Option<String>,
31 pub eos: Option<String>,
32 pub unk: Option<String>,
33}
34
35struct PropsGGUF {
36 model: String,
37 tokens: Vec<String>,
38 added_tokens: Option<Vec<String>>,
39 scores: Option<Vec<f32>>,
40 merges: Option<Vec<String>>,
41 unk: Option<u32>,
42 eos: u32,
43 bos: u32,
44}
45
46impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
47 type Error = anyhow::Error;
48
49 fn try_from(c: ContentMetadata) -> Result<Self, Self::Error> {
50 let required = ["model", "tokens", "eos_token_id", "bos_token_id"];
51 c.has_required_keys(&required)?;
52
53 let props = Self {
54 model: c.get_value("model")?,
55 tokens: c.get_value("tokens")?,
56 added_tokens: c.get_value("added_tokens").ok(),
57 scores: c.get_value("scores").ok(),
58 merges: c.get_value("merges").ok(),
59 unk: c.get_value("unknown_token_id").ok(),
60 eos: c.get_value("eos_token_id")?,
61 bos: c.get_value("bos_token_id")?,
62 };
63
64 Ok(props)
65 }
66}
67
68pub fn convert_gguf_to_hf_tokenizer<R: std::io::Seek + std::io::Read>(
69 content: &Content<'_, R>,
70) -> Result<GgufTokenizerConversion> {
71 let metadata = ContentMetadata {
72 path_prefix: "tokenizer.ggml",
73 metadata: content.get_metadata(),
74 };
75
76 let md_get = |s: &str| match metadata.metadata.get(s) {
77 None => candle_core::bail!("cannot find {s} in metadata"),
78 Some(v) => Ok(v),
79 };
80
81 let mut token_types = Vec::<i32>::new();
82 if metadata.metadata.contains_key("tokenizer.ggml.token_type") {
83 let vtypes: &Vec<Value> = md_get("tokenizer.ggml.token_type")
84 .unwrap()
85 .to_vec()
86 .unwrap();
87 let v: Vec<i32> = vtypes.iter().map(|v| v.to_i32().unwrap()).collect();
88 token_types.extend(v);
89 }
90
91 let props = PropsGGUF::try_from(metadata)?;
92
93 let (mut tokenizer, kind) = match props.model.as_str() {
94 "llama" | "replit" => unigram_tokenizer(&props)?,
95 "gpt2" => bpe_tokenizer(&props)?,
96 other => {
97 anyhow::bail!("Tokenizer model `{other}` not supported.");
98 }
99 };
100
101 let mut num_special_tokens = 0;
103 #[allow(clippy::needless_range_loop)]
104 if token_types.len() == props.tokens.len() {
105 for i in 0..props.tokens.len() {
106 if token_types[i] != 1i32 {
107 let tk = props.tokens[i].clone();
108 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
109 num_special_tokens += 1;
110 }
111 }
112 }
113
114 info!(
115 "GGUF tokenizer model is `{model}`, kind: `{kind:?}`, num tokens: {}, num special tokens {}, num added tokens: {}, num merges: {}, num scores: {}",
116 tokenizer.get_vocab_size(true),
117 num_special_tokens,
118 props.added_tokens.as_ref().map(|x| x.len()).unwrap_or(0),
119 props.merges.as_ref().map(|x| x.len()).unwrap_or(0),
120 props.scores.as_ref().map(|x| x.len()).unwrap_or(0),
121 model = props.model,
122 );
123 if DEBUG.load(Ordering::Relaxed) {
124 info!("Tokenizer: {tokenizer:?}");
125 }
126
127 let unk = match props.unk {
128 Some(u) => Some(props.tokens[u as usize].clone()),
129 _ => None,
130 };
131
132 Ok(GgufTokenizerConversion {
133 tokenizer,
134 bos: Some(props.tokens[props.bos as usize].clone()),
135 eos: Some(props.tokens[props.eos as usize].clone()),
136 unk,
137 })
138}
139
140#[derive(Debug)]
143enum TokenizerKind {
144 Unigram,
145 Bpe,
146}
147
148fn unigram_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
149 let PropsGGUF { unk, eos, bos, .. } = *p;
150 let unk = unk.unwrap_or(0);
152
153 let model = {
155 let vocab: Vec<(String, f64)> = {
156 let Some(s) = p.scores.as_ref() else {
157 anyhow::bail!(
158 "`llama` unigram tokenizer is missing required metadata `tokenizer.ggml.scores`"
159 );
160 };
161 let scores = s.iter().cloned().map(|f_32| f_32 as f64);
162
163 p.tokens.iter().cloned().zip(scores).collect()
164 };
165
166 Unigram::from(vocab, Some(unk as usize), true).map_err(anyhow::Error::msg)?
167 };
168
169 let decoder = Decoder::Sequence(vec![
172 Decoder::Replace("▁", " "),
173 Decoder::ByteFallback,
174 Decoder::Fuse,
175 Decoder::Strip(' ', 1, 0),
176 ]);
177
178 let normalizer = Normalizer::Sequence(vec![
179 Normalizer::Prepend("▁"),
180 Normalizer::Replace(" ", "▁"),
181 ]);
182
183 let mut tokenizer: Tokenizer = TokenizerX::new(
184 ModelWrapper::Unigram(model),
185 Some(decoder),
186 Some(normalizer),
187 )?;
188
189 for i in [bos, eos, unk] {
191 let tk = p.tokens[i as usize].clone();
192 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
193 }
194 Ok((tokenizer, TokenizerKind::Unigram))
195}
196
197fn bpe_tokenizer(p: &PropsGGUF) -> Result<(Tokenizer, TokenizerKind)> {
198 let merges = p
201 .merges
202 .as_ref()
203 .ok_or(anyhow::Error::msg("BPE tokenizer must include merges"))?
204 .iter()
205 .map(|merge| {
206 let split: (&str, &str) = merge
207 .splitn(2, ' ')
208 .collect_tuple()
209 .expect("Failed to convert split into 2-tuple");
210 (split.0.to_string(), split.1.to_string())
211 })
212 .collect::<Vec<_>>();
213
214 let mut vocab = HashMap::new();
215 for (i, token) in p.tokens.iter().enumerate() {
216 #[allow(clippy::cast_possible_truncation)]
217 vocab.insert(token.clone(), i as u32);
218 }
219
220 let PropsGGUF { eos, bos, unk, .. } = *p;
221
222 let mut bpe = BpeBuilder::new().vocab_and_merges(vocab, merges);
223 if let Some(unk) = unk {
224 bpe = bpe.unk_token(p.tokens[unk as usize].to_string());
225 };
226
227 let bpe = bpe.build().map_err(anyhow::Error::msg)?;
228
229 let mut tokenizer = TokenizerX::new(
230 ModelWrapper::BPE(bpe),
231 Some(Decoder::ByteLevel(true, true, true)),
232 None,
233 )?;
234
235 let split = Split::new(
236 SplitPattern::Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+".to_string()),
237 SplitDelimiterBehavior::Isolated,
238 false,
239 ).unwrap();
240
241 let pre_tokenizer = Sequence::new(vec![
247 PreTokenizerWrapper::Split(split),
248 PreTokenizerWrapper::ByteLevel(ByteLevel::new(false, false, false)),
249 ]);
250
251 tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
252
253 tokenizer.with_decoder(Some(decoders::byte_level::ByteLevel::new(
254 false, false, false,
255 )));
256 tokenizer.with_post_processor(Some(processors::byte_level::ByteLevel::new(
257 false, false, false,
258 )));
259
260 for i in [bos, eos] {
261 let tk = p.tokens[i as usize].clone();
262 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
263 }
264 if unk.is_some() {
265 let tk = p.tokens[unk.unwrap() as usize].clone();
266 tokenizer.add_special_tokens(&[AddedToken::from(tk.to_string(), true)]);
267 }
268
269 Ok((tokenizer, TokenizerKind::Bpe))
270}
271
272struct TokenizerX;
276
277impl TokenizerX {
278 #[allow(clippy::new_ret_no_self)]
279 fn new<'a>(
280 model: ModelWrapper,
281 decoder: Option<Decoder<'a>>,
282 normalizer: Option<Normalizer<'a>>,
283 ) -> Result<Tokenizer> {
284 let mut tokenizer = Tokenizer::new(model);
285
286 if let Some(decoder) = decoder {
288 let d = DecoderWrapper::try_from(decoder)?;
289 tokenizer.with_decoder(Some(d));
290 }
291 if let Some(normalizer) = normalizer {
292 let n: NormalizerWrapper = NormalizerWrapper::try_from(normalizer)?;
293 tokenizer.with_normalizer(Some(n));
294 }
295
296 Ok(tokenizer)
297 }
298}
299
300enum Decoder<'a> {
303 ByteFallback,
304 Fuse,
305 Replace(&'a str, &'a str),
306 Strip(char, usize, usize),
307 Sequence(Vec<Self>),
308 ByteLevel(bool, bool, bool),
309}
310
311impl TryFrom<Decoder<'_>> for DecoderWrapper {
313 type Error = anyhow::Error;
314
315 fn try_from(variant: Decoder) -> Result<Self, Self::Error> {
316 let value: DecoderWrapper = match variant {
317 Decoder::ByteFallback => ByteFallback::default().into(),
318 Decoder::Fuse => Fuse::default().into(),
319 Decoder::Replace(pattern, content) => Replace::new(pattern, content)
320 .map_err(anyhow::Error::msg)?
321 .into(),
322 Decoder::Strip(content, start, stop) => Strip::new(content, start, stop).into(),
323 Decoder::Sequence(decoders) => {
324 let seq = decoders
325 .into_iter()
326 .map(DecoderWrapper::try_from)
327 .collect::<Result<Vec<DecoderWrapper>>>()?;
328
329 decoders::sequence::Sequence::new(seq).into()
330 }
331 Decoder::ByteLevel(add_prefix_space, trim_offsets, use_regex) => {
332 ByteLevel::new(add_prefix_space, trim_offsets, use_regex).into()
333 }
334 };
335
336 Ok(value)
337 }
338}
339
340enum Normalizer<'a> {
343 Prepend(&'a str),
344 Replace(&'a str, &'a str),
345 Sequence(Vec<Self>),
346}
347
348impl TryFrom<Normalizer<'_>> for NormalizerWrapper {
349 type Error = anyhow::Error;
350
351 fn try_from(variant: Normalizer) -> Result<Self, Self::Error> {
352 let value: NormalizerWrapper = match variant {
353 Normalizer::Prepend(prepend) => Prepend::new(prepend.to_owned()).into(),
354 Normalizer::Replace(pattern, content) => Replace::new(pattern, content)
355 .map_err(anyhow::Error::msg)?
356 .into(),
357 Normalizer::Sequence(decoders) => {
358 let seq = decoders
359 .into_iter()
360 .map(NormalizerWrapper::try_from)
361 .collect::<Result<Vec<NormalizerWrapper>>>()?;
362
363 normalizers::Sequence::new(seq).into()
364 }
365 };
366
367 Ok(value)
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use anyhow::Result;
374 use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
375 use tokenizers::Tokenizer;
376
377 #[allow(dead_code)]
378 #[derive(Debug)]
379 enum TokenizerType {
380 Llama,
382 Replit,
383 Gpt2,
384 Rwkv,
385 }
386
387 fn get_gguf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
388 match tokenizer {
389 TokenizerType::Llama => {
390 let api = ApiBuilder::new().with_progress(true).build().unwrap();
391 let api = api.repo(Repo::with_revision(
392 "EricB/mistralrs_tests".to_string(),
393 RepoType::Model,
394 "main".to_string(),
395 ));
396
397 let filename = api.get("llama_gguf_tokenizer.json").unwrap();
398 let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
399 Ok(tokenizer)
400 }
401 TokenizerType::Gpt2 => {
402 let api = ApiBuilder::new().with_progress(true).build().unwrap();
403 let api = api.repo(Repo::with_revision(
404 "EricB/mistralrs_tests".to_string(),
405 RepoType::Model,
406 "main".to_string(),
407 ));
408
409 let filename = api.get("gpt2_gguf_tokenizer.json").unwrap();
410 let tokenizer = Tokenizer::from_file(filename).expect("Valid tokenizer");
411 Ok(tokenizer)
412 }
413 other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
414 }
415 }
416
417 fn get_hf_tokenizer(tokenizer: TokenizerType) -> Result<Tokenizer> {
418 match tokenizer {
419 TokenizerType::Llama => {
420 let api = ApiBuilder::new().with_progress(true).build().unwrap();
421 let api = api.repo(Repo::with_revision(
422 "EricB/mistralrs_tests".to_string(),
423 RepoType::Model,
424 "main".to_string(),
425 ));
426
427 let tokenizer_filename = api.get("tokenizer.json").unwrap();
428 Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
429 }
430 TokenizerType::Gpt2 => {
431 let api = ApiBuilder::new().with_progress(true).build().unwrap();
432 let api = api.repo(Repo::with_revision(
433 "EricB/mistralrs_tests".to_string(),
434 RepoType::Model,
435 "main".to_string(),
436 ));
437
438 let tokenizer_filename = api.get("tokenizer_gpt2.json").unwrap();
439 Ok(Tokenizer::from_file(tokenizer_filename).unwrap())
440 }
441 other => anyhow::bail!("Cannot get testing HF tokenizer for type {other:?}"),
442 }
443 }
444
445 fn get_test_passage() -> String {
447 let passage = "Hello, world! \n🚀 (normal) 😶🌫️ (compound emoji, zwj sequence) ✅ (emoji as single token)\n你好世界!\nNǐ hǎo shìjiè!";
448
449 passage.to_owned()
450 }
451
452 fn codec_roundtrip(
454 tokenizer: &Tokenizer,
455 passage: &str,
456 add_special_tokens: bool,
457 ) -> Result<String> {
458 let tokenized = tokenizer
459 .encode_fast(passage, add_special_tokens)
460 .map_err(anyhow::Error::msg)?;
461
462 decode(tokenizer, tokenized.get_ids(), !add_special_tokens)
464 }
465
466 fn decode(
467 tokenizer: &Tokenizer,
468 token_ids: &[u32],
469 skip_special_tokens: bool,
470 ) -> Result<String> {
471 tokenizer
472 .decode(token_ids, skip_special_tokens)
473 .map_err(anyhow::Error::msg)
474 }
475
476 #[test]
477 fn test_encode_decode_llama() -> Result<()> {
478 use rand::rng;
479 use rand::seq::SliceRandom;
480
481 let passage = get_test_passage();
482 let hf_tokenizer = get_hf_tokenizer(TokenizerType::Llama)?;
483 let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Llama)?;
484
485 let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
487 let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
488 assert_eq!(hf_decoded, gguf_decoded);
489 assert_eq!(passage, gguf_decoded);
490
491 #[allow(clippy::cast_possible_truncation)]
502 let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
503 tokens.shuffle(&mut rng());
504
505 let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
507 let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
508 assert_eq!(hf_decoded, gguf_decoded);
509
510 let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
512 let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
513 assert_eq!(hf_decoded, gguf_decoded);
514
515 Ok(())
516 }
517
518 #[test]
519 fn test_encode_decode_gpt2() -> Result<()> {
520 use rand::rng;
521 use rand::seq::SliceRandom;
522
523 let passage = get_test_passage();
524 let hf_tokenizer = get_hf_tokenizer(TokenizerType::Gpt2)?;
525 let gguf_tokenizer = get_gguf_tokenizer(TokenizerType::Gpt2)?;
526
527 let hf_decoded = codec_roundtrip(&hf_tokenizer, passage.as_str(), false)?;
529 let gguf_decoded = codec_roundtrip(&gguf_tokenizer, passage.as_str(), false)?;
530 assert_eq!(hf_decoded, gguf_decoded);
531 assert_eq!(passage, gguf_decoded);
532
533 #[allow(clippy::cast_possible_truncation)]
544 let mut tokens = (0..hf_tokenizer.get_vocab_size(false) as u32).collect::<Vec<_>>();
545 tokens.shuffle(&mut rng());
546
547 let hf_decoded = decode(&hf_tokenizer, &tokens, false)?;
549 let gguf_decoded = decode(&gguf_tokenizer, &tokens, false)?;
550 assert_eq!(hf_decoded, gguf_decoded);
551
552 let hf_decoded = decode(&hf_tokenizer, &tokens, true)?;
554 let gguf_decoded = decode(&gguf_tokenizer, &tokens, true)?;
555 assert_eq!(hf_decoded, gguf_decoded);
556
557 Ok(())
558 }
559}