mistralrs_core/utils/
tokenizer.rs1use std::{collections::HashMap, path::Path};
2
3use anyhow::Result;
4use serde::Deserialize;
5use serde_json::Value;
6use tokenizers::{tokenizer, Tokenizer};
7
8#[derive(Deserialize)]
9struct AddedToken {
10 id: usize,
11 content: String,
12}
13
14pub(crate) fn get_tokenizer<P: AsRef<Path> + Clone>(
16 p: P,
17 processor_added_tokens: Option<&[&str]>,
18) -> Result<Tokenizer> {
19 let mut tokenizer = {
20 let raw = std::fs::read(p.clone()).map_err(anyhow::Error::msg)?;
21 let mut tokenizer: Value = serde_json::from_slice(&raw).unwrap();
22 let added_tokens: Vec<AddedToken> =
23 serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap();
24 let vocab: HashMap<String, usize> =
25 serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap();
26 for token in added_tokens {
27 if !vocab.contains_key(&token.content) {
28 tokenizer["model"]["vocab"]
29 .as_object_mut()
30 .unwrap()
31 .insert(token.content, token.id.into())
32 .ok_or(())
33 .unwrap_err();
34 }
35 }
36 let raw_fixed = serde_json::to_vec_pretty(&tokenizer).unwrap();
37 Tokenizer::from_bytes(&raw_fixed).map_err(anyhow::Error::msg)?
38 };
39 if let Some(added_tokens) = processor_added_tokens {
40 tokenizer.add_special_tokens(
41 &added_tokens
42 .iter()
43 .map(|x| tokenizer::AddedToken::from(x.to_string(), true))
44 .collect::<Vec<_>>(),
45 );
46 }
47 Ok(tokenizer)
48}