mistralrs_core/gguf/
content.rs

1use std::{collections::HashMap, fs};
2
3use anyhow::Context;
4use candle_core::{
5    quantized::{
6        gguf_file::{self, TensorInfo, Value},
7        QTensor,
8    },
9    Device, Result,
10};
11use indexmap::IndexMap;
12use tracing::info;
13
14use crate::DEBUG;
15
16use super::GGUFArchitecture;
17
18fn parse_gguf_value(value: &Value) -> String {
19    match value {
20        Value::Array(vs) => vs
21            .iter()
22            .map(parse_gguf_value)
23            .collect::<Vec<String>>()
24            .join(", "),
25        Value::Bool(b) => b.to_string(),
26        Value::F32(x) => x.to_string(),
27        Value::F64(x) => x.to_string(),
28        Value::I8(x) => x.to_string(),
29        Value::I16(x) => x.to_string(),
30        Value::I32(x) => x.to_string(),
31        Value::I64(x) => x.to_string(),
32        Value::String(x) => x.to_string(),
33        Value::U8(x) => x.to_string(),
34        Value::U16(x) => x.to_string(),
35        Value::U32(x) => x.to_string(),
36        Value::U64(x) => x.to_string(),
37    }
38}
39
40// Internal invariant: contents and readers must be paired.
41/// This abstracts the files for a GGUF model and enables multiple files to be used.
42pub struct Content<'a, R: std::io::Seek + std::io::Read> {
43    contents: Vec<gguf_file::Content>,
44    readers: &'a mut [&'a mut R],
45    arch: GGUFArchitecture,
46    all_metadata: HashMap<String, Value>,
47}
48
49impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> {
50    /// Create a `Content` from a set of file readers.
51    pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result<Self> {
52        let mut contents = Vec::new();
53        let n_readers = readers.len();
54        for reader in readers.iter_mut() {
55            contents.push(gguf_file::Content::read(reader)?);
56        }
57        let n_splits = contents
58            .iter()
59            .filter_map(|ct| {
60                ct.metadata
61                    .get("split.count")
62                    .map(|val| val.to_u64().unwrap())
63            })
64            .fold(Vec::new(), |mut accum, x| {
65                if !accum.contains(&x) {
66                    accum.push(x);
67                }
68                accum
69            });
70        if n_splits.len() > 1 {
71            candle_core::bail!("GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?");
72        }
73        #[allow(clippy::cast_possible_truncation)]
74        if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
75            candle_core::bail!(
76                "Number of GGUF files does not match the number of splits, expected {} files.",
77                n_splits[0]
78            );
79        } else if n_splits.len() == 1 {
80            info!("GGUF file has been split into {} shards", n_splits[0]);
81        }
82
83        let mut arch = None;
84        for ct in &contents {
85            if !ct.metadata.contains_key("general.architecture") {
86                continue;
87            }
88
89            arch = Some(
90                ct.metadata["general.architecture"]
91                    .to_string()
92                    .context("Model metadata should have declared an architecture")
93                    .and_then(GGUFArchitecture::from_value)
94                    .unwrap(),
95            );
96        }
97        let arch = arch.expect("GGUF files must specify `general.architecture`");
98
99        let mut all_metadata = HashMap::new();
100        for content in &contents {
101            all_metadata.extend(content.metadata.clone())
102        }
103
104        Ok(Self {
105            contents,
106            readers,
107            arch,
108            all_metadata,
109        })
110    }
111
112    pub fn arch(&self) -> GGUFArchitecture {
113        self.arch
114    }
115
116    /// Retrieve a tensor info, searching through each content.
117    pub fn tensor_info(&self, name: &str) -> Result<&TensorInfo> {
118        for ct in &self.contents {
119            if let Some(tensor_info) = ct.tensor_infos.get(name) {
120                return Ok(tensor_info);
121            }
122        }
123        candle_core::bail!("Cannot find tensor info for {name}")
124    }
125
126    /// Retrieve a tensor, searching through each content.
127    pub fn tensor(&mut self, name: &str, device: &Device) -> Result<QTensor> {
128        for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) {
129            if let Some(tensor_info) = ct.tensor_infos.get(name) {
130                return tensor_info.read(reader, ct.tensor_data_offset, device);
131            }
132        }
133        candle_core::bail!("Cannot find tensor info for {name}")
134    }
135
136    /// Check for a tensor, searching through each content.
137    pub fn has_tensor(&self, name: &str) -> bool {
138        for ct in self.contents.iter() {
139            if ct.tensor_infos.contains_key(name) {
140                return true;
141            }
142        }
143        false
144    }
145
146    /// Print metadata for these contents.
147    /// This will also log tensor name, shape and dtype to `mistralrs_gguf_tensors.txt` is DEBUG is enabled.
148    pub fn print_metadata(&self) -> anyhow::Result<()> {
149        // Find the ct with general.architecture
150        let mut keys = Vec::new();
151        let mut metadatas = Vec::new();
152        let mut tensors = Vec::new();
153        for ct in &self.contents {
154            keys.extend(ct.metadata.keys());
155            metadatas.push(&ct.metadata);
156
157            if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
158                for (name, info) in &ct.tensor_infos {
159                    tensors.push(format!(
160                        "name = `{name}`, shape = {:?}, dtype = {:?}",
161                        info.shape.clone(),
162                        info.ggml_dtype
163                    ));
164                }
165            }
166        }
167
168        info!("Model config:");
169        keys.sort();
170        let mut output_keys = IndexMap::new();
171        for name in keys {
172            if !name.contains("tokenizer") {
173                for metadata in &metadatas {
174                    if let Some(val) = metadata.get(name) {
175                        output_keys.insert(name, parse_gguf_value(val));
176                    }
177                }
178            }
179        }
180        for (name, val) in output_keys {
181            println!("{name}: {val}")
182        }
183
184        if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
185            fs::write(
186                "mistralrs_gguf_tensors.txt",
187                serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
188            )?;
189
190            info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
191        }
192
193        anyhow::Ok(())
194    }
195
196    /// Get all metadatas
197    pub fn get_metadata(&self) -> &HashMap<String, Value> {
198        &self.all_metadata
199    }
200}