mistralrs_core/utils/
tokenizer.rs

1use std::{collections::HashMap, path::Path};
2
3use anyhow::Result;
4use serde::Deserialize;
5use serde_json::Value;
6use tokenizers::{tokenizer, Tokenizer};
7
8#[derive(Deserialize)]
9struct AddedToken {
10    id: usize,
11    content: String,
12}
13
14/// May fix the tokenizer according to: https://gist.github.com/jneuff/682d47b786329f19291d166957b3274a
15pub(crate) fn get_tokenizer<P: AsRef<Path> + Clone>(
16    p: P,
17    processor_added_tokens: Option<&[&str]>,
18) -> Result<Tokenizer> {
19    let mut tokenizer = {
20        let raw = std::fs::read(p.clone()).map_err(anyhow::Error::msg)?;
21        let mut tokenizer: Value = serde_json::from_slice(&raw).unwrap();
22        let added_tokens: Vec<AddedToken> =
23            serde_json::from_value(tokenizer["added_tokens"].clone()).unwrap();
24        let vocab: HashMap<String, usize> =
25            serde_json::from_value(tokenizer["model"]["vocab"].clone()).unwrap();
26        for token in added_tokens {
27            if !vocab.contains_key(&token.content) {
28                tokenizer["model"]["vocab"]
29                    .as_object_mut()
30                    .unwrap()
31                    .insert(token.content, token.id.into())
32                    .ok_or(())
33                    .unwrap_err();
34            }
35        }
36        let raw_fixed = serde_json::to_vec_pretty(&tokenizer).unwrap();
37        Tokenizer::from_bytes(&raw_fixed).map_err(anyhow::Error::msg)?
38    };
39    if let Some(added_tokens) = processor_added_tokens {
40        tokenizer.add_special_tokens(
41            &added_tokens
42                .iter()
43                .map(|x| tokenizer::AddedToken::from(x.to_string(), true))
44                .collect::<Vec<_>>(),
45        );
46    }
47    Ok(tokenizer)
48}