1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use candle_core::quantized::ggml_file;
7use candle_core::quantized::QTensor;
8use candle_core::{DType, Device, Result, Tensor};
9use candle_nn::{Embedding, Module};
10use mistralrs_quant::{GgufMatMul, QuantMethod, QuantMethodConfig};
11
12use crate::attention::SdpaParams;
13use crate::device_map::DeviceMapper;
14use crate::gguf::Content;
15use crate::layers::{CausalMasker, MatMul, QRmsNorm, RotaryEmbedding, Sdpa};
16use crate::layers_masker::PastKvLenCache;
17use crate::paged_attention::{AttentionImplementation, PagedAttention};
18use crate::pipeline::extract_logits;
19use crate::pipeline::text_models_inputs_processor::PagedAttentionInputMetadata;
20use crate::pipeline::EitherCache;
21use crate::pipeline::KvCache;
22use crate::pipeline::NormalCache;
23use crate::utils::gguf_metadata::ContentMetadata;
24use crate::utils::model_config as ModelConfig;
25use crate::utils::progress::{new_multi_progress, NiceProgressBar};
26const DEFAULT_MAX_SEQ_LEN: u32 = 4096;
28
29struct Mlp {
30 feed_forward_w1: Arc<dyn QuantMethod>,
31 feed_forward_w2: Arc<dyn QuantMethod>,
32 feed_forward_w3: Arc<dyn QuantMethod>,
33}
34
35impl Mlp {
36 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
37 let w1 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w1)?;
38 let w3 = MatMul.qmethod_matmul(xs, &*self.feed_forward_w3)?;
39 let y = &(candle_nn::ops::silu(&w1)? * w3)?;
40 MatMul.qmethod_matmul(y, &*self.feed_forward_w2)
41 }
42}
43
44enum MlpOrMoe {
45 Mlp(Mlp),
46 MoE {
47 n_expert_used: usize,
48 feed_forward_gate_inp: Arc<dyn QuantMethod>,
49 experts: Vec<Mlp>,
50 },
51}
52
53impl MlpOrMoe {
54 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
55 match self {
56 Self::MoE {
57 feed_forward_gate_inp,
58 experts,
59 n_expert_used,
60 } => {
61 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
62 let xs = xs.reshape(((), hidden_dim))?;
63 let router_logits = MatMul.qmethod_matmul(&xs, &**feed_forward_gate_inp)?;
64 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
65
66 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
69
70 let mut top_x = vec![vec![]; experts.len()];
73 let mut selected_rws = vec![vec![]; experts.len()];
74 for (row_idx, rw) in routing_weights.iter().enumerate() {
75 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
76 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
77 let mut sum_routing_weights = 0f32;
78 for &expert_idx in dst.iter().take(*n_expert_used) {
79 let expert_idx = expert_idx as usize;
80 let routing_weight = rw[expert_idx];
81 sum_routing_weights += routing_weight;
82 top_x[expert_idx].push(row_idx as u32);
83 }
84 for &expert_idx in dst.iter().take(*n_expert_used) {
85 let expert_idx = expert_idx as usize;
86 let routing_weight = rw[expert_idx];
87 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
88 }
89 }
90
91 let mut ys = xs.zeros_like()?;
95 for (expert_idx, expert_layer) in experts.iter().enumerate() {
96 let top_x = &top_x[expert_idx];
97 if top_x.is_empty() {
98 continue;
99 }
100 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
101 let selected_rws =
102 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
103 .reshape(((), 1))?;
104 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
108 let current_hidden_states = expert_layer.forward(¤t_state)?;
110 let current_hidden_states =
111 current_hidden_states.broadcast_mul(&selected_rws)?;
112 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
113 }
114
115 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
116 Ok(ys)
117 }
118 Self::Mlp(mlp) => mlp.forward(xs),
119 }
120 }
121}
122
123struct LayerWeights {
124 attention_wq: Arc<dyn QuantMethod>,
125 attention_wk: Arc<dyn QuantMethod>,
126 attention_wv: Arc<dyn QuantMethod>,
127 attention_wo: Arc<dyn QuantMethod>,
128 attention_norm: QRmsNorm,
129 mlp_or_moe: MlpOrMoe,
130 ffn_norm: QRmsNorm,
131 n_head: usize,
132 n_kv_head: usize,
133 head_dim: usize,
134 rotary: Arc<RotaryEmbedding>,
135 paged_attn: Option<PagedAttention>,
136 sdpa_params: SdpaParams,
137 dtype: DType,
138}
139
140impl LayerWeights {
141 fn forward_attn(
142 &self,
143 x: &Tensor,
144 mask: Option<&Tensor>,
145 start_offsets: &[usize],
146 kv_cache: &mut KvCache,
147 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
148 ) -> Result<Tensor> {
149 let (b_sz, seq_len, _) = x.dims3()?;
150
151 let q = MatMul
152 .qmethod_matmul(x, &*self.attention_wq)?
153 .to_dtype(self.dtype)?;
154 let k = MatMul
155 .qmethod_matmul(x, &*self.attention_wk)?
156 .to_dtype(self.dtype)?;
157 let v = MatMul
158 .qmethod_matmul(x, &*self.attention_wv)?
159 .to_dtype(self.dtype)?;
160
161 let (q, k, v) = if seq_len != 1 {
162 let q = q
163 .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
164 .transpose(1, 2)?;
165 let k = k
166 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
167 .transpose(1, 2)?;
168 let v = v
169 .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
170 .transpose(1, 2)?;
171 (q, k, v)
172 } else {
173 let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
174 let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
175 let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
176 (q, k, v)
177 };
178
179 let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
180
181 let y = match &self.paged_attn {
182 Some(paged_attn) => {
183 let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
184 paged_attn.forward(
185 &q,
186 &k,
187 &v,
188 mask,
189 Some(key_cache),
190 Some(value_cache),
191 input_metadata,
192 &self.sdpa_params,
193 None,
194 )?
195 }
196 None => {
197 let (k, v) = kv_cache.append(&k, &v)?;
198
199 Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
200 }
201 };
202
203 let y = if mask.is_some() {
204 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
205 } else {
206 y.reshape((b_sz, seq_len, ()))?
207 };
208
209 let y = MatMul.qmethod_matmul(&y.to_dtype(x.dtype())?, &*self.attention_wo)?;
210 Ok(y)
211 }
212}
213
214pub struct ModelWeights {
215 tok_embeddings: Embedding,
216 layers: Vec<LayerWeights>,
217 norm: QRmsNorm,
218 output: Arc<dyn QuantMethod>,
219 pub device: Device,
220 pub cache: EitherCache,
221 pub max_seq_len: usize,
222 mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
223 dtype: DType,
224}
225
226impl ModelConfig::FromGGML for ModelWeights {
227 fn from_ggml(mut ct: ggml_file::Content, gqa: usize, dtype: DType) -> Result<Self> {
228 let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
229 let rotary = RotaryEmbedding::new_partial(
230 10000.,
231 ct.hparams.n_rot as usize,
232 DEFAULT_MAX_SEQ_LEN as usize,
233 &ct.device,
234 false,
235 dtype,
236 )?;
237 let tok_embeddings = ct.remove("tok_embeddings.weight")?;
238 let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
239 let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
240 let output = ct.remove("output.weight")?;
241 let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
242 for layer_idx in NiceProgressBar::<_, 'b'>(
243 0..ct.hparams.n_layer,
244 "Loading repeating layers",
245 &new_multi_progress(),
246 ) {
247 let prefix = format!("layers.{layer_idx}");
248 let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
249 let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
250 let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
251 let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
252 let mlp_or_moe = {
253 let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
254 let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
255 let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
256 MlpOrMoe::Mlp(Mlp {
257 feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
258 q_weight: Arc::new(feed_forward_w1),
259 b: None,
260 })?),
261 feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
262 q_weight: Arc::new(feed_forward_w2),
263 b: None,
264 })?),
265 feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
266 q_weight: Arc::new(feed_forward_w3),
267 b: None,
268 })?),
269 })
270 };
271 let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
272 let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
273 let n_kv_head = ct.hparams.n_head as usize / gqa;
274 layers.push(LayerWeights {
275 attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
276 q_weight: Arc::new(attention_wq),
277 b: None,
278 })?),
279 attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
280 q_weight: Arc::new(attention_wk),
281 b: None,
282 })?),
283 attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
284 q_weight: Arc::new(attention_wv),
285 b: None,
286 })?),
287 attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
288 q_weight: Arc::new(attention_wo),
289 b: None,
290 })?),
291 attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
292 mlp_or_moe,
293 ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
294 n_head: ct.hparams.n_head as usize,
295 n_kv_head: ct.hparams.n_head as usize / gqa,
296 head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
297 rotary: rotary.clone().into(),
298 paged_attn: None, sdpa_params: SdpaParams {
300 n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
301 softcap: None,
302 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
303 sliding_window: None,
304 },
305 dtype,
306 })
307 }
308 Ok(Self {
309 tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
310 layers,
311 norm,
312 output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
313 q_weight: Arc::new(output),
314 b: None,
315 })?),
316 device: ct.device.clone(),
317 cache: EitherCache::Normal(NormalCache::new(
318 ct.hparams.n_layer as usize,
319 DEFAULT_MAX_SEQ_LEN as usize,
320 )),
321 max_seq_len: DEFAULT_MAX_SEQ_LEN as usize, mapper: None,
323 dtype,
324 })
325 }
326}
327
328pub(crate) struct PropsGGUF {
332 pub n_expert: usize,
333 pub n_expert_used: usize,
334 pub head_count: usize,
335 pub head_count_kv: usize,
336 pub block_count: usize,
337 pub embedding_length: usize,
338 pub rope_dim: usize,
339 pub rms_norm_eps: f32,
340 pub max_seq_len: usize,
341 pub rope_freq_base: f32,
342 pub key_length: usize,
343 pub value_length: usize,
344}
345
346impl TryFrom<ContentMetadata<'_>> for PropsGGUF {
347 type Error = anyhow::Error;
348
349 fn try_from(c: ContentMetadata) -> std::result::Result<Self, Self::Error> {
350 c.verify_arch("llama")?;
351
352 let required = [
353 "attention.head_count",
354 "attention.head_count_kv",
355 "block_count",
356 "embedding_length",
357 "rope.dimension_count",
358 "attention.layer_norm_rms_epsilon",
359 ];
360 c.has_required_keys(&required)?;
361
362 let embed_len = c.get_value::<u32>("embedding_length")? as usize;
363 let head_count = c.get_value::<u32>("attention.head_count")? as usize;
364
365 let props = Self {
368 n_expert: c.get_value::<u32>("expert_count").ok().unwrap_or(0) as usize,
369 n_expert_used: c.get_value::<u32>("expert_used_count").ok().unwrap_or(0) as usize,
370 head_count,
371 head_count_kv: c.get_value::<u32>("attention.head_count_kv")? as usize,
372 block_count: c.get_value::<u32>("block_count")? as usize,
373 embedding_length: embed_len,
374 rope_dim: c.get_value::<u32>("rope.dimension_count")? as usize,
375 rms_norm_eps: c.get_value("attention.layer_norm_rms_epsilon")?,
377 max_seq_len: c
378 .get_value::<u64>("context_length")
379 .ok()
380 .unwrap_or(DEFAULT_MAX_SEQ_LEN as u64) as usize,
381 rope_freq_base: c.get_value("rope.freq_base").ok().unwrap_or(10_000_f32),
382 key_length: c
383 .get_value::<u32>("attention.key_length")
384 .ok()
385 .map(|x| x as usize)
386 .unwrap_or(embed_len / head_count),
387 value_length: c
388 .get_value::<u32>("attention.value_length")
389 .ok()
390 .map(|x| x as usize)
391 .unwrap_or(embed_len / head_count),
392 };
393
394 Ok(props)
395 }
396}
397
398impl ModelConfig::FromGGUF for ModelWeights {
399 fn from_gguf<R: std::io::Seek + std::io::Read>(
400 mut ct: Content<'_, R>,
401 device: &Device,
402 mapper: Box<dyn DeviceMapper + Send + Sync>,
403 attention_mechanism: AttentionImplementation,
404 dtype: DType,
405 ) -> Result<Self> {
406 let metadata = ContentMetadata {
408 path_prefix: "llama",
409 metadata: ct.get_metadata(),
410 };
411 let PropsGGUF {
412 n_expert,
413 n_expert_used,
414 head_count,
415 head_count_kv,
416 block_count,
417 embedding_length,
418 rope_dim,
419 rms_norm_eps,
420 max_seq_len,
421 rope_freq_base,
422 key_length,
423 value_length,
424 } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
425
426 let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
427 let tok_embeddings = qtok_embeddings.dequantize(device)?;
428 let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
429 let output = if !ct.has_tensor("output.weight") {
430 ct.tensor("token_embd.weight", device)?
431 } else {
432 ct.tensor("output.weight", device)?
433 };
434 let mut layers = Vec::with_capacity(block_count);
435
436 let head_dim = key_length;
437 if key_length != value_length {
438 candle_core::bail!(
439 "Expected key_length == value_length, got {key_length} != {value_length}"
440 );
441 }
442
443 let mut ropes = HashMap::new();
444 for layer_idx in 0..block_count {
445 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
446 ropes.insert(
447 device.location(),
448 Arc::new(RotaryEmbedding::new(
449 rope_freq_base,
450 rope_dim,
451 max_seq_len,
452 device,
453 false,
454 dtype,
455 )?),
456 );
457 }
458
459 for layer_idx in NiceProgressBar::<_, 'b'>(
460 0..block_count,
461 "Loading repeating layers",
462 &new_multi_progress(),
463 ) {
464 let prefix = format!("blk.{layer_idx}");
465 let device = mapper.device_for(layer_idx, false).unwrap_or(device);
466 let rotary = ropes
467 .get(&device.location())
468 .expect("No RoPE for device location!")
469 .clone();
470
471 let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
472 let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
473 let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
474 let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
475 let mlp_or_moe = if n_expert <= 1 {
476 let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
477 let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
478 let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
479 MlpOrMoe::Mlp(Mlp {
480 feed_forward_w1: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
481 q_weight: Arc::new(feed_forward_w1),
482 b: None,
483 })?),
484 feed_forward_w2: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
485 q_weight: Arc::new(feed_forward_w2),
486 b: None,
487 })?),
488 feed_forward_w3: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
489 q_weight: Arc::new(feed_forward_w3),
490 b: None,
491 })?),
492 })
493 } else {
494 let feed_forward_gate_inp =
495 ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
496 let mut experts = Vec::with_capacity(n_expert);
497 match ct.tensor(&format!("{prefix}.ffn_gate_exps.weight"), device) {
498 Ok(feed_forward_gate_exps) => {
499 let feed_forward_down_exps =
500 ct.tensor(&format!("{prefix}.ffn_down_exps.weight"), device)?;
501 let feed_forward_up_exps =
502 ct.tensor(&format!("{prefix}.ffn_up_exps.weight"), device)?;
503
504 let dequant_ffn_gate = feed_forward_gate_exps
505 .dequantize(device)?
506 .chunk(n_expert, 0)?;
507 let dequant_ffn_down = feed_forward_down_exps
508 .dequantize(device)?
509 .chunk(n_expert, 0)?;
510 let dequant_ffn_up = feed_forward_up_exps
511 .dequantize(device)?
512 .chunk(n_expert, 0)?;
513
514 assert_eq!(dequant_ffn_up.len(), dequant_ffn_down.len());
515 assert_eq!(dequant_ffn_gate.len(), dequant_ffn_down.len());
516 assert_eq!(dequant_ffn_gate.len(), n_expert);
517
518 let gate_type = feed_forward_gate_exps.dtype();
519 let down_type = feed_forward_down_exps.dtype();
520 let up_type = feed_forward_up_exps.dtype();
521
522 for (ff_w1, (ff_w2, ff_w3)) in dequant_ffn_gate
523 .into_iter()
524 .zip(dequant_ffn_down.into_iter().zip(dequant_ffn_up))
525 {
526 experts.push(Mlp {
527 feed_forward_w1: Arc::new(GgufMatMul::new(
528 QuantMethodConfig::Gguf {
529 q_weight: Arc::new(QTensor::quantize(&ff_w1, gate_type)?),
530 b: None,
531 },
532 )?),
533 feed_forward_w2: Arc::new(GgufMatMul::new(
534 QuantMethodConfig::Gguf {
535 q_weight: Arc::new(QTensor::quantize(&ff_w2, down_type)?),
536 b: None,
537 },
538 )?),
539 feed_forward_w3: Arc::new(GgufMatMul::new(
540 QuantMethodConfig::Gguf {
541 q_weight: Arc::new(QTensor::quantize(&ff_w3, up_type)?),
542 b: None,
543 },
544 )?),
545 })
546 }
547 }
548 Err(_) => {
549 for i in 0..n_expert {
550 let feed_forward_w1 =
551 ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
552 let feed_forward_w2 =
553 ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
554 let feed_forward_w3 =
555 ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
556 experts.push(Mlp {
557 feed_forward_w1: Arc::new(GgufMatMul::new(
558 QuantMethodConfig::Gguf {
559 q_weight: Arc::new(feed_forward_w1),
560 b: None,
561 },
562 )?),
563 feed_forward_w2: Arc::new(GgufMatMul::new(
564 QuantMethodConfig::Gguf {
565 q_weight: Arc::new(feed_forward_w2),
566 b: None,
567 },
568 )?),
569 feed_forward_w3: Arc::new(GgufMatMul::new(
570 QuantMethodConfig::Gguf {
571 q_weight: Arc::new(feed_forward_w3),
572 b: None,
573 },
574 )?),
575 })
576 }
577 }
578 }
579 MlpOrMoe::MoE {
580 n_expert_used,
581 feed_forward_gate_inp: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
582 q_weight: Arc::new(feed_forward_gate_inp),
583 b: None,
584 })?),
585 experts,
586 }
587 };
588 let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
589 let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
590 let paged_attn = match &attention_mechanism {
591 AttentionImplementation::Eager => None,
592 AttentionImplementation::PagedAttention => {
593 Some(PagedAttention::new(head_dim, device, None)?)
594 }
595 };
596 layers.push(LayerWeights {
597 attention_wq: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
598 q_weight: Arc::new(attention_wq),
599 b: None,
600 })?),
601 attention_wk: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
602 q_weight: Arc::new(attention_wk),
603 b: None,
604 })?),
605 attention_wv: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
606 q_weight: Arc::new(attention_wv),
607 b: None,
608 })?),
609 attention_wo: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
610 q_weight: Arc::new(attention_wo),
611 b: None,
612 })?),
613 attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
614 mlp_or_moe,
615 ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
616 n_head: head_count,
617 n_kv_head: head_count_kv,
618 head_dim,
619 rotary: rotary.clone(),
620 paged_attn,
621 sdpa_params: SdpaParams {
622 n_kv_groups: head_count / head_count_kv,
623 softcap: None,
624 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
625 sliding_window: None,
626 },
627 dtype,
628 })
629 }
630 Ok(Self {
631 tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
632 layers,
633 norm,
634 output: Arc::new(GgufMatMul::new(QuantMethodConfig::Gguf {
635 q_weight: Arc::new(output),
636 b: None,
637 })?),
638 device: device.clone(),
639 cache: EitherCache::Normal(NormalCache::new(block_count, max_seq_len)),
640 max_seq_len,
641 mapper: Some(mapper),
642 dtype,
643 })
644 }
645}
646
647impl ModelWeights {
648 pub fn forward(
649 &self,
650 x: &Tensor,
651 start_offsets: &[usize],
652 context_lens: Vec<(usize, usize)>,
653 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
654 ) -> Result<Tensor> {
655 let mut layer_in = self.tok_embeddings.forward(x)?;
656 let cache = &mut self.cache.normal().0;
657 let mask = CausalMasker.make_causal_mask_matrix(
658 x,
659 metadata
660 .as_ref()
661 .map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
662 .unwrap_or(cache as &dyn PastKvLenCache),
663 self.dtype,
664 self.layers[0].n_head,
665 )?;
666 let mask = mask.filter(|_| {
668 metadata
669 .as_ref()
670 .map(|(_, meta)| meta.is_first_prompt_chunk)
671 .unwrap_or(true)
672 });
673 for (i, layer) in self.layers.iter().enumerate() {
674 if let Some(ref mapper) = self.mapper {
675 layer_in = mapper.map(layer_in, i)?;
676 }
677 let x = layer_in;
678 let residual = &x;
679 let x = layer.attention_norm.forward(&x)?;
680 let attn = layer.forward_attn(
681 &x,
682 mask.as_ref()
683 .map(|m| m.to_device(x.device()).unwrap())
684 .as_ref(),
685 start_offsets,
686 &mut cache[i],
687 metadata
688 .as_ref()
689 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
690 )?;
691 let x = (attn + residual)?;
692
693 let residual = &x;
695 let x = layer.ffn_norm.forward(&x)?;
696 let x = layer.mlp_or_moe.forward(&x)?;
697 let x = (x + residual)?;
698 layer_in = x;
699 }
700 let layer_in = layer_in.to_device(&self.device)?;
701 let x = self.norm.forward(&layer_in)?;
702 extract_logits(
703 &MatMul.qmethod_matmul(&x.contiguous()?, &*self.output)?,
704 context_lens,
705 )
706 }
707}