mistralrs_core/models/
quantized_llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::quantized::ggml_file;
7use candle_core::quantized::QTensor;
8use candle_core::{DType, Device, Result, Tensor};
9use candle_nn::{Embedding, Module};
10use indicatif::MultiProgress;
11use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
12
13use crate::attention::SdpaParams;
14use crate::device_map::DeviceMapper;
15use crate::gguf::Content;
16use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
17use crate::layers_masker::PastKvLenCache;
18use crate::paged_attention::{AttentionImplementation, PagedAttention};
19use crate::pipeline::extract_logits;
20use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
21use crate::pipeline::EitherCache;
22use crate::pipeline::KvCache;
23use crate::pipeline::NormalCache;
24use crate::utils::gguf_metadata::ContentMetadata;
25use crate::utils::model_config as ModelConfig;
26use crate::utils::progress::NiceProgressBar;
27// Default fallback for models that don't specify context_length
28const DEFAULT_MAX_SEQ_LEN: u32 = 4096;
29
30struct Mlp {
31    feed_forward_w1: Arc<dyn QuantMethod>,
32    feed_forward_w2: Arc<dyn QuantMethod>,
33    feed_forward_w3: Arc<dyn QuantMethod>,
34}
35
36impl Mlp {
37    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
38        let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
39        let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
40        let y = &(candle_nn::ops::silu(&w1)? * w3)?;
41        MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
42    }
43}
44
45enum MlpOrMoe {
46    Mlp(Mlp),
47    MoE {
48        n_expert_used: usize,
49        feed_forward_gate_inp: Arc<dyn QuantMethod>,
50        experts: Vec<Mlp>,
51    },
52}
53
54impl MlpOrMoe {
55    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
56        match self {
57            Self::MoE {
58                feed_forward_gate_inp,
59                experts,
60                n_expert_used,
61            } => {
62                let (b_size, seq_len, hidden_dim) = xs.dims3()?;
63                let xs = xs.reshape(((), hidden_dim))?;
64                let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?;
65                let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
66
67                // In order to extract topk, we extract the data from the tensor and manipulate it
68                // directly. Maybe we will want to use some custom ops instead at some point.
69                let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
70
71                // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
72                // top_x contains the row indexes to evaluate for each expert.
73                let mut top_x = vec![vec![]; experts.len()];
74                let mut selected_rws = vec![vec![]; experts.len()];
75                for (row_idx, rw) in routing_weights.iter().enumerate() {
76                    let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
77                    dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
78                    let mut sum_routing_weights = 0f32;
79                    for &expert_idx in dst.iter().take(*n_expert_used) {
80                        let expert_idx = expert_idx as usize;
81                        let routing_weight = rw[expert_idx];
82                        sum_routing_weights += routing_weight;
83                        top_x[expert_idx].push(row_idx as u32);
84                    }
85                    for &expert_idx in dst.iter().take(*n_expert_used) {
86                        let expert_idx = expert_idx as usize;
87                        let routing_weight = rw[expert_idx];
88                        selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
89                    }
90                }
91
92                // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
93                // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
94
95                let mut ys = xs.zeros_like()?;
96                for (expert_idx, expert_layer) in experts.iter().enumerate() {
97                    let top_x = &top_x[expert_idx];
98                    if top_x.is_empty() {
99                        continue;
100                    }
101                    let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
102                    let selected_rws =
103                        Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
104                            .reshape(((), 1))?;
105                    // Index the correct hidden states and compute the expert hidden state for
106                    // the current expert. We need to make sure to multiply the output hidden
107                    // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
108                    let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
109                    // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
110                    let current_hidden_states = expert_layer.forward(&current_state)?;
111                    let current_hidden_states =
112                        current_hidden_states.broadcast_mul(&selected_rws)?;
113                    ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
114                }
115
116                let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
117                Ok(ys)
118            }
119            Self::Mlp(mlp) => mlp.forward(xs),
120        }
121    }
122}
123
124struct LayerWeights {
125    attention_wq: Arc<dyn QuantMethod>,
126    attention_wk: Arc<dyn QuantMethod>,
127    attention_wv: Arc<dyn QuantMethod>,
128    attention_wo: Arc<dyn QuantMethod>,
129    attention_norm: QRmsNorm,
130    mlp_or_moe: MlpOrMoe,
131    ffn_norm: QRmsNorm,
132    n_head: usize,
133    n_kv_head: usize,
134    head_dim: usize,
135    rotary: Arc<RotaryEmbedding>,
136    paged_attn: Option<PagedAttention>,
137    sdpa_params: SdpaParams,
138    dtype: DType,
139}
140
141impl LayerWeights {
142    fn forward_attn(
143        &self,
144        x: &Tensor,
145        mask: Option<&Tensor>,
146        start_offsets: &[usize],
147        kv_cache: &mut KvCache,
148        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
149    ) -> Result<Tensor> {
150        let (b_sz, seq_len, _) = x.dims3()?;
151
152        let q = MatMul
153            .qmethod_matmul(x, &*self.attention_wq)?
154            .to_dtype(self.dtype)?;
155        let k = MatMul
156            .qmethod_matmul(x, &*self.attention_wk)?
157            .to_dtype(self.dtype)?;
158        let v = MatMul
159            .qmethod_matmul(x, &*self.attention_wv)?
160            .to_dtype(self.dtype)?;
161
162        let (q, k, v) = if seq_len != 1 {
163            let q = q
164                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
165                .transpose(1, 2)?;
166            let k = k
167                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
168                .transpose(1, 2)?;
169            let v = v
170                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
171                .transpose(1, 2)?;
172            (q, k, v)
173        } else {
174            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
175            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
176            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
177            (q, k, v)
178        };
179
180        let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
181
182        let y = match &self.paged_attn {
183            Some(paged_attn) => {
184                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
185                paged_attn.forward(
186                    &q,
187                    &k,
188                    &v,
189                    mask,
190                    Some(key_cache),
191                    Some(value_cache),
192                    input_metadata,
193                    &self.sdpa_params,
194                    None,
195                )?
196            }
197            None => {
198                let (k, v) = kv_cache.append(&k, &v)?;
199
200                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
201            }
202        };
203
204        let y = if mask.is_some() {
205            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
206        } else {
207            y.reshape((b_sz, seq_len, ()))?
208        };
209
210        let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
211        Ok(y)
212    }
213}
214
215pub struct ModelWeights {
216    tok_embeddings: Embedding,
217    layers: Vec<LayerWeights>,
218    norm: QRmsNorm,
219    output: Arc<dyn QuantMethod>,
220    pub device: Device,
221    pub cache: EitherCache,
222    pub max_seq_len: usize,
223    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
224    dtype: DType,
225}
226
227impl ModelConfig::FromGGML for ModelWeights {
228    fn from_ggml(mut ct: ggml_file::Content, gqa: usize, dtype: DType) -> Result<Self> {
229        let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
230        let rotary = RotaryEmbedding::new_partial(
231            10000.,
232            ct.hparams.n_rot as usize,
233            DEFAULT_MAX_SEQ_LEN as usize,
234            &ct.device,
235            false,
236            dtype,
237        )?;
238        let tok_embeddings = ct.remove("tok_embeddings.weight")?;
239        let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
240        let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
241        let output = ct.remove("output.weight")?;
242        let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
243        for layer_idx in NiceProgressBar::<_, 'b'>(
244            0..ct.hparams.n_layer,
245            "Loading repeating layers",
246            &MultiProgress::new(),
247        ) {
248            let prefix = format!("layers.{layer_idx}");
249            let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
250            let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
251            let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
252            let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
253            let mlp_or_moe = {
254                let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
255                let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
256                let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
257                MlpOrMoe::Mlp(Mlp {
258                    feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
259                        q_weight: Arc::new(feed_forward_w1),
260                        b: None,
261                    })?),
262                    feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
263                        q_weight: Arc::new(feed_forward_w2),
264                        b: None,
265                    })?),
266                    feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
267                        q_weight: Arc::new(feed_forward_w3),
268                        b: None,
269                    })?),
270                })
271            };
272            let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
273            let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
274            let n_kv_head = ct.hparams.n_head as usize / gqa;
275            layers.push(LayerWeights {
276                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
277                    q_weight: Arc::new(attention_wq),
278                    b: None,
279                })?),
280                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
281                    q_weight: Arc::new(attention_wk),
282                    b: None,
283                })?),
284                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
285                    q_weight: Arc::new(attention_wv),
286                    b: None,
287                })?),
288                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
289                    q_weight: Arc::new(attention_wo),
290                    b: None,
291                })?),
292                attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
293                mlp_or_moe,
294                ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
295                n_head: ct.hparams.n_head as usize,
296                n_kv_head: ct.hparams.n_head as usize / gqa,
297                head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
298                rotary: rotary.clone().into(),
299                paged_attn: None, // TODO
300                sdpa_params: SdpaParams {
301                    n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
302                    softcap: None,
303                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
304                    sliding_window: None,
305                },
306                dtype,
307            })
308        }
309        Ok(Self {
310            tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
311            layers,
312            norm,
313            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
314                q_weight: Arc::new(output),
315                b: None,
316            })?),
317            device: ct.device.clone(),
318            cache: EitherCache::Normal(NormalCache::new(
319                ct.hparams.n_layer as usize,
320                DEFAULT_MAX_SEQ_LEN as usize,
321            )),
322            max_seq_len: DEFAULT_MAX_SEQ_LEN as usize, // Cannot determine from ggml.
323            mapper: None,
324            dtype,
325        })
326    }
327}
328
329// llama `llm` fields:
330// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
331// NOTE: Types here do not match spec
332pub(crate) struct PropsGGUF {
333    pub n_expert: usize,
334    pub n_expert_used: usize,
335    pub head_count: usize,
336    pub head_count_kv: usize,
337    pub block_count: usize,
338    pub embedding_length: usize,
339    pub rope_dim: usize,
340    pub rms_norm_eps: f32,
341    pub max_seq_len: usize,
342    pub rope_freq_base: f32,
343    pub key_length: usize,
344    pub value_length: usize,
345}
346
347impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
348    type Error = anyhow::Error;
349
350    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
351        c.verify_arch("llama")?;
352
353        let required = [
354            "attention.head_count",
355            "attention.head_count_kv",
356            "block_count",
357            "embedding_length",
358            "rope.dimension_count",
359            "attention.layer_norm_rms_epsilon",
360        ];
361        c.has_required_keys(&required)?;
362
363        let embed_len = c.get_value::<u32>("embedding_length")? as usize;
364        let head_count = c.get_value::<u32>("attention.head_count")? as usize;
365
366        // NOTE: Values are not aligned with GGUFv3 types
367        // TODO: Normalize value types to spec
368        let props = Self {
369            n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
370            n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
371            head_count,
372            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
373            block_count: c.get_value::<u32>("block_count")? as usize,
374            embedding_length: embed_len,
375            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
376            // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
377            rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
378            max_seq_len: c
379                .get_value::<u64>("context_length")
380                .ok()
381                .unwrap_or(DEFAULT_MAX_SEQ_LEN as u64) as usize,
382            rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
383            key_length: c
384                .get_value::<u32>("attention.key_length")
385                .ok()
386                .map(|x| x as usize)
387                .unwrap_or(embed_len / head_count),
388            value_length: c
389                .get_value::<u32>("attention.value_length")
390                .ok()
391                .map(|x| x as usize)
392                .unwrap_or(embed_len / head_count),
393        };
394
395        Ok(props)
396    }
397}
398
399impl ModelConfig::FromGGUF for ModelWeights {
400    fn from_gguf<R: std::io::Seek + std::io::Read>(
401        mut ct: Content<'_, R>,
402        device: &Device,
403        mapper: Box<dyn DeviceMapper + Send + Sync>,
404        attention_mechanism: AttentionImplementation,
405        dtype: DType,
406    ) -> Result<Self> {
407        // Parameter extraction from metadata.
408        let metadata = ContentMetadata {
409            path_prefix: "llama",
410            metadata: ct.get_metadata(),
411        };
412        let PropsGGUF {
413            n_expert,
414            n_expert_used,
415            head_count,
416            head_count_kv,
417            block_count,
418            embedding_length,
419            rope_dim,
420            rms_norm_eps,
421            max_seq_len,
422            rope_freq_base,
423            key_length,
424            value_length,
425        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
426
427        let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
428        let tok_embeddings = qtok_embeddings.dequantize(device)?;
429        let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
430        let output = if !ct.has_tensor("output.weight") {
431            ct.tensor("token_embd.weight", device)?
432        } else {
433            ct.tensor("output.weight", device)?
434        };
435        let mut layers = Vec::with_capacity(block_count);
436
437        let head_dim = key_length;
438        if key_length != value_length {
439            candle_core::bail!(
440                "Expected key_length == value_length, got {key_length} != {value_length}"
441            );
442        }
443
444        let mut ropes = HashMap::new();
445        for layer_idx in 0..block_count {
446            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
447            ropes.insert(
448                device.location(),
449                Arc::new(RotaryEmbedding::new(
450                    rope_freq_base,
451                    rope_dim,
452                    max_seq_len,
453                    device,
454                    false,
455                    dtype,
456                )?),
457            );
458        }
459
460        for layer_idx in NiceProgressBar::<_, 'b'>(
461            0..block_count,
462            "Loading repeating layers",
463            &MultiProgress::new(),
464        ) {
465            let prefix = format!("blk.{layer_idx}");
466            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
467            let rotary = ropes
468                .get(&device.location())
469                .expect("No RoPE for device location!")
470                .clone();
471
472            let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
473            let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
474            let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
475            let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
476            let mlp_or_moe = if n_expert <= 1 {
477                let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
478                let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
479                let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
480                MlpOrMoe::Mlp(Mlp {
481                    feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
482                        q_weight: Arc::new(feed_forward_w1),
483                        b: None,
484                    })?),
485                    feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
486                        q_weight: Arc::new(feed_forward_w2),
487                        b: None,
488                    })?),
489                    feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
490                        q_weight: Arc::new(feed_forward_w3),
491                        b: None,
492                    })?),
493                })
494            } else {
495                let feed_forward_gate_inp =
496                    ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
497                let mut experts = Vec::with_capacity(n_expert);
498                match ct.tensor(&format!("{prefix}.ffn_gate_exps.weight"), device) {
499                    Ok(feed_forward_gate_exps) => {
500                        let feed_forward_down_exps =
501                            ct.tensor(&format!("{prefix}.ffn_down_exps.weight"), device)?;
502                        let feed_forward_up_exps =
503                            ct.tensor(&format!("{prefix}.ffn_up_exps.weight"), device)?;
504
505                        let dequant_ffn_gate = feed_forward_gate_exps
506                            .dequantize(device)?
507                            .chunk(n_expert, 0)?;
508                        let dequant_ffn_down = feed_forward_down_exps
509                            .dequantize(device)?
510                            .chunk(n_expert, 0)?;
511                        let dequant_ffn_up = feed_forward_up_exps
512                            .dequantize(device)?
513                            .chunk(n_expert, 0)?;
514
515                        assert_eq!(dequant_ffn_up.len(), dequant_ffn_down.len());
516                        assert_eq!(dequant_ffn_gate.len(), dequant_ffn_down.len());
517                        assert_eq!(dequant_ffn_gate.len(), n_expert);
518
519                        let gate_type = feed_forward_gate_exps.dtype();
520                        let down_type = feed_forward_down_exps.dtype();
521                        let up_type = feed_forward_up_exps.dtype();
522
523                        for (ff_w1, (ff_w2, ff_w3)) in dequant_ffn_gate
524                            .into_iter()
525                            .zip(dequant_ffn_down.into_iter().zip(dequant_ffn_up))
526                        {
527                            experts.push(Mlp {
528                                feed_forward_w1: Arc::new(GgufMatMul::new(
529                                    QuantMethodConfig::Gguf {
530                                        q_weight: Arc::new(QTensor::quantize(&ff_w1, gate_type)?),
531                                        b: None,
532                                    },
533                                )?),
534                                feed_forward_w2: Arc::new(GgufMatMul::new(
535                                    QuantMethodConfig::Gguf {
536                                        q_weight: Arc::new(QTensor::quantize(&ff_w2, down_type)?),
537                                        b: None,
538                                    },
539                                )?),
540                                feed_forward_w3: Arc::new(GgufMatMul::new(
541                                    QuantMethodConfig::Gguf {
542                                        q_weight: Arc::new(QTensor::quantize(&ff_w3, up_type)?),
543                                        b: None,
544                                    },
545                                )?),
546                            })
547                        }
548                    }
549                    Err(_) => {
550                        for i in 0..n_expert {
551                            let feed_forward_w1 =
552                                ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
553                            let feed_forward_w2 =
554                                ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
555                            let feed_forward_w3 =
556                                ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
557                            experts.push(Mlp {
558                                feed_forward_w1: Arc::new(GgufMatMul::new(
559                                    QuantMethodConfig::Gguf {
560                                        q_weight: Arc::new(feed_forward_w1),
561                                        b: None,
562                                    },
563                                )?),
564                                feed_forward_w2: Arc::new(GgufMatMul::new(
565                                    QuantMethodConfig::Gguf {
566                                        q_weight: Arc::new(feed_forward_w2),
567                                        b: None,
568                                    },
569                                )?),
570                                feed_forward_w3: Arc::new(GgufMatMul::new(
571                                    QuantMethodConfig::Gguf {
572                                        q_weight: Arc::new(feed_forward_w3),
573                                        b: None,
574                                    },
575                                )?),
576                            })
577                        }
578                    }
579                }
580                MlpOrMoe::MoE {
581                    n_expert_used,
582                    feed_forward_gate_inp: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
583                        q_weight: Arc::new(feed_forward_gate_inp),
584                        b: None,
585                    })?),
586                    experts,
587                }
588            };
589            let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
590            let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
591            let paged_attn = match &attention_mechanism {
592                AttentionImplementation::Eager => None,
593                AttentionImplementation::PagedAttention => {
594                    Some(PagedAttention::new(head_dim, device, None)?)
595                }
596            };
597            layers.push(LayerWeights {
598                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
599                    q_weight: Arc::new(attention_wq),
600                    b: None,
601                })?),
602                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
603                    q_weight: Arc::new(attention_wk),
604                    b: None,
605                })?),
606                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
607                    q_weight: Arc::new(attention_wv),
608                    b: None,
609                })?),
610                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
611                    q_weight: Arc::new(attention_wo),
612                    b: None,
613                })?),
614                attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
615                mlp_or_moe,
616                ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
617                n_head: head_count,
618                n_kv_head: head_count_kv,
619                head_dim,
620                rotary: rotary.clone(),
621                paged_attn,
622                sdpa_params: SdpaParams {
623                    n_kv_groups: head_count / head_count_kv,
624                    softcap: None,
625                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
626                    sliding_window: None,
627                },
628                dtype,
629            })
630        }
631        Ok(Self {
632            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
633            layers,
634            norm,
635            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
636                q_weight: Arc::new(output),
637                b: None,
638            })?),
639            device: device.clone(),
640            cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
641            max_seq_len,
642            mapper: Some(mapper),
643            dtype,
644        })
645    }
646}
647
648impl ModelWeights {
649    pub fn forward(
650        &self,
651        x: &Tensor,
652        start_offsets: &[usize],
653        context_lens: Vec<(usize, usize)>,
654        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
655    ) -> Result<Tensor> {
656        let mut layer_in = self.tok_embeddings.forward(x)?;
657        let cache = &mut self.cache.normal().0;
658        let mask = CausalMasker.make_causal_mask_matrix(
659            x,
660            metadata
661                .as_ref()
662                .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
663                .unwrap_or(cache as &dyn PastKvLenCache),
664            self.dtype,
665            self.layers[0].n_head,
666        )?;
667        // PagedAttention prompt chunking
668        let mask = mask.filter(|_| {
669            metadata
670                .as_ref()
671                .map(|(_, meta)| meta.is_first_prompt_chunk)
672                .unwrap_or(true)
673        });
674        for (i, layer) in self.layers.iter().enumerate() {
675            if let Some(ref mapper) = self.mapper {
676                layer_in = mapper.map(layer_in, i)?;
677            }
678            let x = layer_in;
679            let residual = &x;
680            let x = layer.attention_norm.forward(&x)?;
681            let attn = layer.forward_attn(
682                &x,
683                mask.as_ref()
684                    .map(|m| m.to_device(x.device()).unwrap())
685                    .as_ref(),
686                start_offsets,
687                &mut cache[i],
688                metadata
689                    .as_ref()
690                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
691            )?;
692            let x = (attn + residual)?;
693
694            // MLP
695            let residual = &x;
696            let x = layer.ffn_norm.forward(&x)?;
697            let x = layer.mlp_or_moe.forward(&x)?;
698            let x = (x + residual)?;
699            layer_in = x;
700        }
701        let layer_in = layer_in.to_device(&self.device)?;
702        let x = self.norm.forward(&layer_in)?;
703        extract_logits(
704            &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
705            context_lens,
706        )
707    }
708}