mistralrs_core/utils/
gguf_metadata.rs

1use akin::akin;
2use anyhow::ensure;
3use anyhow::Result;
4use candle_core::quantized::gguf_file;
5use candle_core::DType;
6use std::collections::HashMap;
7use std::fs;
8use tracing::warn;
9
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::gguf::Content;
12use crate::matformer::MatformerSliceConfig;
13use crate::paged_attention::ModelConfigLike;
14use crate::pipeline::AutoDeviceMapParams;
15use crate::pipeline::DeviceMappedModelLoader;
16use crate::GGUFArchitecture;
17
18#[derive(Debug)]
19pub struct ContentConfig {
20    max_seq_len: usize,
21    hidden_size: usize,
22    num_attn_heads: usize,
23    num_kv_heads: usize,
24    num_layers: usize,
25    key_length: Option<usize>,
26    value_length: Option<usize>,
27}
28
29#[allow(clippy::cast_possible_truncation)]
30impl<'a, R: std::io::Seek + std::io::Read> From<&Content<'a, R>> for ContentConfig {
31    fn from(value: &Content<'a, R>) -> Self {
32        let metadata = value.get_metadata();
33        let arch = metadata["general.architecture"].to_string().unwrap();
34        Self {
35            max_seq_len: metadata[&format!("{arch}.context_length")]
36                .to_u64()
37                .unwrap() as usize,
38            hidden_size: metadata[&format!("{arch}.embedding_length")]
39                .to_u64()
40                .unwrap() as usize,
41            num_attn_heads: metadata[&format!("{arch}.attention.head_count")]
42                .to_u64()
43                .unwrap() as usize,
44            num_kv_heads: metadata[&format!("{arch}.attention.head_count_kv")]
45                .to_u64()
46                .unwrap() as usize,
47            num_layers: metadata[&format!("{arch}.block_count")].to_u64().unwrap() as usize,
48            key_length: metadata
49                .get(&format!("{arch}.attention.key_length"))
50                .map(|x| x.to_u64().unwrap() as usize),
51            value_length: metadata
52                .get(&format!("{arch}.attention.value_length"))
53                .map(|x| x.to_u64().unwrap() as usize),
54        }
55    }
56}
57
58impl ModelConfigLike for ContentConfig {
59    fn max_seq_len(&self) -> usize {
60        self.max_seq_len
61    }
62    fn hidden_size(&self) -> usize {
63        self.hidden_size
64    }
65    fn num_attn_heads(&self) -> usize {
66        self.num_attn_heads
67    }
68    fn num_kv_heads(&self) -> usize {
69        self.num_kv_heads
70    }
71    fn num_layers(&self) -> usize {
72        self.num_layers
73    }
74    fn k_head_dim(&self) -> usize {
75        self.key_length
76            .unwrap_or(self.hidden_size / self.num_attn_heads)
77    }
78    fn v_head_dim(&self) -> usize {
79        self.value_length
80            .unwrap_or(self.hidden_size / self.num_attn_heads)
81    }
82}
83
84pub struct ContentMetadata<'a> {
85    pub path_prefix: &'a str,
86    pub metadata: &'a HashMap<String, gguf_file::Value>,
87}
88
89impl ContentMetadata<'_> {
90    // Retrieve a prop the struct needs by querying the metadata content:
91    pub fn get_value<T: TryFromValue>(&self, field_name: &str) -> Result<T, anyhow::Error> {
92        let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
93        let value = self.metadata.get(&prop_key).cloned();
94
95        // Unwrap the inner value of the `Value` enum via trait method,
96        // otherwise format error with prop key as context:
97        value
98            .try_value_into()
99            .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
100    }
101
102    // Retrieve a prop the struct needs by querying the metadata content:
103    pub fn get_option_value<T: TryFromValue>(
104        &self,
105        field_name: &str,
106    ) -> Result<Option<T>, anyhow::Error> {
107        let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
108        let value = self.metadata.get(&prop_key).cloned();
109
110        // Unwrap the inner value of the `Value` enum via trait method,
111        // otherwise format error with prop key as context:
112        value
113            .map(|v| {
114                v.try_value_into()
115                    .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
116            })
117            .map_or(Ok(None), |res| res.map(Some))
118    }
119
120    // Fail early - Catch all missing mandatory keys upfront:
121    pub fn has_required_keys(&self, fields: &[&str]) -> Result<()> {
122        let mut all_props_are_present = true;
123
124        for field_name in fields {
125            let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
126
127            if !self.metadata.contains_key(&prop_key) {
128                all_props_are_present = false;
129                warn!("Expected GGUF metadata to have key: `{prop_key}`");
130            }
131        }
132
133        ensure!(all_props_are_present, "Tokenizer is missing required props");
134        Ok(())
135    }
136
137    // Reference: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#required
138    pub fn verify_arch(&self, expected_arch: &str) -> Result<()> {
139        let actual_arch: String = self
140            .metadata
141            .get("general.architecture")
142            .cloned()
143            .try_value_into()?;
144
145        anyhow::ensure!(
146            actual_arch == expected_arch,
147            "Expected `{expected_arch}` architecture, got `{actual_arch}`."
148        );
149
150        Ok(())
151    }
152}
153
154// These traits below are a workaround for converting candles GGUF `Value` enum type wrapper.
155// A better upstream approach would instead be to provide serialize/deserialize support?
156pub trait TryFromValue {
157    fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error>
158    where
159        Self: Sized;
160}
161
162// Value wrapped types, each has a different conversion method:
163// NOTE: Type conversion methods internally bail with "not a <into type> <input value>"
164// https://docs.rs/candle-core/latest/candle_core/quantized/gguf_file/enum.Value.html#variants
165akin! {
166    let &types = [String, bool, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64];
167    let &to_type = [
168        value.to_string().cloned(),
169        value.to_bool(),
170        value.to_f32(),
171        value.to_f64(),
172        value.to_i8(),
173        value.to_i16(),
174        value.to_i32(),
175        value.to_i64(),
176        value.to_u8(),
177        value.to_u16(),
178        value.to_u32(),
179        value.to_u64(),
180    ];
181
182    impl TryFromValue for *types {
183        fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error> {
184            *to_type.or_else(|_| candle_core::bail!("value is not a `*types`"))
185        }
186    }
187}
188
189// Vec<Value> to Vec<T> from above types:
190impl<T: TryFromValue> TryFromValue for Vec<T> {
191    fn try_from_value(value_vec: gguf_file::Value) -> Result<Self, candle_core::Error> {
192        value_vec
193            .to_vec()
194            .or_else(|_| candle_core::bail!("value is not a `Vec`"))?
195            .clone()
196            .into_iter()
197            .map(|item| T::try_from_value(item))
198            .collect()
199    }
200}
201
202pub trait TryValueInto<T>: Sized {
203    fn try_value_into(self) -> Result<T, candle_core::Error>;
204}
205
206impl<T: TryFromValue> TryValueInto<T> for gguf_file::Value {
207    fn try_value_into(self) -> Result<T, candle_core::Error> {
208        T::try_from_value(self)
209    }
210}
211
212impl<T: TryFromValue> TryValueInto<T> for Option<gguf_file::Value> {
213    fn try_value_into(self) -> Result<T, candle_core::Error> {
214        match self {
215            Some(value) => value.try_value_into(),
216            None => candle_core::bail!("Expected `Option<gguf_file::Value>` to contain a value"),
217        }
218    }
219}
220
221macro_rules! tensor_info_size_in_bytes {
222    ($t:expr) => {
223        $t.shape.elem_count() / $t.ggml_dtype.block_size() * $t.ggml_dtype.type_size()
224    };
225    ($t:expr, $ty:expr) => {
226        $t.shape.elem_count() * $ty.size_in_bytes()
227    };
228}
229
230pub struct GgufDeviceMapLoaderInner<'a, 'f> {
231    pub model: &'a Content<'f, fs::File>,
232    pub arch: GGUFArchitecture,
233}
234
235impl DeviceMappedModelLoader for GgufDeviceMapLoaderInner<'_, '_> {
236    fn mapped_max_act_size_elems(
237        &self,
238        _config: &str,
239        params: &AutoDeviceMapParams,
240    ) -> Result<usize> {
241        let AutoDeviceMapParams::Text {
242            max_seq_len,
243            max_batch_size,
244        } = params
245        else {
246            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
247        };
248        let num_heads = self.model.get_metadata()[&format!("{}.attention.head_count", self.arch)]
249            .to_u32()? as usize;
250        Ok(max_batch_size * num_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE))
251    }
252    fn non_mapped_max_act_size_elems(
253        &self,
254        _config: &str,
255        _params: &AutoDeviceMapParams,
256    ) -> Result<usize> {
257        Ok(0)
258    }
259
260    fn non_mapped_size_in_bytes(
261        &self,
262        _config: &str,
263        _dtype: DType,
264        _weight_pack_factor: usize,
265        _matformer_config: Option<&MatformerSliceConfig>,
266    ) -> Result<usize> {
267        let size_in_bytes = match self.arch {
268            GGUFArchitecture::Llama => {
269                let token_embd = tensor_info_size_in_bytes!(
270                    self.model.tensor_info("token_embd.weight")?,
271                    DType::F32
272                );
273                let output_norm = tensor_info_size_in_bytes!(
274                    self.model.tensor_info("output_norm.weight")?,
275                    DType::F32
276                );
277                let output = if !self.model.has_tensor("output.weight") {
278                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
279                } else {
280                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
281                };
282                token_embd + output_norm + output
283            }
284            GGUFArchitecture::Phi2 => {
285                let token_embd = tensor_info_size_in_bytes!(
286                    self.model.tensor_info("token_embd.weight")?,
287                    DType::F32
288                );
289                let output_norm =
290                    tensor_info_size_in_bytes!(
291                        self.model.tensor_info("output_norm.weight")?,
292                        DType::F32
293                    ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
294                let output = if !self.model.has_tensor("output.weight") {
295                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
296                } else {
297                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
298                };
299                token_embd + output_norm + output
300            }
301            GGUFArchitecture::Phi3 => {
302                let token_embd = tensor_info_size_in_bytes!(
303                    self.model.tensor_info("token_embd.weight")?,
304                    DType::F32
305                );
306                let output_norm = tensor_info_size_in_bytes!(
307                    self.model.tensor_info("output_norm.weight")?,
308                    DType::F32
309                );
310                let output = if !self.model.has_tensor("output.weight") {
311                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
312                } else {
313                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
314                };
315                token_embd + output_norm + output
316            }
317            GGUFArchitecture::Qwen2 | GGUFArchitecture::Qwen3 | GGUFArchitecture::Qwen3MoE => {
318                let token_embd = tensor_info_size_in_bytes!(
319                    self.model.tensor_info("token_embd.weight")?,
320                    DType::F32
321                );
322                let output_norm = tensor_info_size_in_bytes!(
323                    self.model.tensor_info("output_norm.weight")?,
324                    DType::F32
325                );
326                let output = if !self.model.has_tensor("output.weight") {
327                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
328                } else {
329                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
330                };
331                token_embd + output_norm + output
332            }
333            GGUFArchitecture::Starcoder2 => {
334                let token_embd = tensor_info_size_in_bytes!(
335                    self.model.tensor_info("token_embd.weight")?,
336                    DType::F32
337                );
338                let output_norm =
339                    tensor_info_size_in_bytes!(
340                        self.model.tensor_info("output_norm.weight")?,
341                        DType::F32
342                    ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
343                let output = if !self.model.has_tensor("output.weight") {
344                    tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
345                } else {
346                    tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
347                };
348                token_embd + output_norm + output
349            }
350            _ => unimplemented!(),
351        };
352        Ok(size_in_bytes)
353    }
354    fn num_layers(&self, _config: &str) -> Result<usize> {
355        Ok(self.model.get_metadata()[&format!("{}.block_count", self.arch)].to_u32()? as usize)
356    }
357    fn layer_sizes_in_bytes(
358        &self,
359        config: &str,
360        _dtype: DType,
361        _weight_pack_factor: usize,
362        _matformer_config: Option<&MatformerSliceConfig>,
363    ) -> Result<Vec<usize>> {
364        let size_in_bytes = match self.arch {
365            GGUFArchitecture::Llama => {
366                let attn_norm = tensor_info_size_in_bytes!(
367                    self.model.tensor_info("blk.0.attn_norm.weight")?,
368                    DType::F32
369                );
370                let ffn_norm = tensor_info_size_in_bytes!(
371                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
372                    DType::F32
373                );
374
375                let attn_q =
376                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.weight")?);
377                let attn_k =
378                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.weight")?);
379                let attn_v =
380                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.weight")?);
381                let attn_output = tensor_info_size_in_bytes!(self
382                    .model
383                    .tensor_info("blk.0.attn_output.weight")?);
384
385                // MoE or Mlp
386                #[allow(clippy::cast_possible_truncation)]
387                let n_expert = self
388                    .model
389                    .get_metadata()
390                    .get("expert_count")
391                    .map(|x| x.to_u64().unwrap() as usize)
392                    .unwrap_or(0);
393                let moe_or_mlp = if n_expert <= 1 {
394                    let ffn_gate = tensor_info_size_in_bytes!(self
395                        .model
396                        .tensor_info("blk.0.ffn_gate.weight")?);
397                    let ffn_up = tensor_info_size_in_bytes!(self
398                        .model
399                        .tensor_info("blk.0.ffn_up.weight")?);
400                    let ffn_down = tensor_info_size_in_bytes!(self
401                        .model
402                        .tensor_info("blk.0.ffn_down.weight")?);
403                    ffn_gate + ffn_up + ffn_down
404                } else {
405                    let mut moe_count = 0;
406                    moe_count += tensor_info_size_in_bytes!(self
407                        .model
408                        .tensor_info("blk.0.ffn_gate_inp.weight")?);
409                    match self.model.tensor_info("blk.0.ffn_gate_exps.weight") {
410                        Ok(feed_forward_gate_exps) => {
411                            moe_count += tensor_info_size_in_bytes!(feed_forward_gate_exps);
412                            moe_count += tensor_info_size_in_bytes!(self
413                                .model
414                                .tensor_info("blk.0.ffn_down_exps.weight")?);
415                            moe_count += tensor_info_size_in_bytes!(self
416                                .model
417                                .tensor_info("blk.0.ffn_up_exps.weight")?);
418                        }
419                        Err(_) => {
420                            for i in 0..n_expert {
421                                moe_count += tensor_info_size_in_bytes!(self
422                                    .model
423                                    .tensor_info(&format!("blk.0.ffn_gate.{i}.weight"),)?);
424                                moe_count += tensor_info_size_in_bytes!(self
425                                    .model
426                                    .tensor_info(&format!("blk.0.ffn_down.{i}.weight"),)?);
427                                moe_count += tensor_info_size_in_bytes!(self
428                                    .model
429                                    .tensor_info(&format!("blk.0.ffn_up.{i}.weight"))?);
430                            }
431                        }
432                    }
433
434                    moe_count
435                };
436                attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + moe_or_mlp
437            }
438            GGUFArchitecture::Phi2 => {
439                let attn_norm = tensor_info_size_in_bytes!(
440                    self.model.tensor_info("blk.0.attn_norm.weight")?,
441                    DType::F32
442                ) + tensor_info_size_in_bytes!(self
443                    .model
444                    .tensor_info("blk.0.attn_norm.bias")?);
445
446                let attn_qkv =
447                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
448                let attn_output = tensor_info_size_in_bytes!(self
449                    .model
450                    .tensor_info("blk.0.attn_output.weight")?);
451
452                let ffn_up =
453                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
454                let ffn_down =
455                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
456
457                attn_norm + attn_qkv + attn_output + ffn_up + ffn_down
458            }
459            GGUFArchitecture::Phi3 => {
460                let attn_norm = tensor_info_size_in_bytes!(
461                    self.model.tensor_info("blk.0.attn_norm.weight")?,
462                    DType::F32
463                );
464                let ffn_norm = tensor_info_size_in_bytes!(
465                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
466                    DType::F32
467                );
468
469                let attn_qkv =
470                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
471                let attn_output = tensor_info_size_in_bytes!(self
472                    .model
473                    .tensor_info("blk.0.attn_output.weight")?);
474
475                let ffn_up =
476                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
477                let ffn_down =
478                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
479
480                attn_norm + ffn_norm + attn_qkv + attn_output + ffn_up + ffn_down
481            }
482            GGUFArchitecture::Qwen2 | GGUFArchitecture::Qwen3 | GGUFArchitecture::Qwen3MoE => {
483                let attn_norm = tensor_info_size_in_bytes!(
484                    self.model.tensor_info("blk.0.attn_norm.weight")?,
485                    DType::F32
486                );
487                let ffn_norm = tensor_info_size_in_bytes!(
488                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
489                    DType::F32
490                );
491
492                let mut attn_q =
493                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.weight")?);
494                if let GGUFArchitecture::Qwen2 = self.arch {
495                    attn_q +=
496                        tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
497                }
498                let mut attn_k =
499                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.weight")?);
500                if let GGUFArchitecture::Qwen2 = self.arch {
501                    attn_k +=
502                        tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
503                }
504
505                let mut attn_v =
506                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.weight")?);
507                if let GGUFArchitecture::Qwen2 = self.arch {
508                    attn_v +=
509                        tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
510                }
511
512                let attn_output = tensor_info_size_in_bytes!(self
513                    .model
514                    .tensor_info("blk.0.attn_output.weight")?);
515
516                let ffn_gate = if let GGUFArchitecture::Qwen3MoE = self.arch {
517                    tensor_info_size_in_bytes!(self
518                        .model
519                        .tensor_info("blk.0.ffn_gate_exps.weight")?)
520                } else {
521                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_gate.weight")?)
522                };
523
524                let ffn_up = if let GGUFArchitecture::Qwen3MoE = self.arch {
525                    tensor_info_size_in_bytes!(self
526                        .model
527                        .tensor_info("blk.0.ffn_up_exps.weight")?)
528                } else {
529                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?)
530                };
531
532                let ffn_down = if let GGUFArchitecture::Qwen3MoE = self.arch {
533                    tensor_info_size_in_bytes!(self
534                        .model
535                        .tensor_info("blk.0.ffn_down_exps.weight")?)
536                } else {
537                    tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?)
538                };
539
540                attn_norm
541                    + ffn_norm
542                    + attn_q
543                    + attn_k
544                    + attn_v
545                    + attn_output
546                    + ffn_gate
547                    + ffn_up
548                    + ffn_down
549            }
550            GGUFArchitecture::Starcoder2 => {
551                let attn_norm = tensor_info_size_in_bytes!(
552                    self.model.tensor_info("blk.0.attn_norm.weight")?,
553                    DType::F32
554                ) + tensor_info_size_in_bytes!(self
555                    .model
556                    .tensor_info("blk.0.attn_norm.bias")?);
557                let ffn_norm = tensor_info_size_in_bytes!(
558                    self.model.tensor_info("blk.0.ffn_norm.weight")?,
559                    DType::F32
560                ) + tensor_info_size_in_bytes!(self
561                    .model
562                    .tensor_info("blk.0.ffn_norm.bias")?);
563
564                let attn_q = tensor_info_size_in_bytes!(self
565                    .model
566                    .tensor_info("blk.0.attn_q.weight")?)
567                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
568                let attn_k = tensor_info_size_in_bytes!(self
569                    .model
570                    .tensor_info("blk.0.attn_k.weight")?)
571                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
572                let attn_v = tensor_info_size_in_bytes!(self
573                    .model
574                    .tensor_info("blk.0.attn_v.weight")?)
575                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
576                let attn_output = tensor_info_size_in_bytes!(self
577                    .model
578                    .tensor_info("blk.0.attn_output.weight")?)
579                    + tensor_info_size_in_bytes!(self
580                        .model
581                        .tensor_info("blk.0.attn_output.bias")?);
582
583                let ffn_up = tensor_info_size_in_bytes!(self
584                    .model
585                    .tensor_info("blk.0.ffn_up.weight")?)
586                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.bias")?);
587                let ffn_down = tensor_info_size_in_bytes!(self
588                    .model
589                    .tensor_info("blk.0.ffn_down.weight")?)
590                    + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.bias")?);
591
592                attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + ffn_up + ffn_down
593            }
594            _ => unimplemented!(),
595        };
596        Ok(vec![size_in_bytes; self.num_layers(config)?])
597    }
598    fn model_config(&self, _config: &str) -> Result<Box<dyn ModelConfigLike>> {
599        let model_config_metadata: ContentConfig = self.model.into();
600        Ok(Box::new(model_config_metadata))
601    }
602}