1use akin::akin;
2use anyhow::ensure;
3use anyhow::Result;
4use candle_core::quantized::gguf_file;
5use candle_core::DType;
6use std::collections::HashMap;
7use std::fs;
8use tracing::warn;
9
10use crate::attention::ATTENTION_CHUNK_SIZE;
11use crate::gguf::Content;
12use crate::matformer::MatformerSliceConfig;
13use crate::paged_attention::ModelConfigLike;
14use crate::pipeline::AutoDeviceMapParams;
15use crate::pipeline::DeviceMappedModelLoader;
16use crate::GGUFArchitecture;
17
18#[derive(Debug)]
19pub struct ContentConfig {
20 max_seq_len: usize,
21 hidden_size: usize,
22 num_attn_heads: usize,
23 num_kv_heads: usize,
24 num_layers: usize,
25 key_length: Option<usize>,
26 value_length: Option<usize>,
27}
28
29#[allow(clippy::cast_possible_truncation)]
30impl<'a, R: std::io::Seek + std::io::Read> From<&Content<'a, R>> for ContentConfig {
31 fn from(value: &Content<'a, R>) -> Self {
32 let metadata = value.get_metadata();
33 let arch = metadata["general.architecture"].to_string().unwrap();
34 Self {
35 max_seq_len: metadata[&format!("{arch}.context_length")]
36 .to_u64()
37 .unwrap() as usize,
38 hidden_size: metadata[&format!("{arch}.embedding_length")]
39 .to_u64()
40 .unwrap() as usize,
41 num_attn_heads: metadata[&format!("{arch}.attention.head_count")]
42 .to_u64()
43 .unwrap() as usize,
44 num_kv_heads: metadata[&format!("{arch}.attention.head_count_kv")]
45 .to_u64()
46 .unwrap() as usize,
47 num_layers: metadata[&format!("{arch}.block_count")].to_u64().unwrap() as usize,
48 key_length: metadata
49 .get(&format!("{arch}.attention.key_length"))
50 .map(|x| x.to_u64().unwrap() as usize),
51 value_length: metadata
52 .get(&format!("{arch}.attention.value_length"))
53 .map(|x| x.to_u64().unwrap() as usize),
54 }
55 }
56}
57
58impl ModelConfigLike for ContentConfig {
59 fn max_seq_len(&self) -> usize {
60 self.max_seq_len
61 }
62 fn hidden_size(&self) -> usize {
63 self.hidden_size
64 }
65 fn num_attn_heads(&self) -> usize {
66 self.num_attn_heads
67 }
68 fn num_kv_heads(&self) -> usize {
69 self.num_kv_heads
70 }
71 fn num_layers(&self) -> usize {
72 self.num_layers
73 }
74 fn k_head_dim(&self) -> usize {
75 self.key_length
76 .unwrap_or(self.hidden_size / self.num_attn_heads)
77 }
78 fn v_head_dim(&self) -> usize {
79 self.value_length
80 .unwrap_or(self.hidden_size / self.num_attn_heads)
81 }
82}
83
84pub struct ContentMetadata<'a> {
85 pub path_prefix: &'a str,
86 pub metadata: &'a HashMap<String, gguf_file::Value>,
87}
88
89impl ContentMetadata<'_> {
90 pub fn get_value<T: TryFromValue>(&self, field_name: &str) -> Result<T, anyhow::Error> {
92 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
93 let value = self.metadata.get(&prop_key).cloned();
94
95 value
98 .try_value_into()
99 .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
100 }
101
102 pub fn get_option_value<T: TryFromValue>(
104 &self,
105 field_name: &str,
106 ) -> Result<Option<T>, anyhow::Error> {
107 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
108 let value = self.metadata.get(&prop_key).cloned();
109
110 value
113 .map(|v| {
114 v.try_value_into()
115 .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
116 })
117 .map_or(Ok(None), |res| res.map(Some))
118 }
119
120 pub fn has_required_keys(&self, fields: &[&str]) -> Result<()> {
122 let mut all_props_are_present = true;
123
124 for field_name in fields {
125 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
126
127 if !self.metadata.contains_key(&prop_key) {
128 all_props_are_present = false;
129 warn!("Expected GGUF metadata to have key: `{prop_key}`");
130 }
131 }
132
133 ensure!(all_props_are_present, "Tokenizer is missing required props");
134 Ok(())
135 }
136
137 pub fn verify_arch(&self, expected_arch: &str) -> Result<()> {
139 let actual_arch: String = self
140 .metadata
141 .get("general.architecture")
142 .cloned()
143 .try_value_into()?;
144
145 anyhow::ensure!(
146 actual_arch == expected_arch,
147 "Expected `{expected_arch}` architecture, got `{actual_arch}`."
148 );
149
150 Ok(())
151 }
152}
153
154pub trait TryFromValue {
157 fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error>
158 where
159 Self: Sized;
160}
161
162akin! {
166 let &types = [String, bool, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64];
167 let &to_type = [
168 value.to_string().cloned(),
169 value.to_bool(),
170 value.to_f32(),
171 value.to_f64(),
172 value.to_i8(),
173 value.to_i16(),
174 value.to_i32(),
175 value.to_i64(),
176 value.to_u8(),
177 value.to_u16(),
178 value.to_u32(),
179 value.to_u64(),
180 ];
181
182 impl TryFromValue for *types {
183 fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error> {
184 *to_type.or_else(|_| candle_core::bail!("value is not a `*types`"))
185 }
186 }
187}
188
189impl<T: TryFromValue> TryFromValue for Vec<T> {
191 fn try_from_value(value_vec: gguf_file::Value) -> Result<Self, candle_core::Error> {
192 value_vec
193 .to_vec()
194 .or_else(|_| candle_core::bail!("value is not a `Vec`"))?
195 .clone()
196 .into_iter()
197 .map(|item| T::try_from_value(item))
198 .collect()
199 }
200}
201
202pub trait TryValueInto<T>: Sized {
203 fn try_value_into(self) -> Result<T, candle_core::Error>;
204}
205
206impl<T: TryFromValue> TryValueInto<T> for gguf_file::Value {
207 fn try_value_into(self) -> Result<T, candle_core::Error> {
208 T::try_from_value(self)
209 }
210}
211
212impl<T: TryFromValue> TryValueInto<T> for Option<gguf_file::Value> {
213 fn try_value_into(self) -> Result<T, candle_core::Error> {
214 match self {
215 Some(value) => value.try_value_into(),
216 None => candle_core::bail!("Expected `Option<gguf_file::Value>` to contain a value"),
217 }
218 }
219}
220
221macro_rules! tensor_info_size_in_bytes {
222 ($t:expr) => {
223 $t.shape.elem_count() / $t.ggml_dtype.block_size() * $t.ggml_dtype.type_size()
224 };
225 ($t:expr, $ty:expr) => {
226 $t.shape.elem_count() * $ty.size_in_bytes()
227 };
228}
229
230pub struct GgufDeviceMapLoaderInner<'a, 'f> {
231 pub model: &'a Content<'f, fs::File>,
232 pub arch: GGUFArchitecture,
233}
234
235impl DeviceMappedModelLoader for GgufDeviceMapLoaderInner<'_, '_> {
236 fn mapped_max_act_size_elems(
237 &self,
238 _config: &str,
239 params: &AutoDeviceMapParams,
240 ) -> Result<usize> {
241 let AutoDeviceMapParams::Text {
242 max_seq_len,
243 max_batch_size,
244 } = params
245 else {
246 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
247 };
248 let num_heads = self.model.get_metadata()[&format!("{}.attention.head_count", self.arch)]
249 .to_u32()? as usize;
250 Ok(max_batch_size * num_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE))
251 }
252 fn non_mapped_max_act_size_elems(
253 &self,
254 _config: &str,
255 _params: &AutoDeviceMapParams,
256 ) -> Result<usize> {
257 Ok(0)
258 }
259
260 fn non_mapped_size_in_bytes(
261 &self,
262 _config: &str,
263 _dtype: DType,
264 _weight_pack_factor: usize,
265 _matformer_config: Option<&MatformerSliceConfig>,
266 ) -> Result<usize> {
267 let size_in_bytes = match self.arch {
268 GGUFArchitecture::Llama => {
269 let token_embd = tensor_info_size_in_bytes!(
270 self.model.tensor_info("token_embd.weight")?,
271 DType::F32
272 );
273 let output_norm = tensor_info_size_in_bytes!(
274 self.model.tensor_info("output_norm.weight")?,
275 DType::F32
276 );
277 let output = if !self.model.has_tensor("output.weight") {
278 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
279 } else {
280 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
281 };
282 token_embd + output_norm + output
283 }
284 GGUFArchitecture::Phi2 => {
285 let token_embd = tensor_info_size_in_bytes!(
286 self.model.tensor_info("token_embd.weight")?,
287 DType::F32
288 );
289 let output_norm =
290 tensor_info_size_in_bytes!(
291 self.model.tensor_info("output_norm.weight")?,
292 DType::F32
293 ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
294 let output = if !self.model.has_tensor("output.weight") {
295 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
296 } else {
297 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
298 };
299 token_embd + output_norm + output
300 }
301 GGUFArchitecture::Phi3 => {
302 let token_embd = tensor_info_size_in_bytes!(
303 self.model.tensor_info("token_embd.weight")?,
304 DType::F32
305 );
306 let output_norm = tensor_info_size_in_bytes!(
307 self.model.tensor_info("output_norm.weight")?,
308 DType::F32
309 );
310 let output = if !self.model.has_tensor("output.weight") {
311 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
312 } else {
313 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
314 };
315 token_embd + output_norm + output
316 }
317 GGUFArchitecture::Qwen2 | GGUFArchitecture::Qwen3 | GGUFArchitecture::Qwen3MoE => {
318 let token_embd = tensor_info_size_in_bytes!(
319 self.model.tensor_info("token_embd.weight")?,
320 DType::F32
321 );
322 let output_norm = tensor_info_size_in_bytes!(
323 self.model.tensor_info("output_norm.weight")?,
324 DType::F32
325 );
326 let output = if !self.model.has_tensor("output.weight") {
327 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
328 } else {
329 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
330 };
331 token_embd + output_norm + output
332 }
333 GGUFArchitecture::Starcoder2 => {
334 let token_embd = tensor_info_size_in_bytes!(
335 self.model.tensor_info("token_embd.weight")?,
336 DType::F32
337 );
338 let output_norm =
339 tensor_info_size_in_bytes!(
340 self.model.tensor_info("output_norm.weight")?,
341 DType::F32
342 ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
343 let output = if !self.model.has_tensor("output.weight") {
344 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
345 } else {
346 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
347 };
348 token_embd + output_norm + output
349 }
350 _ => unimplemented!(),
351 };
352 Ok(size_in_bytes)
353 }
354 fn num_layers(&self, _config: &str) -> Result<usize> {
355 Ok(self.model.get_metadata()[&format!("{}.block_count", self.arch)].to_u32()? as usize)
356 }
357 fn layer_sizes_in_bytes(
358 &self,
359 config: &str,
360 _dtype: DType,
361 _weight_pack_factor: usize,
362 _matformer_config: Option<&MatformerSliceConfig>,
363 ) -> Result<Vec<usize>> {
364 let size_in_bytes = match self.arch {
365 GGUFArchitecture::Llama => {
366 let attn_norm = tensor_info_size_in_bytes!(
367 self.model.tensor_info("blk.0.attn_norm.weight")?,
368 DType::F32
369 );
370 let ffn_norm = tensor_info_size_in_bytes!(
371 self.model.tensor_info("blk.0.ffn_norm.weight")?,
372 DType::F32
373 );
374
375 let attn_q =
376 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.weight")?);
377 let attn_k =
378 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.weight")?);
379 let attn_v =
380 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.weight")?);
381 let attn_output = tensor_info_size_in_bytes!(self
382 .model
383 .tensor_info("blk.0.attn_output.weight")?);
384
385 #[allow(clippy::cast_possible_truncation)]
387 let n_expert = self
388 .model
389 .get_metadata()
390 .get("expert_count")
391 .map(|x| x.to_u64().unwrap() as usize)
392 .unwrap_or(0);
393 let moe_or_mlp = if n_expert <= 1 {
394 let ffn_gate = tensor_info_size_in_bytes!(self
395 .model
396 .tensor_info("blk.0.ffn_gate.weight")?);
397 let ffn_up = tensor_info_size_in_bytes!(self
398 .model
399 .tensor_info("blk.0.ffn_up.weight")?);
400 let ffn_down = tensor_info_size_in_bytes!(self
401 .model
402 .tensor_info("blk.0.ffn_down.weight")?);
403 ffn_gate + ffn_up + ffn_down
404 } else {
405 let mut moe_count = 0;
406 moe_count += tensor_info_size_in_bytes!(self
407 .model
408 .tensor_info("blk.0.ffn_gate_inp.weight")?);
409 match self.model.tensor_info("blk.0.ffn_gate_exps.weight") {
410 Ok(feed_forward_gate_exps) => {
411 moe_count += tensor_info_size_in_bytes!(feed_forward_gate_exps);
412 moe_count += tensor_info_size_in_bytes!(self
413 .model
414 .tensor_info("blk.0.ffn_down_exps.weight")?);
415 moe_count += tensor_info_size_in_bytes!(self
416 .model
417 .tensor_info("blk.0.ffn_up_exps.weight")?);
418 }
419 Err(_) => {
420 for i in 0..n_expert {
421 moe_count += tensor_info_size_in_bytes!(self
422 .model
423 .tensor_info(&format!("blk.0.ffn_gate.{i}.weight"),)?);
424 moe_count += tensor_info_size_in_bytes!(self
425 .model
426 .tensor_info(&format!("blk.0.ffn_down.{i}.weight"),)?);
427 moe_count += tensor_info_size_in_bytes!(self
428 .model
429 .tensor_info(&format!("blk.0.ffn_up.{i}.weight"))?);
430 }
431 }
432 }
433
434 moe_count
435 };
436 attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + moe_or_mlp
437 }
438 GGUFArchitecture::Phi2 => {
439 let attn_norm = tensor_info_size_in_bytes!(
440 self.model.tensor_info("blk.0.attn_norm.weight")?,
441 DType::F32
442 ) + tensor_info_size_in_bytes!(self
443 .model
444 .tensor_info("blk.0.attn_norm.bias")?);
445
446 let attn_qkv =
447 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
448 let attn_output = tensor_info_size_in_bytes!(self
449 .model
450 .tensor_info("blk.0.attn_output.weight")?);
451
452 let ffn_up =
453 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
454 let ffn_down =
455 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
456
457 attn_norm + attn_qkv + attn_output + ffn_up + ffn_down
458 }
459 GGUFArchitecture::Phi3 => {
460 let attn_norm = tensor_info_size_in_bytes!(
461 self.model.tensor_info("blk.0.attn_norm.weight")?,
462 DType::F32
463 );
464 let ffn_norm = tensor_info_size_in_bytes!(
465 self.model.tensor_info("blk.0.ffn_norm.weight")?,
466 DType::F32
467 );
468
469 let attn_qkv =
470 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
471 let attn_output = tensor_info_size_in_bytes!(self
472 .model
473 .tensor_info("blk.0.attn_output.weight")?);
474
475 let ffn_up =
476 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
477 let ffn_down =
478 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
479
480 attn_norm + ffn_norm + attn_qkv + attn_output + ffn_up + ffn_down
481 }
482 GGUFArchitecture::Qwen2 | GGUFArchitecture::Qwen3 | GGUFArchitecture::Qwen3MoE => {
483 let attn_norm = tensor_info_size_in_bytes!(
484 self.model.tensor_info("blk.0.attn_norm.weight")?,
485 DType::F32
486 );
487 let ffn_norm = tensor_info_size_in_bytes!(
488 self.model.tensor_info("blk.0.ffn_norm.weight")?,
489 DType::F32
490 );
491
492 let mut attn_q =
493 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.weight")?);
494 if let GGUFArchitecture::Qwen2 = self.arch {
495 attn_q +=
496 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
497 }
498 let mut attn_k =
499 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.weight")?);
500 if let GGUFArchitecture::Qwen2 = self.arch {
501 attn_k +=
502 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
503 }
504
505 let mut attn_v =
506 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.weight")?);
507 if let GGUFArchitecture::Qwen2 = self.arch {
508 attn_v +=
509 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
510 }
511
512 let attn_output = tensor_info_size_in_bytes!(self
513 .model
514 .tensor_info("blk.0.attn_output.weight")?);
515
516 let ffn_gate = if let GGUFArchitecture::Qwen3MoE = self.arch {
517 tensor_info_size_in_bytes!(self
518 .model
519 .tensor_info("blk.0.ffn_gate_exps.weight")?)
520 } else {
521 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_gate.weight")?)
522 };
523
524 let ffn_up = if let GGUFArchitecture::Qwen3MoE = self.arch {
525 tensor_info_size_in_bytes!(self
526 .model
527 .tensor_info("blk.0.ffn_up_exps.weight")?)
528 } else {
529 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?)
530 };
531
532 let ffn_down = if let GGUFArchitecture::Qwen3MoE = self.arch {
533 tensor_info_size_in_bytes!(self
534 .model
535 .tensor_info("blk.0.ffn_down_exps.weight")?)
536 } else {
537 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?)
538 };
539
540 attn_norm
541 + ffn_norm
542 + attn_q
543 + attn_k
544 + attn_v
545 + attn_output
546 + ffn_gate
547 + ffn_up
548 + ffn_down
549 }
550 GGUFArchitecture::Starcoder2 => {
551 let attn_norm = tensor_info_size_in_bytes!(
552 self.model.tensor_info("blk.0.attn_norm.weight")?,
553 DType::F32
554 ) + tensor_info_size_in_bytes!(self
555 .model
556 .tensor_info("blk.0.attn_norm.bias")?);
557 let ffn_norm = tensor_info_size_in_bytes!(
558 self.model.tensor_info("blk.0.ffn_norm.weight")?,
559 DType::F32
560 ) + tensor_info_size_in_bytes!(self
561 .model
562 .tensor_info("blk.0.ffn_norm.bias")?);
563
564 let attn_q = tensor_info_size_in_bytes!(self
565 .model
566 .tensor_info("blk.0.attn_q.weight")?)
567 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
568 let attn_k = tensor_info_size_in_bytes!(self
569 .model
570 .tensor_info("blk.0.attn_k.weight")?)
571 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
572 let attn_v = tensor_info_size_in_bytes!(self
573 .model
574 .tensor_info("blk.0.attn_v.weight")?)
575 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
576 let attn_output = tensor_info_size_in_bytes!(self
577 .model
578 .tensor_info("blk.0.attn_output.weight")?)
579 + tensor_info_size_in_bytes!(self
580 .model
581 .tensor_info("blk.0.attn_output.bias")?);
582
583 let ffn_up = tensor_info_size_in_bytes!(self
584 .model
585 .tensor_info("blk.0.ffn_up.weight")?)
586 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.bias")?);
587 let ffn_down = tensor_info_size_in_bytes!(self
588 .model
589 .tensor_info("blk.0.ffn_down.weight")?)
590 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.bias")?);
591
592 attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + ffn_up + ffn_down
593 }
594 _ => unimplemented!(),
595 };
596 Ok(vec![size_in_bytes; self.num_layers(config)?])
597 }
598 fn model_config(&self, _config: &str) -> Result<Box<dyn ModelConfigLike>> {
599 let model_config_metadata: ContentConfig = self.model.into();
600 Ok(Box::new(model_config_metadata))
601 }
602}