1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::quantized::ggml_file;
7use candle_core::quantized::QTensor;
8use candle_core::{DType, Device, Result, Tensor};
9use candle_nn::{Embedding, Module};
10use indicatif::MultiProgress;
11use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
12
13use crate::attention::SdpaParams;
14use crate::device_map::DeviceMapper;
15use crate::gguf::Content;
16use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
17use crate::layers_masker::PastKvLenCache;
18use crate::paged_attention::{AttentionImplementation, PagedAttention};
19use crate::pipeline::extract_logits;
20use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
21use crate::pipeline::EitherCache;
22use crate::pipeline::KvCache;
23use crate::pipeline::NormalCache;
24use crate::utils::gguf_metadata::ContentMetadata;
25use crate::utils::model_config as ModelConfig;
26use crate::utils::progress::NiceProgressBar;
27const MAX_SEQ_LEN: u32 = 4096;
28
29struct Mlp {
30 feed_forward_w1: Arc<dyn QuantMethod>,
31 feed_forward_w2: Arc<dyn QuantMethod>,
32 feed_forward_w3: Arc<dyn QuantMethod>,
33}
34
35impl Mlp {
36 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
37 let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
38 let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
39 let y = &(candle_nn::ops::silu(&w1)? * w3)?;
40 MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
41 }
42}
43
44enum MlpOrMoe {
45 Mlp(Mlp),
46 MoE {
47 n_expert_used: usize,
48 feed_forward_gate_inp: Arc<dyn QuantMethod>,
49 experts: Vec<Mlp>,
50 },
51}
52
53impl MlpOrMoe {
54 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
55 match self {
56 Self::MoE {
57 feed_forward_gate_inp,
58 experts,
59 n_expert_used,
60 } => {
61 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
62 let xs = xs.reshape(((), hidden_dim))?;
63 let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?;
64 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
65
66 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
69
70 let mut top_x = vec![vec![]; experts.len()];
73 let mut selected_rws = vec![vec![]; experts.len()];
74 for (row_idx, rw) in routing_weights.iter().enumerate() {
75 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
76 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
77 let mut sum_routing_weights = 0f32;
78 for &expert_idx in dst.iter().take(*n_expert_used) {
79 let expert_idx = expert_idx as usize;
80 let routing_weight = rw[expert_idx];
81 sum_routing_weights += routing_weight;
82 top_x[expert_idx].push(row_idx as u32);
83 }
84 for &expert_idx in dst.iter().take(*n_expert_used) {
85 let expert_idx = expert_idx as usize;
86 let routing_weight = rw[expert_idx];
87 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
88 }
89 }
90
91 let mut ys = xs.zeros_like()?;
95 for (expert_idx, expert_layer) in experts.iter().enumerate() {
96 let top_x = &top_x[expert_idx];
97 if top_x.is_empty() {
98 continue;
99 }
100 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
101 let selected_rws =
102 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
103 .reshape(((), 1))?;
104 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
108 let current_hidden_states = expert_layer.forward(¤t_state)?;
110 let current_hidden_states =
111 current_hidden_states.broadcast_mul(&selected_rws)?;
112 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
113 }
114
115 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
116 Ok(ys)
117 }
118 Self::Mlp(mlp) => mlp.forward(xs),
119 }
120 }
121}
122
123struct LayerWeights {
124 attention_wq: Arc<dyn QuantMethod>,
125 attention_wk: Arc<dyn QuantMethod>,
126 attention_wv: Arc<dyn QuantMethod>,
127 attention_wo: Arc<dyn QuantMethod>,
128 attention_norm: QRmsNorm,
129 mlp_or_moe: MlpOrMoe,
130 ffn_norm: QRmsNorm,
131 n_head: usize,
132 n_kv_head: usize,
133 head_dim: usize,
134 rotary: Arc<RotaryEmbedding>,
135 paged_attn: Option<PagedAttention>,
136 sdpa_params: SdpaParams,
137 dtype: DType,
138}
139
140impl LayerWeights {
141 fn forward_attn(
142 &self,
143 x: &Tensor,
144 mask: Option<&Tensor>,
145 start_offsets: &[usize],
146 kv_cache: &mut KvCache,
147 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
148 ) -> Result<Tensor> {
149 let (b_sz, seq_len, _) = x.dims3()?;
150
151 let q = MatMul
152 .qmethod_matmul(x, &*self.attention_wq)?
153 .to_dtype(self.dtype)?;
154 let k = MatMul
155 .qmethod_matmul(x, &*self.attention_wk)?
156 .to_dtype(self.dtype)?;
157 let v = MatMul
158 .qmethod_matmul(x, &*self.attention_wv)?
159 .to_dtype(self.dtype)?;
160
161 let (q, k, v) = if seq_len != 1 {
162 let q = q
163 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
164 .transpose(1, 2)?;
165 let k = k
166 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
167 .transpose(1, 2)?;
168 let v = v
169 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
170 .transpose(1, 2)?;
171 (q, k, v)
172 } else {
173 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
174 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
175 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
176 (q, k, v)
177 };
178
179 let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
180
181 let y = match &self.paged_attn {
182 Some(paged_attn) => {
183 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
184 paged_attn.forward(
185 &q,
186 &k,
187 &v,
188 mask,
189 Some(key_cache),
190 Some(value_cache),
191 input_metadata,
192 &self.sdpa_params,
193 None,
194 )?
195 }
196 None => {
197 let (k, v) = kv_cache.append(&k, &v)?;
198
199 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
200 }
201 };
202
203 let y = if mask.is_some() {
204 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
205 } else {
206 y.reshape((b_sz, seq_len, ()))?
207 };
208
209 let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
210 Ok(y)
211 }
212}
213
214pub struct ModelWeights {
215 tok_embeddings: Embedding,
216 layers: Vec<LayerWeights>,
217 norm: QRmsNorm,
218 output: Arc<dyn QuantMethod>,
219 pub device: Device,
220 pub cache: EitherCache,
221 pub max_seq_len: usize,
222 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
223 dtype: DType,
224}
225
226impl ModelConfig::FromGGML for ModelWeights {
227 fn from_ggml(mut ct: ggml_file::Content, gqa: usize, dtype: DType) -> Result<Self> {
228 let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
229 let rotary = RotaryEmbedding::new_partial(
230 10000.,
231 ct.hparams.n_rot as usize,
232 MAX_SEQ_LEN as usize,
233 &ct.device,
234 false,
235 dtype,
236 )?;
237 let tok_embeddings = ct.remove("tok_embeddings.weight")?;
238 let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
239 let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
240 let output = ct.remove("output.weight")?;
241 let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
242 for layer_idx in NiceProgressBar::<_, 'b'>(
243 0..ct.hparams.n_layer,
244 "Loading repeating layers",
245 &MultiProgress::new(),
246 ) {
247 let prefix = format!("layers.{layer_idx}");
248 let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
249 let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
250 let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
251 let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
252 let mlp_or_moe = {
253 let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
254 let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
255 let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
256 MlpOrMoe::Mlp(Mlp {
257 feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
258 q_weight: Arc::new(feed_forward_w1),
259 b: None,
260 })?),
261 feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
262 q_weight: Arc::new(feed_forward_w2),
263 b: None,
264 })?),
265 feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
266 q_weight: Arc::new(feed_forward_w3),
267 b: None,
268 })?),
269 })
270 };
271 let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
272 let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
273 let n_kv_head = ct.hparams.n_head as usize / gqa;
274 layers.push(LayerWeights {
275 attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
276 q_weight: Arc::new(attention_wq),
277 b: None,
278 })?),
279 attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
280 q_weight: Arc::new(attention_wk),
281 b: None,
282 })?),
283 attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
284 q_weight: Arc::new(attention_wv),
285 b: None,
286 })?),
287 attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
288 q_weight: Arc::new(attention_wo),
289 b: None,
290 })?),
291 attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
292 mlp_or_moe,
293 ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
294 n_head: ct.hparams.n_head as usize,
295 n_kv_head: ct.hparams.n_head as usize / gqa,
296 head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
297 rotary: rotary.clone().into(),
298 paged_attn: None, sdpa_params: SdpaParams {
300 n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
301 use_flash_attn: false,
302 softcap: None,
303 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
304 sliding_window: None,
305 },
306 dtype,
307 })
308 }
309 Ok(Self {
310 tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
311 layers,
312 norm,
313 output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
314 q_weight: Arc::new(output),
315 b: None,
316 })?),
317 device: ct.device.clone(),
318 cache: EitherCache::Normal(NormalCache::new(
319 ct.hparams.n_layer as usize,
320 MAX_SEQ_LEN as usize,
321 )),
322 max_seq_len: MAX_SEQ_LEN as usize, mapper: None,
324 dtype,
325 })
326 }
327}
328
329pub(crate) struct PropsGGUF {
333 pub n_expert: usize,
334 pub n_expert_used: usize,
335 pub head_count: usize,
336 pub head_count_kv: usize,
337 pub block_count: usize,
338 pub embedding_length: usize,
339 pub rope_dim: usize,
340 pub rms_norm_eps: f32,
341 pub max_seq_len: usize,
342 pub rope_freq_base: f32,
343 pub key_length: usize,
344 pub value_length: usize,
345}
346
347impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
348 type Error = anyhow::Error;
349
350 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
351 c.verify_arch("llama")?;
352
353 let required = [
354 "attention.head_count",
355 "attention.head_count_kv",
356 "block_count",
357 "embedding_length",
358 "rope.dimension_count",
359 "attention.layer_norm_rms_epsilon",
360 ];
361 c.has_required_keys(&required)?;
362
363 let embed_len = c.get_value::<u32>("embedding_length")? as usize;
364 let head_count = c.get_value::<u32>("attention.head_count")? as usize;
365
366 let props = Self {
369 n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
370 n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
371 head_count,
372 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
373 block_count: c.get_value::<u32>("block_count")? as usize,
374 embedding_length: embed_len,
375 rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
376 rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
378 max_seq_len: c
379 .get_value::<u64>("context_length")
380 .ok()
381 .unwrap_or(MAX_SEQ_LEN as u64) as usize,
382 rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
383 key_length: c
384 .get_value::<u32>("attention.key_length")
385 .ok()
386 .map(|x| x as usize)
387 .unwrap_or(embed_len / head_count),
388 value_length: c
389 .get_value::<u32>("attention.value_length")
390 .ok()
391 .map(|x| x as usize)
392 .unwrap_or(embed_len / head_count),
393 };
394
395 Ok(props)
396 }
397}
398
399impl ModelConfig::FromGGUF for ModelWeights {
400 fn from_gguf<R: std::io::Seek + std::io::Read>(
401 mut ct: Content<'_, R>,
402 device: &Device,
403 mapper: Box<dyn DeviceMapper + Send + Sync>,
404 attention_mechanism: AttentionImplementation,
405 dtype: DType,
406 ) -> Result<Self> {
407 let metadata = ContentMetadata {
409 path_prefix: "llama",
410 metadata: ct.get_metadata(),
411 };
412 let PropsGGUF {
413 n_expert,
414 n_expert_used,
415 head_count,
416 head_count_kv,
417 block_count,
418 embedding_length,
419 rope_dim,
420 rms_norm_eps,
421 max_seq_len,
422 rope_freq_base,
423 key_length,
424 value_length,
425 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
426
427 let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
428 let tok_embeddings = qtok_embeddings.dequantize(device)?;
429 let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
430 let output = if !ct.has_tensor("output.weight") {
431 ct.tensor("token_embd.weight", device)?
432 } else {
433 ct.tensor("output.weight", device)?
434 };
435 let mut layers = Vec::with_capacity(block_count);
436
437 let head_dim = key_length;
438 if key_length != value_length {
439 candle_core::bail!(
440 "Expected key_length == value_length, got {key_length} != {value_length}"
441 );
442 }
443
444 let mut ropes = HashMap::new();
445 for layer_idx in 0..block_count {
446 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
447 ropes.insert(
448 device.location(),
449 Arc::new(RotaryEmbedding::new(
450 rope_freq_base,
451 rope_dim,
452 max_seq_len,
453 device,
454 false,
455 dtype,
456 )?),
457 );
458 }
459
460 for layer_idx in NiceProgressBar::<_, 'b'>(
461 0..block_count,
462 "Loading repeating layers",
463 &MultiProgress::new(),
464 ) {
465 let prefix = format!("blk.{layer_idx}");
466 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
467 let rotary = ropes
468 .get(&device.location())
469 .expect("No RoPE for device location!")
470 .clone();
471
472 let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
473 let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
474 let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
475 let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
476 let mlp_or_moe = if n_expert <= 1 {
477 let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
478 let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
479 let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
480 MlpOrMoe::Mlp(Mlp {
481 feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
482 q_weight: Arc::new(feed_forward_w1),
483 b: None,
484 })?),
485 feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
486 q_weight: Arc::new(feed_forward_w2),
487 b: None,
488 })?),
489 feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
490 q_weight: Arc::new(feed_forward_w3),
491 b: None,
492 })?),
493 })
494 } else {
495 let feed_forward_gate_inp =
496 ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
497 let mut experts = Vec::with_capacity(n_expert);
498 match ct.tensor(&format!("{prefix}.ffn_gate_exps.weight"), device) {
499 Ok(feed_forward_gate_exps) => {
500 let feed_forward_down_exps =
501 ct.tensor(&format!("{prefix}.ffn_down_exps.weight"), device)?;
502 let feed_forward_up_exps =
503 ct.tensor(&format!("{prefix}.ffn_up_exps.weight"), device)?;
504
505 let dequant_ffn_gate = feed_forward_gate_exps
506 .dequantize(device)?
507 .chunk(n_expert, 0)?;
508 let dequant_ffn_down = feed_forward_down_exps
509 .dequantize(device)?
510 .chunk(n_expert, 0)?;
511 let dequant_ffn_up = feed_forward_up_exps
512 .dequantize(device)?
513 .chunk(n_expert, 0)?;
514
515 assert_eq!(dequant_ffn_up.len(), dequant_ffn_down.len());
516 assert_eq!(dequant_ffn_gate.len(), dequant_ffn_down.len());
517 assert_eq!(dequant_ffn_gate.len(), n_expert);
518
519 let gate_type = feed_forward_gate_exps.dtype();
520 let down_type = feed_forward_down_exps.dtype();
521 let up_type = feed_forward_up_exps.dtype();
522
523 for (ff_w1, (ff_w2, ff_w3)) in dequant_ffn_gate
524 .into_iter()
525 .zip(dequant_ffn_down.into_iter().zip(dequant_ffn_up))
526 {
527 experts.push(Mlp {
528 feed_forward_w1: Arc::new(GgufMatMul::new(
529 QuantMethodConfig::Gguf {
530 q_weight: Arc::new(QTensor::quantize(&ff_w1, gate_type)?),
531 b: None,
532 },
533 )?),
534 feed_forward_w2: Arc::new(GgufMatMul::new(
535 QuantMethodConfig::Gguf {
536 q_weight: Arc::new(QTensor::quantize(&ff_w2, down_type)?),
537 b: None,
538 },
539 )?),
540 feed_forward_w3: Arc::new(GgufMatMul::new(
541 QuantMethodConfig::Gguf {
542 q_weight: Arc::new(QTensor::quantize(&ff_w3, up_type)?),
543 b: None,
544 },
545 )?),
546 })
547 }
548 }
549 Err(_) => {
550 for i in 0..n_expert {
551 let feed_forward_w1 =
552 ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
553 let feed_forward_w2 =
554 ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
555 let feed_forward_w3 =
556 ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
557 experts.push(Mlp {
558 feed_forward_w1: Arc::new(GgufMatMul::new(
559 QuantMethodConfig::Gguf {
560 q_weight: Arc::new(feed_forward_w1),
561 b: None,
562 },
563 )?),
564 feed_forward_w2: Arc::new(GgufMatMul::new(
565 QuantMethodConfig::Gguf {
566 q_weight: Arc::new(feed_forward_w2),
567 b: None,
568 },
569 )?),
570 feed_forward_w3: Arc::new(GgufMatMul::new(
571 QuantMethodConfig::Gguf {
572 q_weight: Arc::new(feed_forward_w3),
573 b: None,
574 },
575 )?),
576 })
577 }
578 }
579 }
580 MlpOrMoe::MoE {
581 n_expert_used,
582 feed_forward_gate_inp: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
583 q_weight: Arc::new(feed_forward_gate_inp),
584 b: None,
585 })?),
586 experts,
587 }
588 };
589 let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
590 let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
591 let paged_attn = match &attention_mechanism {
592 AttentionImplementation::Eager => None,
593 AttentionImplementation::PagedAttention => {
594 Some(PagedAttention::new(head_dim, device, None)?)
595 }
596 };
597 layers.push(LayerWeights {
598 attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
599 q_weight: Arc::new(attention_wq),
600 b: None,
601 })?),
602 attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
603 q_weight: Arc::new(attention_wk),
604 b: None,
605 })?),
606 attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
607 q_weight: Arc::new(attention_wv),
608 b: None,
609 })?),
610 attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
611 q_weight: Arc::new(attention_wo),
612 b: None,
613 })?),
614 attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
615 mlp_or_moe,
616 ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
617 n_head: head_count,
618 n_kv_head: head_count_kv,
619 head_dim,
620 rotary: rotary.clone(),
621 paged_attn,
622 sdpa_params: SdpaParams {
623 n_kv_groups: head_count / head_count_kv,
624 use_flash_attn: false,
625 softcap: None,
626 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
627 sliding_window: None,
628 },
629 dtype,
630 })
631 }
632 Ok(Self {
633 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
634 layers,
635 norm,
636 output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
637 q_weight: Arc::new(output),
638 b: None,
639 })?),
640 device: device.clone(),
641 cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
642 max_seq_len,
643 mapper: Some(mapper),
644 dtype,
645 })
646 }
647}
648
649impl ModelWeights {
650 pub fn forward(
651 &self,
652 x: &Tensor,
653 start_offsets: &[usize],
654 context_lens: Vec<(usize, usize)>,
655 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
656 ) -> Result<Tensor> {
657 let mut layer_in = self.tok_embeddings.forward(x)?;
658 let cache = &mut self.cache.normal().0;
659 let mask = CausalMasker.make_causal_mask_matrix(
660 x,
661 metadata
662 .as_ref()
663 .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
664 .unwrap_or(cache as &dyn PastKvLenCache),
665 self.dtype,
666 self.layers[0].n_head,
667 )?;
668 let mask = mask.filter(|_| {
670 metadata
671 .as_ref()
672 .map(|(_, meta)| meta.is_first_prompt_chunk)
673 .unwrap_or(true)
674 });
675 for (i, layer) in self.layers.iter().enumerate() {
676 if let Some(ref mapper) = self.mapper {
677 layer_in = mapper.map(layer_in, i)?;
678 }
679 let x = layer_in;
680 let residual = &x;
681 let x = layer.attention_norm.forward(&x)?;
682 let attn = layer.forward_attn(
683 &x,
684 mask.as_ref()
685 .map(|m| m.to_device(x.device()).unwrap())
686 .as_ref(),
687 start_offsets,
688 &mut cache[i],
689 metadata
690 .as_ref()
691 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
692 )?;
693 let x = (attn + residual)?;
694
695 let residual = &x;
697 let x = layer.ffn_norm.forward(&x)?;
698 let x = layer.mlp_or_moe.forward(&x)?;
699 let x = (x + residual)?;
700 layer_in = x;
701 }
702 let layer_in = layer_in.to_device(&self.device)?;
703 let x = self.norm.forward(&layer_in)?;
704 extract_logits(
705 &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
706 context_lens,
707 )
708 }
709}