mistralrs_core/utils/
gguf_metadata.rs

1use akin::akin;
2use anyhow::ensure;
3use anyhow::Result;
4use candle_core::quantized::gguf_file;
5use candle_core::DType;
6use std::collections::HashMap;
7use std::fs;
8use tracing::warn;
9
10use crate::gguf::Content;
11use crate::paged_attention::ModelConfigLike;
12use crate::pipeline::AutoDeviceMapParams;
13use crate::pipeline::DeviceMappedModelLoader;
14use crate::GGUFArchitecture;
15
16#[derive(Debug)]
17pub struct ContentConfig {
18    max_seq_len: usize,
19    hidden_size: usize,
20    num_attn_heads: usize,
21    num_kv_heads: usize,
22    num_layers: usize,
23    key_length: Option<usize>,
24    value_length: Option<usize>,
25}
26
27#[allow(clippy::cast_possible_truncation)]
28impl<'a, R: std::io::Seek + std::io::Read> From<&Content<'a, R>> for ContentConfig {
29    fn from(value: &Content<'a, R>) -> Self {
30        let metadata = value.get_metadata();
31        let arch = metadata["general.architecture"].to_string().unwrap();
32        Self {
33            max_seq_len: metadata[&format!("{arch}.context_length")]
34                .to_u64()
35                .unwrap() as usize,
36            hidden_size: metadata[&format!("{arch}.embedding_length")]
37                .to_u64()
38                .unwrap() as usize,
39            num_attn_heads: metadata[&format!("{arch}.attention.head_count")]
40                .to_u64()
41                .unwrap() as usize,
42            num_kv_heads: metadata[&format!("{arch}.attention.head_count_kv")]
43                .to_u64()
44                .unwrap() as usize,
45            num_layers: metadata[&format!("{arch}.block_count")].to_u64().unwrap() as usize,
46            key_length: metadata
47                .get(&format!("{arch}.attention.key_length"))
48                .map(|x| x.to_u64().unwrap() as usize),
49            value_length: metadata
50                .get(&format!("{arch}.attention.value_length"))
51                .map(|x| x.to_u64().unwrap() as usize),
52        }
53    }
54}
55
56impl ModelConfigLike for ContentConfig {
57    fn max_seq_len(&self) -> usize {
58        self.max_seq_len
59    }
60    fn hidden_size(&self) -> usize {
61        self.hidden_size
62    }
63    fn num_attn_heads(&self) -> usize {
64        self.num_attn_heads
65    }
66    fn num_kv_heads(&self) -> usize {
67        self.num_kv_heads
68    }
69    fn num_layers(&self) -> usize {
70        self.num_layers
71    }
72    fn k_head_dim(&self) -> usize {
73        self.key_length
74            .unwrap_or(self.hidden_size / self.num_attn_heads)
75    }
76    fn v_head_dim(&self) -> usize {
77        self.value_length
78            .unwrap_or(self.hidden_size / self.num_attn_heads)
79    }
80}
81
82pub struct ContentMetadata<'a> {
83    pub path_prefix: &'a str,
84    pub metadata: &'a HashMap<String, gguf_file::Value>,
85}
86
87impl ContentMetadata<'_> {
88    // Retrieve a prop the struct needs by querying the metadata content:
89    pub fn get_value<T: TryFromValue>(&self, field_name: &str) -> Result<T, anyhow::Error> {
90        let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
91        let value = self.metadata.get(&prop_key).cloned();
92
93        // Unwrap the inner value of the `Value` enum via trait method,
94        // otherwise format error with prop key as context:
95        value
96            .try_value_into()
97            .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
98    }
99
100    // Retrieve a prop the struct needs by querying the metadata content:
101    pub fn get_option_value<T: TryFromValue>(
102        &self,
103        field_name: &str,
104    ) -> Result<Option<T>, anyhow::Error> {
105        let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
106        let value = self.metadata.get(&prop_key).cloned();
107
108        // Unwrap the inner value of the `Value` enum via trait method,
109        // otherwise format error with prop key as context:
110        value
111            .map(|v| {
112                v.try_value_into()
113                    .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
114            })
115            .map_or(Ok(None), |res| res.map(Some))
116    }
117
118    // Fail early - Catch all missing mandatory keys upfront:
119    pub fn has_required_keys(&self, fields: &[&str]) -> Result<()> {
120        let mut all_props_are_present = true;
121
122        for field_name in fields {
123            let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
124
125            if !self.metadata.contains_key(&prop_key) {
126                all_props_are_present = false;
127                warn!("Expected GGUF metadata to have key: `{prop_key}`");
128            }
129        }
130
131        ensure!(all_props_are_present, "Tokenizer is missing required props");
132        Ok(())
133    }
134
135    // Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#required
136    pub fn verify_arch(&self, expected_arch: &str) -> Result<()> {
137        let actual_arch: String = self
138            .metadata
139            .get("general.architecture")
140            .cloned()
141            .try_value_into()?;
142
143        anyhow::ensure!(
144            actual_arch == expected_arch,
145            "Expected `{expected_arch}` architecture, got `{actual_arch}`."
146        );
147
148        Ok(())
149    }
150}
151
152// These traits below are a workaround for converting candles GGUF `Value` enum type wrapper.
153// A better upstream approach would instead be to provide serialize/deserialize support?
154pub trait TryFromValue {
155    fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error>
156    where
157        Self: Sized;
158}
159
160// Value wrapped types, each has a different conversion method:
161// NOTE: Type conversion methods internally bail with "not a <into type> <input value>"
162// https://docs.rs/candle-core/latest/candle_core/quantized/gguf_file/enum.Value.html#variants
163akin! {
164    let &types = [String, bool, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64];
165    let &to_type = [
166        value.to_string().cloned(),
167        value.to_bool(),
168        value.to_f32(),
169        value.to_f64(),
170        value.to_i8(),
171        value.to_i16(),
172        value.to_i32(),
173        value.to_i64(),
174        value.to_u8(),
175        value.to_u16(),
176        value.to_u32(),
177        value.to_u64(),
178    ];
179
180    impl TryFromValue for *types {
181        fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error> {
182            *to_type.or_else(|_| candle_core::bail!("value is not a `*types`"))
183        }
184    }
185}
186
187// Vec<Value> to Vec<T> from above types:
188impl<T: TryFromValue> TryFromValue for Vec<T> {
189    fn try_from_value(value_vec: gguf_file::Value) -> Result<Self, candle_core::Error> {
190        value_vec
191            .to_vec()
192            .or_else(|_| candle_core::bail!("value is not a `Vec`"))?
193            .clone()
194            .into_iter()
195            .map(|item| T::try_from_value(item))
196            .collect()
197    }
198}
199
200pub trait TryValueInto<T>: Sized {
201    fn try_value_into(self) -> Result<T, candle_core::Error>;
202}
203
204impl<T: TryFromValue> TryValueInto<T> for gguf_file::Value {
205    fn try_value_into(self) -> Result<T, candle_core::Error> {
206        T::try_from_value(self)
207    }
208}
209
210impl<T: TryFromValue> TryValueInto<T> for Option<gguf_file::Value> {
211    fn try_value_into(self) -> Result<T, candle_core::Error> {
212        match self {
213            Some(value) => value.try_value_into(),
214            None => candle_core::bail!("Expected `Option<gguf_file::Value>` to contain a value"),
215        }
216    }
217}
218
219macro_rules! tensor_info_size_in_bytes {
220    ($t:expr) => {
221        $t.shape.elem_count() / $t.ggml_dtype.block_size() * $t.ggml_dtype.type_size()
222    };
223    ($t:expr, $ty:expr) => {
224        $t.shape.elem_count() * $ty.size_in_bytes()
225    };
226}
227
228pub struct GgufDeviceMapLoaderInner<'a, 'f> {
229    pub model: &'a Content<'f, fs::File>,
230    pub arch: GGUFArchitecture,
231}
232
233impl DeviceMappedModelLoader for GgufDeviceMapLoaderInner<'_, '_> {
234    fn mapped_max_act_size_elems(
235        &self,
236        _config: &str,
237        params: &AutoDeviceMapParams,
238        prompt_chunksize: usize,
239    ) -> Result<usize> {
240        let AutoDeviceMapParams::Text {
241            max_seq_len: _,
242            max_batch_size,
243        } = params
244        else {
245            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
246        };
247        let num_heads = self.model.get_metadata()[&format!("{}.attention.head_count", self.arch)]
248            .to_u32()? as usize;
249        Ok(max_batch_size * num_heads * prompt_chunksize * prompt_chunksize)
250    }
251    fn non_mapped_max_act_size_elems(
252        &self,
253        _config: &str,
254        _params: &AutoDeviceMapParams,
255    ) -> Result<usize> {
256        Ok(0)
257    }
258
259    fn non_mapped_size_in_bytes(
260        &self,
261        _config: &str,
262        _dtype: DType,
263        _weight_pack_factor: usize,
264    ) -> Result<usize> {
265        let size_in_bytes = match self.arch {
266            GGUFArchitecture::Llama => {
267                let token_embd = tensor_info_size_in_bytes!(
268                    self.model.tensor_info("token_embd.weight")?,
269                    DType::F32
270                );
271                let output_norm = tensor_info_size_in_bytes!(
272                    self.model.tensor_info("output_norm.weight")?,
273                    DType::F32
274                );
275                let output = if !self.model.has_tensor("output.weight") {
276                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
277                } else {
278                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
279                };
280                token_embd + output_norm + output
281            }
282            GGUFArchitecture::Phi2 => {
283                let token_embd = tensor_info_size_in_bytes!(
284                    self.model.tensor_info("token_embd.weight")?,
285                    DType::F32
286                );
287                let output_norm =
288                    tensor_info_size_in_bytes!(
289                        self.model.tensor_info("output_norm.weight")?,
290                        DType::F32
291                    ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
292                let output = if !self.model.has_tensor("output.weight") {
293                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
294                } else {
295                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
296                };
297                token_embd + output_norm + output
298            }
299            GGUFArchitecture::Phi3 => {
300                let token_embd = tensor_info_size_in_bytes!(
301                    self.model.tensor_info("token_embd.weight")?,
302                    DType::F32
303                );
304                let output_norm = tensor_info_size_in_bytes!(
305                    self.model.tensor_info("output_norm.weight")?,
306                    DType::F32
307                );
308                let output = if !self.model.has_tensor("output.weight") {
309                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
310                } else {
311                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
312                };
313                token_embd + output_norm + output
314            }
315            GGUFArchitecture::Qwen2 => {
316                let token_embd = tensor_info_size_in_bytes!(
317                    self.model.tensor_info("token_embd.weight")?,
318                    DType::F32
319                );
320                let output_norm = tensor_info_size_in_bytes!(
321                    self.model.tensor_info("output_norm.weight")?,
322                    DType::F32
323                );
324                let output = if !self.model.has_tensor("output.weight") {
325                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
326                } else {
327                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
328                };
329                token_embd + output_norm + output
330            }
331            GGUFArchitecture::Starcoder2 => {
332                let token_embd = tensor_info_size_in_bytes!(
333                    self.model.tensor_info("token_embd.weight")?,
334                    DType::F32
335                );
336                let output_norm =
337                    tensor_info_size_in_bytes!(
338                        self.model.tensor_info("output_norm.weight")?,
339                        DType::F32
340                    ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
341                let output = if !self.model.has_tensor("output.weight") {
342                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
343                } else {
344                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
345                };
346                token_embd + output_norm + output
347            }
348            _ => unimplemented!(),
349        };
350        Ok(size_in_bytes)
351    }
352    fn num_layers(&self, _config: &str) -> Result<usize> {
353        Ok(self.model.get_metadata()[&format!("{}.block_count", self.arch)].to_u32()? as usize)
354    }
355    fn layer_sizes_in_bytes(
356        &self,
357        config: &str,
358        _dtype: DType,
359        _weight_pack_factor: usize,
360    ) -> Result<Vec<usize>> {
361        let size_in_bytes = match self.arch {
362            GGUFArchitecture::Llama => {
363                let attn_norm = tensor_info_size_in_bytes!(
364                    self.model.tensor_info("blk.0.attn_norm.weight")?,
365                    DType::F32
366                );
367                let ffn_norm = tensor_info_size_in_bytes!(
368                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
369                    DType::F32
370                );
371
372                let attn_q =
373                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.weight")?);
374                let attn_k =
375                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.weight")?);
376                let attn_v =
377                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.weight")?);
378                let attn_output = tensor_info_size_in_bytes!(self
379                    .model
380                    .tensor_info("blk.0.attn_output.weight")?);
381
382                // MoE or Mlp
383                #[allow(clippy::cast_possible_truncation)]
384                let n_expert = self
385                    .model
386                    .get_metadata()
387                    .get("expert_count")
388                    .map(|x| x.to_u64().unwrap() as usize)
389                    .unwrap_or(0);
390                let moe_or_mlp = if n_expert <= 1 {
391                    let ffn_gate = tensor_info_size_in_bytes!(self
392                        .model
393                        .tensor_info("blk.0.ffn_gate.weight")?);
394                    let ffn_up = tensor_info_size_in_bytes!(self
395                        .model
396                        .tensor_info("blk.0.ffn_up.weight")?);
397                    let ffn_down = tensor_info_size_in_bytes!(self
398                        .model
399                        .tensor_info("blk.0.ffn_down.weight")?);
400                    ffn_gate + ffn_up + ffn_down
401                } else {
402                    let mut moe_count = 0;
403                    moe_count += tensor_info_size_in_bytes!(self
404                        .model
405                        .tensor_info("blk.0.ffn_gate_inp.weight")?);
406                    match self.model.tensor_info("blk.0.ffn_gate_exps.weight") {
407                        Ok(feed_forward_gate_exps) => {
408                            moe_count += tensor_info_size_in_bytes!(feed_forward_gate_exps);
409                            moe_count += tensor_info_size_in_bytes!(self
410                                .model
411                                .tensor_info("blk.0.ffn_down_exps.weight")?);
412                            moe_count += tensor_info_size_in_bytes!(self
413                                .model
414                                .tensor_info("blk.0.ffn_up_exps.weight")?);
415                        }
416                        Err(_) => {
417                            for i in 0..n_expert {
418                                moe_count += tensor_info_size_in_bytes!(self
419                                    .model
420                                    .tensor_info(&format!("blk.0.ffn_gate.{i}.weight"),)?);
421                                moe_count += tensor_info_size_in_bytes!(self
422                                    .model
423                                    .tensor_info(&format!("blk.0.ffn_down.{i}.weight"),)?);
424                                moe_count += tensor_info_size_in_bytes!(self
425                                    .model
426                                    .tensor_info(&format!("blk.0.ffn_up.{i}.weight"))?);
427                            }
428                        }
429                    }
430
431                    moe_count
432                };
433                attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + moe_or_mlp
434            }
435            GGUFArchitecture::Phi2 => {
436                let attn_norm = tensor_info_size_in_bytes!(
437                    self.model.tensor_info("blk.0.attn_norm.weight")?,
438                    DType::F32
439                ) + tensor_info_size_in_bytes!(self
440                    .model
441                    .tensor_info("blk.0.attn_norm.bias")?);
442
443                let attn_qkv =
444                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
445                let attn_output = tensor_info_size_in_bytes!(self
446                    .model
447                    .tensor_info("blk.0.attn_output.weight")?);
448
449                let ffn_up =
450                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
451                let ffn_down =
452                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
453
454                attn_norm + attn_qkv + attn_output + ffn_up + ffn_down
455            }
456            GGUFArchitecture::Phi3 => {
457                let attn_norm = tensor_info_size_in_bytes!(
458                    self.model.tensor_info("blk.0.attn_norm.weight")?,
459                    DType::F32
460                );
461                let ffn_norm = tensor_info_size_in_bytes!(
462                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
463                    DType::F32
464                );
465
466                let attn_qkv =
467                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
468                let attn_output = tensor_info_size_in_bytes!(self
469                    .model
470                    .tensor_info("blk.0.attn_output.weight")?);
471
472                let ffn_up =
473                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
474                let ffn_down =
475                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
476
477                attn_norm + ffn_norm + attn_qkv + attn_output + ffn_up + ffn_down
478            }
479            GGUFArchitecture::Qwen2 => {
480                let attn_norm = tensor_info_size_in_bytes!(
481                    self.model.tensor_info("blk.0.attn_norm.weight")?,
482                    DType::F32
483                );
484                let ffn_norm = tensor_info_size_in_bytes!(
485                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
486                    DType::F32
487                );
488
489                let attn_q = tensor_info_size_in_bytes!(self
490                    .model
491                    .tensor_info("blk.0.attn_q.weight")?)
492                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
493                let attn_k = tensor_info_size_in_bytes!(self
494                    .model
495                    .tensor_info("blk.0.attn_k.weight")?)
496                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
497                let attn_v = tensor_info_size_in_bytes!(self
498                    .model
499                    .tensor_info("blk.0.attn_v.weight")?)
500                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
501                let attn_output = tensor_info_size_in_bytes!(self
502                    .model
503                    .tensor_info("blk.0.attn_output.weight")?);
504
505                let ffn_gate =
506                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_gate.weight")?);
507                let ffn_up =
508                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
509                let ffn_down =
510                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
511
512                attn_norm
513                    + ffn_norm
514                    + attn_q
515                    + attn_k
516                    + attn_v
517                    + attn_output
518                    + ffn_gate
519                    + ffn_up
520                    + ffn_down
521            }
522            GGUFArchitecture::Starcoder2 => {
523                let attn_norm = tensor_info_size_in_bytes!(
524                    self.model.tensor_info("blk.0.attn_norm.weight")?,
525                    DType::F32
526                ) + tensor_info_size_in_bytes!(self
527                    .model
528                    .tensor_info("blk.0.attn_norm.bias")?);
529                let ffn_norm = tensor_info_size_in_bytes!(
530                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
531                    DType::F32
532                ) + tensor_info_size_in_bytes!(self
533                    .model
534                    .tensor_info("blk.0.ffn_norm.bias")?);
535
536                let attn_q = tensor_info_size_in_bytes!(self
537                    .model
538                    .tensor_info("blk.0.attn_q.weight")?)
539                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
540                let attn_k = tensor_info_size_in_bytes!(self
541                    .model
542                    .tensor_info("blk.0.attn_k.weight")?)
543                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
544                let attn_v = tensor_info_size_in_bytes!(self
545                    .model
546                    .tensor_info("blk.0.attn_v.weight")?)
547                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
548                let attn_output = tensor_info_size_in_bytes!(self
549                    .model
550                    .tensor_info("blk.0.attn_output.weight")?)
551                    + tensor_info_size_in_bytes!(self
552                        .model
553                        .tensor_info("blk.0.attn_output.bias")?);
554
555                let ffn_up = tensor_info_size_in_bytes!(self
556                    .model
557                    .tensor_info("blk.0.ffn_up.weight")?)
558                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.bias")?);
559                let ffn_down = tensor_info_size_in_bytes!(self
560                    .model
561                    .tensor_info("blk.0.ffn_down.weight")?)
562                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.bias")?);
563
564                attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + ffn_up + ffn_down
565            }
566            _ => unimplemented!(),
567        };
568        Ok(vec![size_in_bytes; self.num_layers(config)?])
569    }
570    fn model_config(&self, _config: &str) -> Result<Box<dyn ModelConfigLike>> {
571        let model_config_metadata: ContentConfig = self.model.into();
572        Ok(Box::new(model_config_metadata))
573    }
574}