1use akin::akin;
2use anyhow::ensure;
3use anyhow::Result;
4use candle_core::quantized::gguf_file;
5use candle_core::DType;
6use std::collections::HashMap;
7use std::fs;
8use tracing::warn;
9
10use crate::gguf::Content;
11use crate::paged_attention::ModelConfigLike;
12use crate::pipeline::AutoDeviceMapParams;
13use crate::pipeline::DeviceMappedModelLoader;
14use crate::GGUFArchitecture;
15
16#[derive(Debug)]
17pub struct ContentConfig {
18 max_seq_len: usize,
19 hidden_size: usize,
20 num_attn_heads: usize,
21 num_kv_heads: usize,
22 num_layers: usize,
23 key_length: Option<usize>,
24 value_length: Option<usize>,
25}
26
27#[allow(clippy::cast_possible_truncation)]
28impl<'a, R: std::io::Seek + std::io::Read> From<&Content<'a, R>> for ContentConfig {
29 fn from(value: &Content<'a, R>) -> Self {
30 let metadata = value.get_metadata();
31 let arch = metadata["general.architecture"].to_string().unwrap();
32 Self {
33 max_seq_len: metadata[&format!("{arch}.context_length")]
34 .to_u64()
35 .unwrap() as usize,
36 hidden_size: metadata[&format!("{arch}.embedding_length")]
37 .to_u64()
38 .unwrap() as usize,
39 num_attn_heads: metadata[&format!("{arch}.attention.head_count")]
40 .to_u64()
41 .unwrap() as usize,
42 num_kv_heads: metadata[&format!("{arch}.attention.head_count_kv")]
43 .to_u64()
44 .unwrap() as usize,
45 num_layers: metadata[&format!("{arch}.block_count")].to_u64().unwrap() as usize,
46 key_length: metadata
47 .get(&format!("{arch}.attention.key_length"))
48 .map(|x| x.to_u64().unwrap() as usize),
49 value_length: metadata
50 .get(&format!("{arch}.attention.value_length"))
51 .map(|x| x.to_u64().unwrap() as usize),
52 }
53 }
54}
55
56impl ModelConfigLike for ContentConfig {
57 fn max_seq_len(&self) -> usize {
58 self.max_seq_len
59 }
60 fn hidden_size(&self) -> usize {
61 self.hidden_size
62 }
63 fn num_attn_heads(&self) -> usize {
64 self.num_attn_heads
65 }
66 fn num_kv_heads(&self) -> usize {
67 self.num_kv_heads
68 }
69 fn num_layers(&self) -> usize {
70 self.num_layers
71 }
72 fn k_head_dim(&self) -> usize {
73 self.key_length
74 .unwrap_or(self.hidden_size / self.num_attn_heads)
75 }
76 fn v_head_dim(&self) -> usize {
77 self.value_length
78 .unwrap_or(self.hidden_size / self.num_attn_heads)
79 }
80}
81
82pub struct ContentMetadata<'a> {
83 pub path_prefix: &'a str,
84 pub metadata: &'a HashMap<String, gguf_file::Value>,
85}
86
87impl ContentMetadata<'_> {
88 pub fn get_value<T: TryFromValue>(&self, field_name: &str) -> Result<T, anyhow::Error> {
90 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
91 let value = self.metadata.get(&prop_key).cloned();
92
93 value
96 .try_value_into()
97 .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
98 }
99
100 pub fn get_option_value<T: TryFromValue>(
102 &self,
103 field_name: &str,
104 ) -> Result<Option<T>, anyhow::Error> {
105 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
106 let value = self.metadata.get(&prop_key).cloned();
107
108 value
111 .map(|v| {
112 v.try_value_into()
113 .or_else(|e| anyhow::bail!("`{prop_key}` `{e}`"))
114 })
115 .map_or(Ok(None), |res| res.map(Some))
116 }
117
118 pub fn has_required_keys(&self, fields: &[&str]) -> Result<()> {
120 let mut all_props_are_present = true;
121
122 for field_name in fields {
123 let prop_key = format!("{prefix}.{field_name}", prefix = self.path_prefix);
124
125 if !self.metadata.contains_key(&prop_key) {
126 all_props_are_present = false;
127 warn!("Expected GGUF metadata to have key: `{prop_key}`");
128 }
129 }
130
131 ensure!(all_props_are_present, "Tokenizer is missing required props");
132 Ok(())
133 }
134
135 pub fn verify_arch(&self, expected_arch: &str) -> Result<()> {
137 let actual_arch: String = self
138 .metadata
139 .get("general.architecture")
140 .cloned()
141 .try_value_into()?;
142
143 anyhow::ensure!(
144 actual_arch == expected_arch,
145 "Expected `{expected_arch}` architecture, got `{actual_arch}`."
146 );
147
148 Ok(())
149 }
150}
151
152pub trait TryFromValue {
155 fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error>
156 where
157 Self: Sized;
158}
159
160akin! {
164 let &types = [String, bool, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64];
165 let &to_type = [
166 value.to_string().cloned(),
167 value.to_bool(),
168 value.to_f32(),
169 value.to_f64(),
170 value.to_i8(),
171 value.to_i16(),
172 value.to_i32(),
173 value.to_i64(),
174 value.to_u8(),
175 value.to_u16(),
176 value.to_u32(),
177 value.to_u64(),
178 ];
179
180 impl TryFromValue for *types {
181 fn try_from_value(value: gguf_file::Value) -> Result<Self, candle_core::Error> {
182 *to_type.or_else(|_| candle_core::bail!("value is not a `*types`"))
183 }
184 }
185}
186
187impl<T: TryFromValue> TryFromValue for Vec<T> {
189 fn try_from_value(value_vec: gguf_file::Value) -> Result<Self, candle_core::Error> {
190 value_vec
191 .to_vec()
192 .or_else(|_| candle_core::bail!("value is not a `Vec`"))?
193 .clone()
194 .into_iter()
195 .map(|item| T::try_from_value(item))
196 .collect()
197 }
198}
199
200pub trait TryValueInto<T>: Sized {
201 fn try_value_into(self) -> Result<T, candle_core::Error>;
202}
203
204impl<T: TryFromValue> TryValueInto<T> for gguf_file::Value {
205 fn try_value_into(self) -> Result<T, candle_core::Error> {
206 T::try_from_value(self)
207 }
208}
209
210impl<T: TryFromValue> TryValueInto<T> for Option<gguf_file::Value> {
211 fn try_value_into(self) -> Result<T, candle_core::Error> {
212 match self {
213 Some(value) => value.try_value_into(),
214 None => candle_core::bail!("Expected `Option<gguf_file::Value>` to contain a value"),
215 }
216 }
217}
218
219macro_rules! tensor_info_size_in_bytes {
220 ($t:expr) => {
221 $t.shape.elem_count() / $t.ggml_dtype.block_size() * $t.ggml_dtype.type_size()
222 };
223 ($t:expr, $ty:expr) => {
224 $t.shape.elem_count() * $ty.size_in_bytes()
225 };
226}
227
228pub struct GgufDeviceMapLoaderInner<'a, 'f> {
229 pub model: &'a Content<'f, fs::File>,
230 pub arch: GGUFArchitecture,
231}
232
233impl DeviceMappedModelLoader for GgufDeviceMapLoaderInner<'_, '_> {
234 fn mapped_max_act_size_elems(
235 &self,
236 _config: &str,
237 params: &AutoDeviceMapParams,
238 prompt_chunksize: usize,
239 ) -> Result<usize> {
240 let AutoDeviceMapParams::Text {
241 max_seq_len: _,
242 max_batch_size,
243 } = params
244 else {
245 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
246 };
247 let num_heads = self.model.get_metadata()[&format!("{}.attention.head_count", self.arch)]
248 .to_u32()? as usize;
249 Ok(max_batch_size * num_heads * prompt_chunksize * prompt_chunksize)
250 }
251 fn non_mapped_max_act_size_elems(
252 &self,
253 _config: &str,
254 _params: &AutoDeviceMapParams,
255 ) -> Result<usize> {
256 Ok(0)
257 }
258
259 fn non_mapped_size_in_bytes(
260 &self,
261 _config: &str,
262 _dtype: DType,
263 _weight_pack_factor: usize,
264 ) -> Result<usize> {
265 let size_in_bytes = match self.arch {
266 GGUFArchitecture::Llama => {
267 let token_embd = tensor_info_size_in_bytes!(
268 self.model.tensor_info("token_embd.weight")?,
269 DType::F32
270 );
271 let output_norm = tensor_info_size_in_bytes!(
272 self.model.tensor_info("output_norm.weight")?,
273 DType::F32
274 );
275 let output = if !self.model.has_tensor("output.weight") {
276 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
277 } else {
278 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
279 };
280 token_embd + output_norm + output
281 }
282 GGUFArchitecture::Phi2 => {
283 let token_embd = tensor_info_size_in_bytes!(
284 self.model.tensor_info("token_embd.weight")?,
285 DType::F32
286 );
287 let output_norm =
288 tensor_info_size_in_bytes!(
289 self.model.tensor_info("output_norm.weight")?,
290 DType::F32
291 ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
292 let output = if !self.model.has_tensor("output.weight") {
293 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
294 } else {
295 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
296 };
297 token_embd + output_norm + output
298 }
299 GGUFArchitecture::Phi3 => {
300 let token_embd = tensor_info_size_in_bytes!(
301 self.model.tensor_info("token_embd.weight")?,
302 DType::F32
303 );
304 let output_norm = tensor_info_size_in_bytes!(
305 self.model.tensor_info("output_norm.weight")?,
306 DType::F32
307 );
308 let output = if !self.model.has_tensor("output.weight") {
309 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
310 } else {
311 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
312 };
313 token_embd + output_norm + output
314 }
315 GGUFArchitecture::Qwen2 => {
316 let token_embd = tensor_info_size_in_bytes!(
317 self.model.tensor_info("token_embd.weight")?,
318 DType::F32
319 );
320 let output_norm = tensor_info_size_in_bytes!(
321 self.model.tensor_info("output_norm.weight")?,
322 DType::F32
323 );
324 let output = if !self.model.has_tensor("output.weight") {
325 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
326 } else {
327 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
328 };
329 token_embd + output_norm + output
330 }
331 GGUFArchitecture::Starcoder2 => {
332 let token_embd = tensor_info_size_in_bytes!(
333 self.model.tensor_info("token_embd.weight")?,
334 DType::F32
335 );
336 let output_norm =
337 tensor_info_size_in_bytes!(
338 self.model.tensor_info("output_norm.weight")?,
339 DType::F32
340 ) + tensor_info_size_in_bytes!(self.model.tensor_info("output_norm.bias")?);
341 let output = if !self.model.has_tensor("output.weight") {
342 tensor_info_size_in_bytes!(self.model.tensor_info("token_embd.weight")?)
343 } else {
344 tensor_info_size_in_bytes!(self.model.tensor_info("output.weight")?)
345 };
346 token_embd + output_norm + output
347 }
348 _ => unimplemented!(),
349 };
350 Ok(size_in_bytes)
351 }
352 fn num_layers(&self, _config: &str) -> Result<usize> {
353 Ok(self.model.get_metadata()[&format!("{}.block_count", self.arch)].to_u32()? as usize)
354 }
355 fn layer_sizes_in_bytes(
356 &self,
357 config: &str,
358 _dtype: DType,
359 _weight_pack_factor: usize,
360 ) -> Result<Vec<usize>> {
361 let size_in_bytes = match self.arch {
362 GGUFArchitecture::Llama => {
363 let attn_norm = tensor_info_size_in_bytes!(
364 self.model.tensor_info("blk.0.attn_norm.weight")?,
365 DType::F32
366 );
367 let ffn_norm = tensor_info_size_in_bytes!(
368 self.model.tensor_info("blk.0.ffn_norm.weight")?,
369 DType::F32
370 );
371
372 let attn_q =
373 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.weight")?);
374 let attn_k =
375 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.weight")?);
376 let attn_v =
377 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.weight")?);
378 let attn_output = tensor_info_size_in_bytes!(self
379 .model
380 .tensor_info("blk.0.attn_output.weight")?);
381
382 #[allow(clippy::cast_possible_truncation)]
384 let n_expert = self
385 .model
386 .get_metadata()
387 .get("expert_count")
388 .map(|x| x.to_u64().unwrap() as usize)
389 .unwrap_or(0);
390 let moe_or_mlp = if n_expert <= 1 {
391 let ffn_gate = tensor_info_size_in_bytes!(self
392 .model
393 .tensor_info("blk.0.ffn_gate.weight")?);
394 let ffn_up = tensor_info_size_in_bytes!(self
395 .model
396 .tensor_info("blk.0.ffn_up.weight")?);
397 let ffn_down = tensor_info_size_in_bytes!(self
398 .model
399 .tensor_info("blk.0.ffn_down.weight")?);
400 ffn_gate + ffn_up + ffn_down
401 } else {
402 let mut moe_count = 0;
403 moe_count += tensor_info_size_in_bytes!(self
404 .model
405 .tensor_info("blk.0.ffn_gate_inp.weight")?);
406 match self.model.tensor_info("blk.0.ffn_gate_exps.weight") {
407 Ok(feed_forward_gate_exps) => {
408 moe_count += tensor_info_size_in_bytes!(feed_forward_gate_exps);
409 moe_count += tensor_info_size_in_bytes!(self
410 .model
411 .tensor_info("blk.0.ffn_down_exps.weight")?);
412 moe_count += tensor_info_size_in_bytes!(self
413 .model
414 .tensor_info("blk.0.ffn_up_exps.weight")?);
415 }
416 Err(_) => {
417 for i in 0..n_expert {
418 moe_count += tensor_info_size_in_bytes!(self
419 .model
420 .tensor_info(&format!("blk.0.ffn_gate.{i}.weight"),)?);
421 moe_count += tensor_info_size_in_bytes!(self
422 .model
423 .tensor_info(&format!("blk.0.ffn_down.{i}.weight"),)?);
424 moe_count += tensor_info_size_in_bytes!(self
425 .model
426 .tensor_info(&format!("blk.0.ffn_up.{i}.weight"))?);
427 }
428 }
429 }
430
431 moe_count
432 };
433 attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + moe_or_mlp
434 }
435 GGUFArchitecture::Phi2 => {
436 let attn_norm = tensor_info_size_in_bytes!(
437 self.model.tensor_info("blk.0.attn_norm.weight")?,
438 DType::F32
439 ) + tensor_info_size_in_bytes!(self
440 .model
441 .tensor_info("blk.0.attn_norm.bias")?);
442
443 let attn_qkv =
444 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
445 let attn_output = tensor_info_size_in_bytes!(self
446 .model
447 .tensor_info("blk.0.attn_output.weight")?);
448
449 let ffn_up =
450 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
451 let ffn_down =
452 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
453
454 attn_norm + attn_qkv + attn_output + ffn_up + ffn_down
455 }
456 GGUFArchitecture::Phi3 => {
457 let attn_norm = tensor_info_size_in_bytes!(
458 self.model.tensor_info("blk.0.attn_norm.weight")?,
459 DType::F32
460 );
461 let ffn_norm = tensor_info_size_in_bytes!(
462 self.model.tensor_info("blk.0.ffn_norm.weight")?,
463 DType::F32
464 );
465
466 let attn_qkv =
467 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_qkv.weight")?);
468 let attn_output = tensor_info_size_in_bytes!(self
469 .model
470 .tensor_info("blk.0.attn_output.weight")?);
471
472 let ffn_up =
473 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
474 let ffn_down =
475 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
476
477 attn_norm + ffn_norm + attn_qkv + attn_output + ffn_up + ffn_down
478 }
479 GGUFArchitecture::Qwen2 => {
480 let attn_norm = tensor_info_size_in_bytes!(
481 self.model.tensor_info("blk.0.attn_norm.weight")?,
482 DType::F32
483 );
484 let ffn_norm = tensor_info_size_in_bytes!(
485 self.model.tensor_info("blk.0.ffn_norm.weight")?,
486 DType::F32
487 );
488
489 let attn_q = tensor_info_size_in_bytes!(self
490 .model
491 .tensor_info("blk.0.attn_q.weight")?)
492 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
493 let attn_k = tensor_info_size_in_bytes!(self
494 .model
495 .tensor_info("blk.0.attn_k.weight")?)
496 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
497 let attn_v = tensor_info_size_in_bytes!(self
498 .model
499 .tensor_info("blk.0.attn_v.weight")?)
500 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
501 let attn_output = tensor_info_size_in_bytes!(self
502 .model
503 .tensor_info("blk.0.attn_output.weight")?);
504
505 let ffn_gate =
506 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_gate.weight")?);
507 let ffn_up =
508 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.weight")?);
509 let ffn_down =
510 tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.weight")?);
511
512 attn_norm
513 + ffn_norm
514 + attn_q
515 + attn_k
516 + attn_v
517 + attn_output
518 + ffn_gate
519 + ffn_up
520 + ffn_down
521 }
522 GGUFArchitecture::Starcoder2 => {
523 let attn_norm = tensor_info_size_in_bytes!(
524 self.model.tensor_info("blk.0.attn_norm.weight")?,
525 DType::F32
526 ) + tensor_info_size_in_bytes!(self
527 .model
528 .tensor_info("blk.0.attn_norm.bias")?);
529 let ffn_norm = tensor_info_size_in_bytes!(
530 self.model.tensor_info("blk.0.ffn_norm.weight")?,
531 DType::F32
532 ) + tensor_info_size_in_bytes!(self
533 .model
534 .tensor_info("blk.0.ffn_norm.bias")?);
535
536 let attn_q = tensor_info_size_in_bytes!(self
537 .model
538 .tensor_info("blk.0.attn_q.weight")?)
539 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_q.bias")?);
540 let attn_k = tensor_info_size_in_bytes!(self
541 .model
542 .tensor_info("blk.0.attn_k.weight")?)
543 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_k.bias")?);
544 let attn_v = tensor_info_size_in_bytes!(self
545 .model
546 .tensor_info("blk.0.attn_v.weight")?)
547 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.attn_v.bias")?);
548 let attn_output = tensor_info_size_in_bytes!(self
549 .model
550 .tensor_info("blk.0.attn_output.weight")?)
551 + tensor_info_size_in_bytes!(self
552 .model
553 .tensor_info("blk.0.attn_output.bias")?);
554
555 let ffn_up = tensor_info_size_in_bytes!(self
556 .model
557 .tensor_info("blk.0.ffn_up.weight")?)
558 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_up.bias")?);
559 let ffn_down = tensor_info_size_in_bytes!(self
560 .model
561 .tensor_info("blk.0.ffn_down.weight")?)
562 + tensor_info_size_in_bytes!(self.model.tensor_info("blk.0.ffn_down.bias")?);
563
564 attn_norm + ffn_norm + attn_q + attn_k + attn_v + attn_output + ffn_up + ffn_down
565 }
566 _ => unimplemented!(),
567 };
568 Ok(vec![size_in_bytes; self.num_layers(config)?])
569 }
570 fn model_config(&self, _config: &str) -> Result<Box<dyn ModelConfigLike>> {
571 let model_config_metadata: ContentConfig = self.model.into();
572 Ok(Box::new(model_config_metadata))
573 }
574}