mistralrs_core/models/
quantized_llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::quantized::ggml_file;
7use candle_core::quantized::QTensor;
8use candle_core::{DType, Device, Result, Tensor};
9use candle_nn::{Embedding, Module};
10use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
11
12use crate::attention::SdpaParams;
13use crate::device_map::DeviceMapper;
14use crate::gguf::Content;
15use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
16use crate::layers_masker::PastKvLenCache;
17use crate::paged_attention::{AttentionImplementation, PagedAttention};
18use crate::pipeline::extract_logits;
19use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
20use crate::pipeline::EitherCache;
21use crate::pipeline::KvCache;
22use crate::pipeline::NormalCache;
23use crate::utils::gguf_metadata::ContentMetadata;
24use crate::utils::model_config as ModelConfig;
25use crate::utils::progress::{new_multi_progress, NiceProgressBar};
26// Default fallback for models that don't specify context_length
27const DEFAULT_MAX_SEQ_LEN: u32 = 4096;
28
29struct Mlp {
30    feed_forward_w1: Arc<dyn QuantMethod>,
31    feed_forward_w2: Arc<dyn QuantMethod>,
32    feed_forward_w3: Arc<dyn QuantMethod>,
33}
34
35impl Mlp {
36    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
37        let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
38        let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
39        let y = &(candle_nn::ops::silu(&w1)? * w3)?;
40        MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
41    }
42}
43
44enum MlpOrMoe {
45    Mlp(Mlp),
46    MoE {
47        n_expert_used: usize,
48        feed_forward_gate_inp: Arc<dyn QuantMethod>,
49        experts: Vec<Mlp>,
50    },
51}
52
53impl MlpOrMoe {
54    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
55        match self {
56            Self::MoE {
57                feed_forward_gate_inp,
58                experts,
59                n_expert_used,
60            } => {
61                let (b_size, seq_len, hidden_dim) = xs.dims3()?;
62                let xs = xs.reshape(((), hidden_dim))?;
63                let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?;
64                let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
65
66                // In order to extract topk, we extract the data from the tensor and manipulate it
67                // directly. Maybe we will want to use some custom ops instead at some point.
68                let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
69
70                // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
71                // top_x contains the row indexes to evaluate for each expert.
72                let mut top_x = vec![vec![]; experts.len()];
73                let mut selected_rws = vec![vec![]; experts.len()];
74                for (row_idx, rw) in routing_weights.iter().enumerate() {
75                    let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
76                    dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
77                    let mut sum_routing_weights = 0f32;
78                    for &expert_idx in dst.iter().take(*n_expert_used) {
79                        let expert_idx = expert_idx as usize;
80                        let routing_weight = rw[expert_idx];
81                        sum_routing_weights += routing_weight;
82                        top_x[expert_idx].push(row_idx as u32);
83                    }
84                    for &expert_idx in dst.iter().take(*n_expert_used) {
85                        let expert_idx = expert_idx as usize;
86                        let routing_weight = rw[expert_idx];
87                        selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
88                    }
89                }
90
91                // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
92                // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
93
94                let mut ys = xs.zeros_like()?;
95                for (expert_idx, expert_layer) in experts.iter().enumerate() {
96                    let top_x = &top_x[expert_idx];
97                    if top_x.is_empty() {
98                        continue;
99                    }
100                    let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
101                    let selected_rws =
102                        Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
103                            .reshape(((), 1))?;
104                    // Index the correct hidden states and compute the expert hidden state for
105                    // the current expert. We need to make sure to multiply the output hidden
106                    // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
107                    let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
108                    // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
109                    let current_hidden_states = expert_layer.forward(&current_state)?;
110                    let current_hidden_states =
111                        current_hidden_states.broadcast_mul(&selected_rws)?;
112                    ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
113                }
114
115                let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
116                Ok(ys)
117            }
118            Self::Mlp(mlp) => mlp.forward(xs),
119        }
120    }
121}
122
123struct LayerWeights {
124    attention_wq: Arc<dyn QuantMethod>,
125    attention_wk: Arc<dyn QuantMethod>,
126    attention_wv: Arc<dyn QuantMethod>,
127    attention_wo: Arc<dyn QuantMethod>,
128    attention_norm: QRmsNorm,
129    mlp_or_moe: MlpOrMoe,
130    ffn_norm: QRmsNorm,
131    n_head: usize,
132    n_kv_head: usize,
133    head_dim: usize,
134    rotary: Arc<RotaryEmbedding>,
135    paged_attn: Option<PagedAttention>,
136    sdpa_params: SdpaParams,
137    dtype: DType,
138}
139
140impl LayerWeights {
141    fn forward_attn(
142        &self,
143        x: &Tensor,
144        mask: Option<&Tensor>,
145        start_offsets: &[usize],
146        kv_cache: &mut KvCache,
147        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
148    ) -> Result<Tensor> {
149        let (b_sz, seq_len, _) = x.dims3()?;
150
151        let q = MatMul
152            .qmethod_matmul(x, &*self.attention_wq)?
153            .to_dtype(self.dtype)?;
154        let k = MatMul
155            .qmethod_matmul(x, &*self.attention_wk)?
156            .to_dtype(self.dtype)?;
157        let v = MatMul
158            .qmethod_matmul(x, &*self.attention_wv)?
159            .to_dtype(self.dtype)?;
160
161        let (q, k, v) = if seq_len != 1 {
162            let q = q
163                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
164                .transpose(1, 2)?;
165            let k = k
166                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
167                .transpose(1, 2)?;
168            let v = v
169                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
170                .transpose(1, 2)?;
171            (q, k, v)
172        } else {
173            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
174            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
175            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
176            (q, k, v)
177        };
178
179        let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
180
181        let y = match &self.paged_attn {
182            Some(paged_attn) => {
183                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
184                paged_attn.forward(
185                    &q,
186                    &k,
187                    &v,
188                    mask,
189                    Some(key_cache),
190                    Some(value_cache),
191                    input_metadata,
192                    &self.sdpa_params,
193                    None,
194                )?
195            }
196            None => {
197                let (k, v) = kv_cache.append(&k, &v)?;
198
199                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
200            }
201        };
202
203        let y = if mask.is_some() {
204            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
205        } else {
206            y.reshape((b_sz, seq_len, ()))?
207        };
208
209        let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
210        Ok(y)
211    }
212}
213
214pub struct ModelWeights {
215    tok_embeddings: Embedding,
216    layers: Vec<LayerWeights>,
217    norm: QRmsNorm,
218    output: Arc<dyn QuantMethod>,
219    pub device: Device,
220    pub cache: EitherCache,
221    pub max_seq_len: usize,
222    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
223    dtype: DType,
224}
225
226impl ModelConfig::FromGGML for ModelWeights {
227    fn from_ggml(mut ct: ggml_file::Content, gqa: usize, dtype: DType) -> Result<Self> {
228        let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
229        let rotary = RotaryEmbedding::new_partial(
230            10000.,
231            ct.hparams.n_rot as usize,
232            DEFAULT_MAX_SEQ_LEN as usize,
233            &ct.device,
234            false,
235            dtype,
236        )?;
237        let tok_embeddings = ct.remove("tok_embeddings.weight")?;
238        let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
239        let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
240        let output = ct.remove("output.weight")?;
241        let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
242        for layer_idx in NiceProgressBar::<_, 'b'>(
243            0..ct.hparams.n_layer,
244            "Loading repeating layers",
245            &new_multi_progress(),
246        ) {
247            let prefix = format!("layers.{layer_idx}");
248            let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
249            let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
250            let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
251            let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
252            let mlp_or_moe = {
253                let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
254                let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
255                let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
256                MlpOrMoe::Mlp(Mlp {
257                    feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
258                        q_weight: Arc::new(feed_forward_w1),
259                        b: None,
260                    })?),
261                    feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
262                        q_weight: Arc::new(feed_forward_w2),
263                        b: None,
264                    })?),
265                    feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
266                        q_weight: Arc::new(feed_forward_w3),
267                        b: None,
268                    })?),
269                })
270            };
271            let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
272            let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
273            let n_kv_head = ct.hparams.n_head as usize / gqa;
274            layers.push(LayerWeights {
275                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
276                    q_weight: Arc::new(attention_wq),
277                    b: None,
278                })?),
279                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
280                    q_weight: Arc::new(attention_wk),
281                    b: None,
282                })?),
283                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
284                    q_weight: Arc::new(attention_wv),
285                    b: None,
286                })?),
287                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
288                    q_weight: Arc::new(attention_wo),
289                    b: None,
290                })?),
291                attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
292                mlp_or_moe,
293                ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
294                n_head: ct.hparams.n_head as usize,
295                n_kv_head: ct.hparams.n_head as usize / gqa,
296                head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
297                rotary: rotary.clone().into(),
298                paged_attn: None, // TODO
299                sdpa_params: SdpaParams {
300                    n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
301                    softcap: None,
302                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
303                    sliding_window: None,
304                },
305                dtype,
306            })
307        }
308        Ok(Self {
309            tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
310            layers,
311            norm,
312            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
313                q_weight: Arc::new(output),
314                b: None,
315            })?),
316            device: ct.device.clone(),
317            cache: EitherCache::Normal(NormalCache::new(
318                ct.hparams.n_layer as usize,
319                DEFAULT_MAX_SEQ_LEN as usize,
320            )),
321            max_seq_len: DEFAULT_MAX_SEQ_LEN as usize, // Cannot determine from ggml.
322            mapper: None,
323            dtype,
324        })
325    }
326}
327
328// llama `llm` fields:
329// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
330// NOTE: Types here do not match spec
331pub(crate) struct PropsGGUF {
332    pub n_expert: usize,
333    pub n_expert_used: usize,
334    pub head_count: usize,
335    pub head_count_kv: usize,
336    pub block_count: usize,
337    pub embedding_length: usize,
338    pub rope_dim: usize,
339    pub rms_norm_eps: f32,
340    pub max_seq_len: usize,
341    pub rope_freq_base: f32,
342    pub key_length: usize,
343    pub value_length: usize,
344}
345
346impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
347    type Error = anyhow::Error;
348
349    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
350        c.verify_arch("llama")?;
351
352        let required = [
353            "attention.head_count",
354            "attention.head_count_kv",
355            "block_count",
356            "embedding_length",
357            "rope.dimension_count",
358            "attention.layer_norm_rms_epsilon",
359        ];
360        c.has_required_keys(&required)?;
361
362        let embed_len = c.get_value::<u32>("embedding_length")? as usize;
363        let head_count = c.get_value::<u32>("attention.head_count")? as usize;
364
365        // NOTE: Values are not aligned with GGUFv3 types
366        // TODO: Normalize value types to spec
367        let props = Self {
368            n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
369            n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
370            head_count,
371            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
372            block_count: c.get_value::<u32>("block_count")? as usize,
373            embedding_length: embed_len,
374            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
375            // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
376            rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
377            max_seq_len: c
378                .get_value::<u64>("context_length")
379                .ok()
380                .unwrap_or(DEFAULT_MAX_SEQ_LEN as u64) as usize,
381            rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
382            key_length: c
383                .get_value::<u32>("attention.key_length")
384                .ok()
385                .map(|x| x as usize)
386                .unwrap_or(embed_len / head_count),
387            value_length: c
388                .get_value::<u32>("attention.value_length")
389                .ok()
390                .map(|x| x as usize)
391                .unwrap_or(embed_len / head_count),
392        };
393
394        Ok(props)
395    }
396}
397
398impl ModelConfig::FromGGUF for ModelWeights {
399    fn from_gguf<R: std::io::Seek + std::io::Read>(
400        mut ct: Content<'_, R>,
401        device: &Device,
402        mapper: Box<dyn DeviceMapper + Send + Sync>,
403        attention_mechanism: AttentionImplementation,
404        dtype: DType,
405    ) -> Result<Self> {
406        // Parameter extraction from metadata.
407        let metadata = ContentMetadata {
408            path_prefix: "llama",
409            metadata: ct.get_metadata(),
410        };
411        let PropsGGUF {
412            n_expert,
413            n_expert_used,
414            head_count,
415            head_count_kv,
416            block_count,
417            embedding_length,
418            rope_dim,
419            rms_norm_eps,
420            max_seq_len,
421            rope_freq_base,
422            key_length,
423            value_length,
424        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
425
426        let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
427        let tok_embeddings = qtok_embeddings.dequantize(device)?;
428        let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
429        let output = if !ct.has_tensor("output.weight") {
430            ct.tensor("token_embd.weight", device)?
431        } else {
432            ct.tensor("output.weight", device)?
433        };
434        let mut layers = Vec::with_capacity(block_count);
435
436        let head_dim = key_length;
437        if key_length != value_length {
438            candle_core::bail!(
439                "Expected key_length == value_length, got {key_length} != {value_length}"
440            );
441        }
442
443        let mut ropes = HashMap::new();
444        for layer_idx in 0..block_count {
445            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
446            ropes.insert(
447                device.location(),
448                Arc::new(RotaryEmbedding::new(
449                    rope_freq_base,
450                    rope_dim,
451                    max_seq_len,
452                    device,
453                    false,
454                    dtype,
455                )?),
456            );
457        }
458
459        for layer_idx in NiceProgressBar::<_, 'b'>(
460            0..block_count,
461            "Loading repeating layers",
462            &new_multi_progress(),
463        ) {
464            let prefix = format!("blk.{layer_idx}");
465            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
466            let rotary = ropes
467                .get(&device.location())
468                .expect("No RoPE for device location!")
469                .clone();
470
471            let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
472            let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
473            let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
474            let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
475            let mlp_or_moe = if n_expert <= 1 {
476                let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
477                let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
478                let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
479                MlpOrMoe::Mlp(Mlp {
480                    feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
481                        q_weight: Arc::new(feed_forward_w1),
482                        b: None,
483                    })?),
484                    feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
485                        q_weight: Arc::new(feed_forward_w2),
486                        b: None,
487                    })?),
488                    feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
489                        q_weight: Arc::new(feed_forward_w3),
490                        b: None,
491                    })?),
492                })
493            } else {
494                let feed_forward_gate_inp =
495                    ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
496                let mut experts = Vec::with_capacity(n_expert);
497                match ct.tensor(&format!("{prefix}.ffn_gate_exps.weight"), device) {
498                    Ok(feed_forward_gate_exps) => {
499                        let feed_forward_down_exps =
500                            ct.tensor(&format!("{prefix}.ffn_down_exps.weight"), device)?;
501                        let feed_forward_up_exps =
502                            ct.tensor(&format!("{prefix}.ffn_up_exps.weight"), device)?;
503
504                        let dequant_ffn_gate = feed_forward_gate_exps
505                            .dequantize(device)?
506                            .chunk(n_expert, 0)?;
507                        let dequant_ffn_down = feed_forward_down_exps
508                            .dequantize(device)?
509                            .chunk(n_expert, 0)?;
510                        let dequant_ffn_up = feed_forward_up_exps
511                            .dequantize(device)?
512                            .chunk(n_expert, 0)?;
513
514                        assert_eq!(dequant_ffn_up.len(), dequant_ffn_down.len());
515                        assert_eq!(dequant_ffn_gate.len(), dequant_ffn_down.len());
516                        assert_eq!(dequant_ffn_gate.len(), n_expert);
517
518                        let gate_type = feed_forward_gate_exps.dtype();
519                        let down_type = feed_forward_down_exps.dtype();
520                        let up_type = feed_forward_up_exps.dtype();
521
522                        for (ff_w1, (ff_w2, ff_w3)) in dequant_ffn_gate
523                            .into_iter()
524                            .zip(dequant_ffn_down.into_iter().zip(dequant_ffn_up))
525                        {
526                            experts.push(Mlp {
527                                feed_forward_w1: Arc::new(GgufMatMul::new(
528                                    QuantMethodConfig::Gguf {
529                                        q_weight: Arc::new(QTensor::quantize(&ff_w1, gate_type)?),
530                                        b: None,
531                                    },
532                                )?),
533                                feed_forward_w2: Arc::new(GgufMatMul::new(
534                                    QuantMethodConfig::Gguf {
535                                        q_weight: Arc::new(QTensor::quantize(&ff_w2, down_type)?),
536                                        b: None,
537                                    },
538                                )?),
539                                feed_forward_w3: Arc::new(GgufMatMul::new(
540                                    QuantMethodConfig::Gguf {
541                                        q_weight: Arc::new(QTensor::quantize(&ff_w3, up_type)?),
542                                        b: None,
543                                    },
544                                )?),
545                            })
546                        }
547                    }
548                    Err(_) => {
549                        for i in 0..n_expert {
550                            let feed_forward_w1 =
551                                ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
552                            let feed_forward_w2 =
553                                ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
554                            let feed_forward_w3 =
555                                ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
556                            experts.push(Mlp {
557                                feed_forward_w1: Arc::new(GgufMatMul::new(
558                                    QuantMethodConfig::Gguf {
559                                        q_weight: Arc::new(feed_forward_w1),
560                                        b: None,
561                                    },
562                                )?),
563                                feed_forward_w2: Arc::new(GgufMatMul::new(
564                                    QuantMethodConfig::Gguf {
565                                        q_weight: Arc::new(feed_forward_w2),
566                                        b: None,
567                                    },
568                                )?),
569                                feed_forward_w3: Arc::new(GgufMatMul::new(
570                                    QuantMethodConfig::Gguf {
571                                        q_weight: Arc::new(feed_forward_w3),
572                                        b: None,
573                                    },
574                                )?),
575                            })
576                        }
577                    }
578                }
579                MlpOrMoe::MoE {
580                    n_expert_used,
581                    feed_forward_gate_inp: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
582                        q_weight: Arc::new(feed_forward_gate_inp),
583                        b: None,
584                    })?),
585                    experts,
586                }
587            };
588            let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
589            let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
590            let paged_attn = match &attention_mechanism {
591                AttentionImplementation::Eager => None,
592                AttentionImplementation::PagedAttention => {
593                    Some(PagedAttention::new(head_dim, device, None)?)
594                }
595            };
596            layers.push(LayerWeights {
597                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
598                    q_weight: Arc::new(attention_wq),
599                    b: None,
600                })?),
601                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
602                    q_weight: Arc::new(attention_wk),
603                    b: None,
604                })?),
605                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
606                    q_weight: Arc::new(attention_wv),
607                    b: None,
608                })?),
609                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
610                    q_weight: Arc::new(attention_wo),
611                    b: None,
612                })?),
613                attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
614                mlp_or_moe,
615                ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
616                n_head: head_count,
617                n_kv_head: head_count_kv,
618                head_dim,
619                rotary: rotary.clone(),
620                paged_attn,
621                sdpa_params: SdpaParams {
622                    n_kv_groups: head_count / head_count_kv,
623                    softcap: None,
624                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
625                    sliding_window: None,
626                },
627                dtype,
628            })
629        }
630        Ok(Self {
631            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
632            layers,
633            norm,
634            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
635                q_weight: Arc::new(output),
636                b: None,
637            })?),
638            device: device.clone(),
639            cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
640            max_seq_len,
641            mapper: Some(mapper),
642            dtype,
643        })
644    }
645}
646
647impl ModelWeights {
648    pub fn forward(
649        &self,
650        x: &Tensor,
651        start_offsets: &[usize],
652        context_lens: Vec<(usize, usize)>,
653        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
654    ) -> Result<Tensor> {
655        let mut layer_in = self.tok_embeddings.forward(x)?;
656        let cache = &mut self.cache.normal().0;
657        let mask = CausalMasker.make_causal_mask_matrix(
658            x,
659            metadata
660                .as_ref()
661                .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
662                .unwrap_or(cache as &dyn PastKvLenCache),
663            self.dtype,
664            self.layers[0].n_head,
665        )?;
666        // PagedAttention prompt chunking
667        let mask = mask.filter(|_| {
668            metadata
669                .as_ref()
670                .map(|(_, meta)| meta.is_first_prompt_chunk)
671                .unwrap_or(true)
672        });
673        for (i, layer) in self.layers.iter().enumerate() {
674            if let Some(ref mapper) = self.mapper {
675                layer_in = mapper.map(layer_in, i)?;
676            }
677            let x = layer_in;
678            let residual = &x;
679            let x = layer.attention_norm.forward(&x)?;
680            let attn = layer.forward_attn(
681                &x,
682                mask.as_ref()
683                    .map(|m| m.to_device(x.device()).unwrap())
684                    .as_ref(),
685                start_offsets,
686                &mut cache[i],
687                metadata
688                    .as_ref()
689                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
690            )?;
691            let x = (attn + residual)?;
692
693            // MLP
694            let residual = &x;
695            let x = layer.ffn_norm.forward(&x)?;
696            let x = layer.mlp_or_moe.forward(&x)?;
697            let x = (x + residual)?;
698            layer_in = x;
699        }
700        let layer_in = layer_in.to_device(&self.device)?;
701        let x = self.norm.forward(&layer_in)?;
702        extract_logits(
703            &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
704            context_lens,
705        )
706    }
707}