mistralrs_core/gguf/
content.rs1use std::{collections::HashMap, fs};
2
3use anyhow::Context;
4use candle_core::{
5 quantized::{
6 gguf_file::{self, TensorInfo, Value},
7 QTensor,
8 },
9 Device, Result,
10};
11use indexmap::IndexMap;
12use tracing::info;
13
14use crate::DEBUG;
15
16use super::GGUFArchitecture;
17
18fn parse_gguf_value(value: &Value) -> String {
19 match value {
20 Value::Array(vs) => vs
21 .iter()
22 .map(parse_gguf_value)
23 .collect::<Vec<String>>()
24 .join(", "),
25 Value::Bool(b) => b.to_string(),
26 Value::F32(x) => x.to_string(),
27 Value::F64(x) => x.to_string(),
28 Value::I8(x) => x.to_string(),
29 Value::I16(x) => x.to_string(),
30 Value::I32(x) => x.to_string(),
31 Value::I64(x) => x.to_string(),
32 Value::String(x) => x.to_string(),
33 Value::U8(x) => x.to_string(),
34 Value::U16(x) => x.to_string(),
35 Value::U32(x) => x.to_string(),
36 Value::U64(x) => x.to_string(),
37 }
38}
39
40pub struct Content<'a, R: std::io::Seek + std::io::Read> {
43 contents: Vec<gguf_file::Content>,
44 readers: &'a mut [&'a mut R],
45 arch: GGUFArchitecture,
46 all_metadata: HashMap<String, Value>,
47}
48
49impl<'a, R: std::io::Seek + std::io::Read> Content<'a, R> {
50 pub fn from_readers(readers: &'a mut [&'a mut R]) -> Result<Self> {
52 let mut contents = Vec::new();
53 let n_readers = readers.len();
54 for reader in readers.iter_mut() {
55 contents.push(gguf_file::Content::read(reader)?);
56 }
57 let n_splits = contents
58 .iter()
59 .filter_map(|ct| {
60 ct.metadata
61 .get("split.count")
62 .map(|val| val.to_u64().unwrap())
63 })
64 .fold(Vec::new(), |mut accum, x| {
65 if !accum.contains(&x) {
66 accum.push(x);
67 }
68 accum
69 });
70 if n_splits.len() > 1 {
71 candle_core::bail!("GGUF files have differing `split.count` values: {n_splits:?}. Perhaps the GGUF files do not match?");
72 }
73 #[allow(clippy::cast_possible_truncation)]
74 if !n_splits.is_empty() && n_readers != n_splits[0] as usize {
75 candle_core::bail!(
76 "Number of GGUF files does not match the number of splits, expected {} files.",
77 n_splits[0]
78 );
79 } else if n_splits.len() == 1 {
80 info!("GGUF file has been split into {} shards", n_splits[0]);
81 }
82
83 let mut arch = None;
84 for ct in &contents {
85 if !ct.metadata.contains_key("general.architecture") {
86 continue;
87 }
88
89 arch = Some(
90 ct.metadata["general.architecture"]
91 .to_string()
92 .context("Model metadata should have declared an architecture")
93 .and_then(GGUFArchitecture::from_value)
94 .unwrap(),
95 );
96 }
97 let arch = arch.expect("GGUF files must specify `general.architecture`");
98
99 let mut all_metadata = HashMap::new();
100 for content in &contents {
101 all_metadata.extend(content.metadata.clone())
102 }
103
104 Ok(Self {
105 contents,
106 readers,
107 arch,
108 all_metadata,
109 })
110 }
111
112 pub fn arch(&self) -> GGUFArchitecture {
113 self.arch
114 }
115
116 pub fn tensor_info(&self, name: &str) -> Result<&TensorInfo> {
118 for ct in &self.contents {
119 if let Some(tensor_info) = ct.tensor_infos.get(name) {
120 return Ok(tensor_info);
121 }
122 }
123 candle_core::bail!("Cannot find tensor info for {name}")
124 }
125
126 pub fn tensor(&mut self, name: &str, device: &Device) -> Result<QTensor> {
128 for (ct, reader) in self.contents.iter().zip(self.readers.iter_mut()) {
129 if let Some(tensor_info) = ct.tensor_infos.get(name) {
130 return tensor_info.read(reader, ct.tensor_data_offset, device);
131 }
132 }
133 candle_core::bail!("Cannot find tensor info for {name}")
134 }
135
136 pub fn has_tensor(&self, name: &str) -> bool {
138 for ct in self.contents.iter() {
139 if ct.tensor_infos.contains_key(name) {
140 return true;
141 }
142 }
143 false
144 }
145
146 pub fn print_metadata(&self) -> anyhow::Result<()> {
149 let mut keys = Vec::new();
151 let mut metadatas = Vec::new();
152 let mut tensors = Vec::new();
153 for ct in &self.contents {
154 keys.extend(ct.metadata.keys());
155 metadatas.push(&ct.metadata);
156
157 if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
158 for (name, info) in &ct.tensor_infos {
159 tensors.push(format!(
160 "name = `{name}`, shape = {:?}, dtype = {:?}",
161 info.shape.clone(),
162 info.ggml_dtype
163 ));
164 }
165 }
166 }
167
168 info!("Model config:");
169 keys.sort();
170 let mut output_keys = IndexMap::new();
171 for name in keys {
172 if !name.contains("tokenizer") {
173 for metadata in &metadatas {
174 if let Some(val) = metadata.get(name) {
175 output_keys.insert(name, parse_gguf_value(val));
176 }
177 }
178 }
179 }
180 for (name, val) in output_keys {
181 println!("{name}: {val}")
182 }
183
184 if DEBUG.load(std::sync::atomic::Ordering::Relaxed) {
185 fs::write(
186 "mistralrs_gguf_tensors.txt",
187 serde_json::to_string_pretty(&tensors).expect("Serialization failed."),
188 )?;
189
190 info!("Debug is enabled, wrote the names and information about each tensor to `mistralrs_gguf_tensors.txt`.");
191 }
192
193 anyhow::Ok(())
194 }
195
196 pub fn get_metadata(&self) -> &HashMap<String, Value> {
198 &self.all_metadata
199 }
200}