mistralrs_core/models/
quantized_llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::quantized::ggml_file;
7use candle_core::quantized::QTensor;
8use candle_core::{DType, Device, Result, Tensor};
9use candle_nn::{Embedding, Module};
10use indicatif::MultiProgress;
11use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
12
13use crate::attention::SdpaParams;
14use crate::device_map::DeviceMapper;
15use crate::gguf::Content;
16use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
17use crate::layers_masker::PastKvLenCache;
18use crate::paged_attention::{AttentionImplementation, PagedAttention};
19use crate::pipeline::extract_logits;
20use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
21use crate::pipeline::EitherCache;
22use crate::pipeline::KvCache;
23use crate::pipeline::NormalCache;
24use crate::utils::gguf_metadata::ContentMetadata;
25use crate::utils::model_config as ModelConfig;
26use crate::utils::progress::NiceProgressBar;
27const MAX_SEQ_LEN: u32 = 4096;
28
29struct Mlp {
30    feed_forward_w1: Arc<dyn QuantMethod>,
31    feed_forward_w2: Arc<dyn QuantMethod>,
32    feed_forward_w3: Arc<dyn QuantMethod>,
33}
34
35impl Mlp {
36    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
37        let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
38        let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
39        let y = &(candle_nn::ops::silu(&w1)? * w3)?;
40        MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
41    }
42}
43
44enum MlpOrMoe {
45    Mlp(Mlp),
46    MoE {
47        n_expert_used: usize,
48        feed_forward_gate_inp: Arc<dyn QuantMethod>,
49        experts: Vec<Mlp>,
50    },
51}
52
53impl MlpOrMoe {
54    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
55        match self {
56            Self::MoE {
57                feed_forward_gate_inp,
58                experts,
59                n_expert_used,
60            } => {
61                let (b_size, seq_len, hidden_dim) = xs.dims3()?;
62                let xs = xs.reshape(((), hidden_dim))?;
63                let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?;
64                let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
65
66                // In order to extract topk, we extract the data from the tensor and manipulate it
67                // directly. Maybe we will want to use some custom ops instead at some point.
68                let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
69
70                // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
71                // top_x contains the row indexes to evaluate for each expert.
72                let mut top_x = vec![vec![]; experts.len()];
73                let mut selected_rws = vec![vec![]; experts.len()];
74                for (row_idx, rw) in routing_weights.iter().enumerate() {
75                    let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
76                    dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
77                    let mut sum_routing_weights = 0f32;
78                    for &expert_idx in dst.iter().take(*n_expert_used) {
79                        let expert_idx = expert_idx as usize;
80                        let routing_weight = rw[expert_idx];
81                        sum_routing_weights += routing_weight;
82                        top_x[expert_idx].push(row_idx as u32);
83                    }
84                    for &expert_idx in dst.iter().take(*n_expert_used) {
85                        let expert_idx = expert_idx as usize;
86                        let routing_weight = rw[expert_idx];
87                        selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
88                    }
89                }
90
91                // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
92                // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
93
94                let mut ys = xs.zeros_like()?;
95                for (expert_idx, expert_layer) in experts.iter().enumerate() {
96                    let top_x = &top_x[expert_idx];
97                    if top_x.is_empty() {
98                        continue;
99                    }
100                    let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
101                    let selected_rws =
102                        Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
103                            .reshape(((), 1))?;
104                    // Index the correct hidden states and compute the expert hidden state for
105                    // the current expert. We need to make sure to multiply the output hidden
106                    // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
107                    let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
108                    // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
109                    let current_hidden_states = expert_layer.forward(&current_state)?;
110                    let current_hidden_states =
111                        current_hidden_states.broadcast_mul(&selected_rws)?;
112                    ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
113                }
114
115                let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
116                Ok(ys)
117            }
118            Self::Mlp(mlp) => mlp.forward(xs),
119        }
120    }
121}
122
123struct LayerWeights {
124    attention_wq: Arc<dyn QuantMethod>,
125    attention_wk: Arc<dyn QuantMethod>,
126    attention_wv: Arc<dyn QuantMethod>,
127    attention_wo: Arc<dyn QuantMethod>,
128    attention_norm: QRmsNorm,
129    mlp_or_moe: MlpOrMoe,
130    ffn_norm: QRmsNorm,
131    n_head: usize,
132    n_kv_head: usize,
133    head_dim: usize,
134    rotary: Arc<RotaryEmbedding>,
135    paged_attn: Option<PagedAttention>,
136    sdpa_params: SdpaParams,
137    dtype: DType,
138}
139
140impl LayerWeights {
141    fn forward_attn(
142        &self,
143        x: &Tensor,
144        mask: Option<&Tensor>,
145        start_offsets: &[usize],
146        kv_cache: &mut KvCache,
147        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
148    ) -> Result<Tensor> {
149        let (b_sz, seq_len, _) = x.dims3()?;
150
151        let q = MatMul
152            .qmethod_matmul(x, &*self.attention_wq)?
153            .to_dtype(self.dtype)?;
154        let k = MatMul
155            .qmethod_matmul(x, &*self.attention_wk)?
156            .to_dtype(self.dtype)?;
157        let v = MatMul
158            .qmethod_matmul(x, &*self.attention_wv)?
159            .to_dtype(self.dtype)?;
160
161        let (q, k, v) = if seq_len != 1 {
162            let q = q
163                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
164                .transpose(1, 2)?;
165            let k = k
166                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
167                .transpose(1, 2)?;
168            let v = v
169                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
170                .transpose(1, 2)?;
171            (q, k, v)
172        } else {
173            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
174            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
175            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
176            (q, k, v)
177        };
178
179        let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
180
181        let y = match &self.paged_attn {
182            Some(paged_attn) => {
183                let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
184                paged_attn.forward(
185                    &q,
186                    &k,
187                    &v,
188                    mask,
189                    Some(key_cache),
190                    Some(value_cache),
191                    input_metadata,
192                    &self.sdpa_params,
193                    None,
194                )?
195            }
196            None => {
197                let (k, v) = kv_cache.append(&k, &v)?;
198
199                Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
200            }
201        };
202
203        let y = if mask.is_some() {
204            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
205        } else {
206            y.reshape((b_sz, seq_len, ()))?
207        };
208
209        let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
210        Ok(y)
211    }
212}
213
214pub struct ModelWeights {
215    tok_embeddings: Embedding,
216    layers: Vec<LayerWeights>,
217    norm: QRmsNorm,
218    output: Arc<dyn QuantMethod>,
219    pub device: Device,
220    pub cache: EitherCache,
221    pub max_seq_len: usize,
222    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
223    dtype: DType,
224}
225
226impl ModelConfig::FromGGML for ModelWeights {
227    fn from_ggml(mut ct: ggml_file::Content, gqa: usize, dtype: DType) -> Result<Self> {
228        let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
229        let rotary = RotaryEmbedding::new_partial(
230            10000.,
231            ct.hparams.n_rot as usize,
232            MAX_SEQ_LEN as usize,
233            &ct.device,
234            false,
235            dtype,
236        )?;
237        let tok_embeddings = ct.remove("tok_embeddings.weight")?;
238        let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
239        let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
240        let output = ct.remove("output.weight")?;
241        let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
242        for layer_idx in NiceProgressBar::<_, 'b'>(
243            0..ct.hparams.n_layer,
244            "Loading repeating layers",
245            &MultiProgress::new(),
246        ) {
247            let prefix = format!("layers.{layer_idx}");
248            let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
249            let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
250            let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
251            let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
252            let mlp_or_moe = {
253                let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
254                let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
255                let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
256                MlpOrMoe::Mlp(Mlp {
257                    feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
258                        q_weight: Arc::new(feed_forward_w1),
259                        b: None,
260                    })?),
261                    feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
262                        q_weight: Arc::new(feed_forward_w2),
263                        b: None,
264                    })?),
265                    feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
266                        q_weight: Arc::new(feed_forward_w3),
267                        b: None,
268                    })?),
269                })
270            };
271            let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
272            let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
273            let n_kv_head = ct.hparams.n_head as usize / gqa;
274            layers.push(LayerWeights {
275                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
276                    q_weight: Arc::new(attention_wq),
277                    b: None,
278                })?),
279                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
280                    q_weight: Arc::new(attention_wk),
281                    b: None,
282                })?),
283                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
284                    q_weight: Arc::new(attention_wv),
285                    b: None,
286                })?),
287                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
288                    q_weight: Arc::new(attention_wo),
289                    b: None,
290                })?),
291                attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
292                mlp_or_moe,
293                ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
294                n_head: ct.hparams.n_head as usize,
295                n_kv_head: ct.hparams.n_head as usize / gqa,
296                head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
297                rotary: rotary.clone().into(),
298                paged_attn: None, // TODO
299                sdpa_params: SdpaParams {
300                    n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
301                    use_flash_attn: false,
302                    softcap: None,
303                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
304                    sliding_window: None,
305                },
306                dtype,
307            })
308        }
309        Ok(Self {
310            tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
311            layers,
312            norm,
313            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
314                q_weight: Arc::new(output),
315                b: None,
316            })?),
317            device: ct.device.clone(),
318            cache: EitherCache::Normal(NormalCache::new(
319                ct.hparams.n_layer as usize,
320                MAX_SEQ_LEN as usize,
321            )),
322            max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml.
323            mapper: None,
324            dtype,
325        })
326    }
327}
328
329// llama `llm` fields:
330// https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llm
331// NOTE: Types here do not match spec
332pub(crate) struct PropsGGUF {
333    pub n_expert: usize,
334    pub n_expert_used: usize,
335    pub head_count: usize,
336    pub head_count_kv: usize,
337    pub block_count: usize,
338    pub embedding_length: usize,
339    pub rope_dim: usize,
340    pub rms_norm_eps: f32,
341    pub max_seq_len: usize,
342    pub rope_freq_base: f32,
343    pub key_length: usize,
344    pub value_length: usize,
345}
346
347impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
348    type Error = anyhow::Error;
349
350    fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
351        c.verify_arch("llama")?;
352
353        let required = [
354            "attention.head_count",
355            "attention.head_count_kv",
356            "block_count",
357            "embedding_length",
358            "rope.dimension_count",
359            "attention.layer_norm_rms_epsilon",
360        ];
361        c.has_required_keys(&required)?;
362
363        let embed_len = c.get_value::<u32>("embedding_length")? as usize;
364        let head_count = c.get_value::<u32>("attention.head_count")? as usize;
365
366        // NOTE: Values are not aligned with GGUFv3 types
367        // TODO: Normalize value types to spec
368        let props = Self {
369            n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
370            n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
371            head_count,
372            head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
373            block_count: c.get_value::<u32>("block_count")? as usize,
374            embedding_length: embed_len,
375            rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
376            // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
377            rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
378            max_seq_len: c
379                .get_value::<u64>("context_length")
380                .ok()
381                .unwrap_or(MAX_SEQ_LEN as u64) as usize,
382            rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
383            key_length: c
384                .get_value::<u32>("attention.key_length")
385                .ok()
386                .map(|x| x as usize)
387                .unwrap_or(embed_len / head_count),
388            value_length: c
389                .get_value::<u32>("attention.value_length")
390                .ok()
391                .map(|x| x as usize)
392                .unwrap_or(embed_len / head_count),
393        };
394
395        Ok(props)
396    }
397}
398
399impl ModelConfig::FromGGUF for ModelWeights {
400    fn from_gguf<R: std::io::Seek + std::io::Read>(
401        mut ct: Content<'_, R>,
402        device: &Device,
403        mapper: Box<dyn DeviceMapper + Send + Sync>,
404        attention_mechanism: AttentionImplementation,
405        dtype: DType,
406    ) -> Result<Self> {
407        // Parameter extraction from metadata.
408        let metadata = ContentMetadata {
409            path_prefix: "llama",
410            metadata: ct.get_metadata(),
411        };
412        let PropsGGUF {
413            n_expert,
414            n_expert_used,
415            head_count,
416            head_count_kv,
417            block_count,
418            embedding_length,
419            rope_dim,
420            rms_norm_eps,
421            max_seq_len,
422            rope_freq_base,
423            key_length,
424            value_length,
425        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
426
427        let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
428        let tok_embeddings = qtok_embeddings.dequantize(device)?;
429        let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
430        let output = if !ct.has_tensor("output.weight") {
431            ct.tensor("token_embd.weight", device)?
432        } else {
433            ct.tensor("output.weight", device)?
434        };
435        let mut layers = Vec::with_capacity(block_count);
436
437        let head_dim = key_length;
438        if key_length != value_length {
439            candle_core::bail!(
440                "Expected key_length == value_length, got {key_length} != {value_length}"
441            );
442        }
443
444        let mut ropes = HashMap::new();
445        for layer_idx in 0..block_count {
446            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
447            ropes.insert(
448                device.location(),
449                Arc::new(RotaryEmbedding::new(
450                    rope_freq_base,
451                    rope_dim,
452                    max_seq_len,
453                    device,
454                    false,
455                    dtype,
456                )?),
457            );
458        }
459
460        for layer_idx in NiceProgressBar::<_, 'b'>(
461            0..block_count,
462            "Loading repeating layers",
463            &MultiProgress::new(),
464        ) {
465            let prefix = format!("blk.{layer_idx}");
466            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
467            let rotary = ropes
468                .get(&device.location())
469                .expect("No RoPE for device location!")
470                .clone();
471
472            let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
473            let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
474            let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
475            let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
476            let mlp_or_moe = if n_expert <= 1 {
477                let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
478                let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
479                let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
480                MlpOrMoe::Mlp(Mlp {
481                    feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
482                        q_weight: Arc::new(feed_forward_w1),
483                        b: None,
484                    })?),
485                    feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
486                        q_weight: Arc::new(feed_forward_w2),
487                        b: None,
488                    })?),
489                    feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
490                        q_weight: Arc::new(feed_forward_w3),
491                        b: None,
492                    })?),
493                })
494            } else {
495                let feed_forward_gate_inp =
496                    ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
497                let mut experts = Vec::with_capacity(n_expert);
498                match ct.tensor(&format!("{prefix}.ffn_gate_exps.weight"), device) {
499                    Ok(feed_forward_gate_exps) => {
500                        let feed_forward_down_exps =
501                            ct.tensor(&format!("{prefix}.ffn_down_exps.weight"), device)?;
502                        let feed_forward_up_exps =
503                            ct.tensor(&format!("{prefix}.ffn_up_exps.weight"), device)?;
504
505                        let dequant_ffn_gate = feed_forward_gate_exps
506                            .dequantize(device)?
507                            .chunk(n_expert, 0)?;
508                        let dequant_ffn_down = feed_forward_down_exps
509                            .dequantize(device)?
510                            .chunk(n_expert, 0)?;
511                        let dequant_ffn_up = feed_forward_up_exps
512                            .dequantize(device)?
513                            .chunk(n_expert, 0)?;
514
515                        assert_eq!(dequant_ffn_up.len(), dequant_ffn_down.len());
516                        assert_eq!(dequant_ffn_gate.len(), dequant_ffn_down.len());
517                        assert_eq!(dequant_ffn_gate.len(), n_expert);
518
519                        let gate_type = feed_forward_gate_exps.dtype();
520                        let down_type = feed_forward_down_exps.dtype();
521                        let up_type = feed_forward_up_exps.dtype();
522
523                        for (ff_w1, (ff_w2, ff_w3)) in dequant_ffn_gate
524                            .into_iter()
525                            .zip(dequant_ffn_down.into_iter().zip(dequant_ffn_up))
526                        {
527                            experts.push(Mlp {
528                                feed_forward_w1: Arc::new(GgufMatMul::new(
529                                    QuantMethodConfig::Gguf {
530                                        q_weight: Arc::new(QTensor::quantize(&ff_w1, gate_type)?),
531                                        b: None,
532                                    },
533                                )?),
534                                feed_forward_w2: Arc::new(GgufMatMul::new(
535                                    QuantMethodConfig::Gguf {
536                                        q_weight: Arc::new(QTensor::quantize(&ff_w2, down_type)?),
537                                        b: None,
538                                    },
539                                )?),
540                                feed_forward_w3: Arc::new(GgufMatMul::new(
541                                    QuantMethodConfig::Gguf {
542                                        q_weight: Arc::new(QTensor::quantize(&ff_w3, up_type)?),
543                                        b: None,
544                                    },
545                                )?),
546                            })
547                        }
548                    }
549                    Err(_) => {
550                        for i in 0..n_expert {
551                            let feed_forward_w1 =
552                                ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
553                            let feed_forward_w2 =
554                                ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
555                            let feed_forward_w3 =
556                                ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
557                            experts.push(Mlp {
558                                feed_forward_w1: Arc::new(GgufMatMul::new(
559                                    QuantMethodConfig::Gguf {
560                                        q_weight: Arc::new(feed_forward_w1),
561                                        b: None,
562                                    },
563                                )?),
564                                feed_forward_w2: Arc::new(GgufMatMul::new(
565                                    QuantMethodConfig::Gguf {
566                                        q_weight: Arc::new(feed_forward_w2),
567                                        b: None,
568                                    },
569                                )?),
570                                feed_forward_w3: Arc::new(GgufMatMul::new(
571                                    QuantMethodConfig::Gguf {
572                                        q_weight: Arc::new(feed_forward_w3),
573                                        b: None,
574                                    },
575                                )?),
576                            })
577                        }
578                    }
579                }
580                MlpOrMoe::MoE {
581                    n_expert_used,
582                    feed_forward_gate_inp: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
583                        q_weight: Arc::new(feed_forward_gate_inp),
584                        b: None,
585                    })?),
586                    experts,
587                }
588            };
589            let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
590            let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
591            let paged_attn = match &attention_mechanism {
592                AttentionImplementation::Eager => None,
593                AttentionImplementation::PagedAttention => {
594                    Some(PagedAttention::new(head_dim, device, None)?)
595                }
596            };
597            layers.push(LayerWeights {
598                attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
599                    q_weight: Arc::new(attention_wq),
600                    b: None,
601                })?),
602                attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
603                    q_weight: Arc::new(attention_wk),
604                    b: None,
605                })?),
606                attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
607                    q_weight: Arc::new(attention_wv),
608                    b: None,
609                })?),
610                attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
611                    q_weight: Arc::new(attention_wo),
612                    b: None,
613                })?),
614                attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
615                mlp_or_moe,
616                ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
617                n_head: head_count,
618                n_kv_head: head_count_kv,
619                head_dim,
620                rotary: rotary.clone(),
621                paged_attn,
622                sdpa_params: SdpaParams {
623                    n_kv_groups: head_count / head_count_kv,
624                    use_flash_attn: false,
625                    softcap: None,
626                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
627                    sliding_window: None,
628                },
629                dtype,
630            })
631        }
632        Ok(Self {
633            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
634            layers,
635            norm,
636            output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
637                q_weight: Arc::new(output),
638                b: None,
639            })?),
640            device: device.clone(),
641            cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
642            max_seq_len,
643            mapper: Some(mapper),
644            dtype,
645        })
646    }
647}
648
649impl ModelWeights {
650    pub fn forward(
651        &self,
652        x: &Tensor,
653        start_offsets: &[usize],
654        context_lens: Vec<(usize, usize)>,
655        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
656    ) -> Result<Tensor> {
657        let mut layer_in = self.tok_embeddings.forward(x)?;
658        let cache = &mut self.cache.normal().0;
659        let mask = CausalMasker.make_causal_mask_matrix(
660            x,
661            metadata
662                .as_ref()
663                .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
664                .unwrap_or(cache as &dyn PastKvLenCache),
665            self.dtype,
666            self.layers[0].n_head,
667        )?;
668        // PagedAttention prompt chunking
669        let mask = mask.filter(|_| {
670            metadata
671                .as_ref()
672                .map(|(_, meta)| meta.is_first_prompt_chunk)
673                .unwrap_or(true)
674        });
675        for (i, layer) in self.layers.iter().enumerate() {
676            if let Some(ref mapper) = self.mapper {
677                layer_in = mapper.map(layer_in, i)?;
678            }
679            let x = layer_in;
680            let residual = &x;
681            let x = layer.attention_norm.forward(&x)?;
682            let attn = layer.forward_attn(
683                &x,
684                mask.as_ref()
685                    .map(|m| m.to_device(x.device()).unwrap())
686                    .as_ref(),
687                start_offsets,
688                &mut cache[i],
689                metadata
690                    .as_ref()
691                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
692            )?;
693            let x = (attn + residual)?;
694
695            // MLP
696            let residual = &x;
697            let x = layer.ffn_norm.forward(&x)?;
698            let x = layer.mlp_or_moe.forward(&x)?;
699            let x = (x + residual)?;
700            layer_in = x;
701        }
702        let layer_in = layer_in.to_device(&self.device)?;
703        let x = self.norm.forward(&layer_in)?;
704        extract_logits(
705            &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
706            context_lens,
707        )
708    }
709}