1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8 distributed::DistributedOperation, ColumnParallelLayer, QuantMethod, QuantizedConfig,
9 ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 attention::SdpaParams,
16 device_map::DeviceMapper,
17 layers::{
18 embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19 DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20 },
21 layers_masker::{masked_fill, PastKvLenCache},
22 ops::{BincountOp, NonZeroOp, SplitOp, TopKLastDimOp, TopKOutput},
23 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24 pipeline::{
25 extract_logits,
26 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28 },
29 serde_default_fn,
30 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32
33serde_default_fn!(f64, routed_scaling_factor, 1.0);
34serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
35serde_default_fn!(usize, moe_layer_freq, 1);
36serde_default_fn!(usize, first_k_dense_replace, 0);
37serde_default_fn!(bool, norm_topk_prob, false);
38serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
39serde_default_fn!(Activation, hidden_act, Activation::Silu);
40serde_default_fn!(bool, tie_word_embeddings, false);
41serde_default_fn!(bool, use_flash_attn_default, false);
42
43#[derive(Deserialize, Clone, Debug)]
44enum TopkMethod {
45 #[serde(rename = "greedy")]
46 Greedy,
47 #[serde(rename = "group_limited_greedy")]
48 GroupLimitedGreedy,
49}
50
51#[derive(Deserialize, Clone, Debug)]
52enum ScoringFunc {
53 #[serde(rename = "softmax")]
54 Softmax,
55}
56
57#[derive(Deserialize, Clone, Debug)]
58pub struct DeepSeekV2Config {
59 pub(crate) vocab_size: usize,
60 pub(crate) hidden_size: usize,
61 pub(crate) intermediate_size: usize,
62 pub(crate) moe_intermediate_size: usize,
63 pub(crate) num_hidden_layers: usize,
64 pub(crate) num_attention_heads: usize,
65 pub(crate) n_shared_experts: Option<usize>,
66 pub(crate) n_routed_experts: Option<usize>,
67 #[serde(default = "routed_scaling_factor")]
68 pub(crate) routed_scaling_factor: f64,
69 #[serde(default = "topk_method")]
70 topk_method: TopkMethod,
71 pub(crate) num_experts_per_tok: Option<usize>,
72 #[serde(default = "moe_layer_freq")]
73 pub(crate) moe_layer_freq: usize,
74 #[serde(default = "first_k_dense_replace")]
75 pub(crate) first_k_dense_replace: usize,
76 #[serde(default = "norm_topk_prob")]
78 pub(crate) norm_topk_prob: bool,
79 #[serde(default = "scoring_func")]
80 scoring_func: ScoringFunc,
81 #[serde(default = "hidden_act")]
82 pub(crate) hidden_act: Activation,
83 pub(crate) max_position_embeddings: usize,
84 pub(crate) rms_norm_eps: f64,
85 #[serde(default = "tie_word_embeddings")]
86 pub(crate) tie_word_embeddings: bool,
87 pub(crate) rope_theta: f32,
88 pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
89 pub(crate) attention_bias: bool,
90 pub(crate) q_lora_rank: Option<usize>,
91 pub(crate) qk_rope_head_dim: usize,
92 pub(crate) kv_lora_rank: usize,
93 pub(crate) v_head_dim: usize,
94 pub(crate) qk_nope_head_dim: usize,
95 #[serde(default = "use_flash_attn_default")]
96 pub(crate) use_flash_attn: bool,
97 pub(crate) quantization_config: Option<QuantizedConfig>,
98 pub(crate) n_group: usize,
99 pub(crate) topk_group: usize,
100}
101
102impl DeepSeekV2Config {
103 pub(crate) fn q_head_dim(&self) -> usize {
104 self.qk_rope_head_dim + self.qk_nope_head_dim
105 }
106
107 fn softmax_scale(&self) -> f32 {
108 let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
109 if let Some(DeepSeekV2RopeScaling::Yarn {
110 mscale_all_dim,
111 factor,
112 ..
113 }) = self.rope_scaling
114 {
115 let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
116 softmax_scale = softmax_scale * mscale * mscale;
117 }
118 softmax_scale
119 }
120}
121
122enum QProj {
123 Plain(Arc<dyn QuantMethod>),
124 Lora {
125 a: Arc<dyn QuantMethod>,
126 norm: RmsNorm,
127 b: Arc<dyn QuantMethod>,
128 },
129}
130
131impl QProj {
132 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
133 match self {
134 Self::Lora { a, norm, b } => {
135 b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
136 }
137 Self::Plain(lin) => lin.forward_autocast(xs),
138 }
139 }
140}
141
142struct Attention {
143 q: QProj,
144 kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
145 kv_a_layernorm: RmsNorm,
146 kv_b_proj: Arc<dyn QuantMethod>,
147 o_proj: Arc<dyn QuantMethod>,
148 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
149 cfg: DeepSeekV2Config,
150 q_head_dim: usize,
151 paged_attn: Option<PagedAttention>,
152 sdpa_params: SdpaParams,
153 num_attention_heads: usize,
154}
155
156impl Attention {
157 #[allow(clippy::too_many_arguments)]
158 fn new(
159 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
160 cfg: &DeepSeekV2Config,
161 vb: ShardedVarBuilder,
162 mapper: &dyn DeviceMapper,
163 layer_idx: usize,
164 loading_isq: bool,
165 paged_attn: Option<PagedAttention>,
166 comm: &Arc<mistralrs_quant::Comm>,
167 ) -> Result<Self> {
168 let q_head_dim = cfg.q_head_dim();
169 let q = match cfg.q_lora_rank {
170 Some(lora_rank) => {
171 let a = ReplicatedLayer::new(
172 cfg.hidden_size,
173 lora_rank,
174 &cfg.quantization_config,
175 cfg.attention_bias,
176 mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
177 )?;
178 let norm = RmsNorm::new(
179 lora_rank,
180 cfg.rms_norm_eps,
181 mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
182 )?;
183 let b = ColumnParallelLayer::new(
184 lora_rank,
185 cfg.num_attention_heads * q_head_dim,
186 &cfg.quantization_config,
187 false,
188 comm,
189 mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
190 )?;
191 QProj::Lora { a, norm, b }
192 }
193 None => QProj::Plain(ColumnParallelLayer::new(
194 cfg.hidden_size,
195 cfg.num_attention_heads * q_head_dim,
196 &cfg.quantization_config,
197 false,
198 comm,
199 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
200 )?),
201 };
202
203 let kv_a_proj_with_mqa = ReplicatedLayer::new(
204 cfg.hidden_size,
205 cfg.kv_lora_rank + cfg.qk_rope_head_dim,
206 &cfg.quantization_config,
207 cfg.attention_bias,
208 mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
209 )?;
210 let kv_a_layernorm = RmsNorm::new(
211 cfg.kv_lora_rank,
212 cfg.rms_norm_eps,
213 mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
214 )?;
215 let kv_b_proj = ColumnParallelLayer::new(
216 cfg.kv_lora_rank,
217 cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
218 &cfg.quantization_config,
219 false,
220 comm,
221 mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
222 )?;
223
224 let o_proj = RowParallelLayer::new(
225 cfg.num_attention_heads * cfg.v_head_dim,
226 cfg.hidden_size,
227 &cfg.quantization_config,
228 cfg.attention_bias,
229 comm,
230 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
231 )?;
232
233 Ok(Self {
234 q,
235 kv_a_proj_with_mqa,
236 kv_a_layernorm,
237 kv_b_proj,
238 o_proj,
239 rotary_emb,
240 cfg: cfg.clone(),
241 q_head_dim,
242 paged_attn,
243 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
244 sdpa_params: SdpaParams {
245 n_kv_groups: 1,
246 use_flash_attn: cfg.use_flash_attn,
247 softcap: None,
248 softmax_scale: cfg.softmax_scale(),
249 sliding_window: None,
250 },
251 })
252 }
253
254 fn forward(
255 &self,
256 xs: &Tensor,
257 attention_mask: Option<&Tensor>,
258 seqlen_offsets: &[usize],
259 kv_cache: &mut KvCache,
260 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
261 flash_params: &FlashParams,
262 ) -> Result<Tensor> {
263 let (bs, seq_len, _) = xs.dims3()?;
264
265 let mut q = self.q.forward(xs)?;
266 q = q
267 .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
268 .transpose(1, 2)?;
269 let q_split = q.split(
270 &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
271 D::Minus1,
272 )?;
273 let q_nope = q_split[0].clone();
274 let mut q_pe = q_split[1].clone();
275
276 let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
277 let ckv_split = compressed_kv.split(
278 &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
279 D::Minus1,
280 )?;
281 compressed_kv = ckv_split[0].clone();
282 let mut k_pe = ckv_split[1].clone();
283 k_pe = k_pe
284 .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
285 .transpose(1, 2)?;
286 let mut kv = self
287 .kv_b_proj
288 .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
289 kv = kv
290 .reshape((
291 bs,
292 seq_len,
293 self.num_attention_heads,
294 self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
295 ))?
296 .transpose(1, 2)?;
297
298 let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
299 let k_nope = kv_split[0].clone();
300 let mut v = kv_split[1].clone();
301
302 (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
303
304 let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
305 let mut k = Tensor::cat(
306 &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
307 D::Minus1,
308 )?
309 .contiguous()?;
310
311 let mut attn_out = match &self.paged_attn {
312 Some(paged_attn) => match metadata {
313 Some(((key_cache, value_cache), input_metadata)) => {
314 let v = v
315 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
316 .contiguous()?;
317 paged_attn
318 .forward(
319 &q,
320 &k,
321 &v,
322 attention_mask,
323 Some(key_cache),
324 Some(value_cache),
325 input_metadata,
326 &self.sdpa_params,
327 Some(flash_params),
328 )?
329 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
330 }
331 None => {
332 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
335 assert!(attention_mask.is_some());
337 let v = v
338 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
339 .contiguous()?;
340 paged_attn
341 .forward(
342 &q,
343 &k,
344 &v,
345 attention_mask,
346 None,
347 None,
348 &input_metadata,
349 &self.sdpa_params,
350 Some(flash_params),
351 )?
352 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
353 }
354 },
355 None => {
356 (k, v) = kv_cache.append(&k, &v)?;
357
358 Sdpa.run_attention(
359 &q,
360 &k,
361 &v,
362 attention_mask,
363 Some(flash_params),
364 &self.sdpa_params,
365 )?
366 }
367 };
368
369 attn_out = if attention_mask.is_some() {
370 attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
371 } else {
372 attn_out.reshape((bs, seq_len, ()))?
373 };
374
375 self.o_proj.forward_autocast(&attn_out)
376 }
377}
378
379struct Expert {
380 gate: Arc<dyn QuantMethod>,
381 up: Arc<dyn QuantMethod>,
382 down: Arc<dyn QuantMethod>,
383 act: Activation,
384}
385
386impl Expert {
387 fn new(
388 cfg: &DeepSeekV2Config,
389 vb: ShardedVarBuilder,
390 hidden_size: Option<usize>,
391 intermediate_size: Option<usize>,
392 ) -> Result<Self> {
393 let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
394 let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
395
396 Ok(Self {
397 gate: ReplicatedLayer::new(
398 hidden_size,
399 intermediate_size,
400 &cfg.quantization_config,
401 false,
402 vb.pp("gate_proj"),
403 )?,
404 up: ReplicatedLayer::new(
405 hidden_size,
406 intermediate_size,
407 &cfg.quantization_config,
408 false,
409 vb.pp("up_proj"),
410 )?,
411 down: ReplicatedLayer::new(
412 intermediate_size,
413 hidden_size,
414 &cfg.quantization_config,
415 false,
416 vb.pp("down_proj"),
417 )?,
418 act: cfg.hidden_act,
419 })
420 }
421
422 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423 let original_dtype = xs.dtype();
424 let mut xs = xs.clone();
425 if let Some(t) = self.gate.quantized_act_type() {
426 xs = xs.to_dtype(t)?;
427 }
428 let lhs = self.gate.forward(&xs)?;
429 let rhs = self.up.forward(&xs)?;
430 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
431 &lhs,
432 &rhs,
433 self.act.try_into()?,
434 )?)?;
435 if self.gate.quantized_act_type().is_some() {
436 res = res.to_dtype(original_dtype)?;
437 }
438 Ok(res)
439 }
440}
441
442struct MoeGate {
443 weight: Tensor,
444 cfg: DeepSeekV2Config,
445 top_k: usize,
446 n_routed_experts: usize,
447}
448
449impl MoeGate {
450 fn new(cfg: &DeepSeekV2Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
451 let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
452 Ok(Self {
453 weight,
454 cfg: cfg.clone(),
455 top_k: cfg.num_experts_per_tok.unwrap(),
456 n_routed_experts,
457 })
458 }
459
460 fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
462 let (bs, seq_len, h) = xs.dims3()?;
463 let xs = xs.reshape(((), h))?;
465 let logits = xs
466 .to_dtype(DType::F32)?
467 .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
468 let scores = match self.cfg.scoring_func {
469 ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
470 };
471
472 let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
474 TopkMethod::Greedy => {
475 let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
476 (values, indices)
477 }
478 TopkMethod::GroupLimitedGreedy => {
479 let group_scores = scores
481 .reshape((bs * seq_len, self.cfg.n_group, ()))?
482 .max(D::Minus1)?;
483 let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices;
485 let mut group_mask = group_scores.zeros_like()?;
487 group_mask = group_mask.scatter_add(
489 &group_idx,
490 &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
491 1,
492 )?;
493 let score_mask = group_mask
495 .unsqueeze(D::Minus1)?
496 .expand((
497 bs * seq_len,
498 self.cfg.n_group,
499 self.n_routed_experts / self.cfg.n_group,
500 ))?
501 .reshape((bs, seq_len, ()))?;
502 let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
505 let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
506 (values, indices)
507 }
508 };
509
510 if self.top_k > 1 && self.cfg.norm_topk_prob {
511 let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
512 topk_weight = (topk_weight / denmoninator)?;
513 } else {
514 topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
515 }
516 Ok((topk_idx, topk_weight))
517 }
518}
519
520struct Moe {
521 experts: Vec<Option<Expert>>,
522 shared_experts: Option<Mlp>,
523 gate: MoeGate,
524 all_reduce: SumAllReduce,
525 experts_start_idx: usize,
526 experts_end_idx: usize,
527 world_size: usize,
528}
529
530impl Moe {
531 #[allow(clippy::too_many_arguments)]
532 fn new(
533 cfg: &DeepSeekV2Config,
534 vb: ShardedVarBuilder,
535 mapper: &dyn DeviceMapper,
536 layer_idx: usize,
537 loading_isq: bool,
538 n_shared_experts: Option<usize>,
539 n_routed_experts: usize,
540 comm: &Arc<mistralrs_quant::Comm>,
541 ) -> Result<Self> {
542 let mut experts = Vec::with_capacity(n_routed_experts);
543 let n_local_experts = n_routed_experts / comm.world_size();
544 let experts_start_idx = comm.rank() * n_local_experts;
545 let experts_end_idx = experts_start_idx + n_local_experts;
546 for i in 0..n_routed_experts {
547 if i >= experts_start_idx && i < experts_end_idx {
548 let vb_e = vb.pp("experts").pp(i);
549 experts.push(Some(Expert::new(
550 cfg,
551 mapper.set_device(layer_idx, vb_e, loading_isq),
552 None,
553 Some(cfg.moe_intermediate_size),
554 )?));
555 } else {
556 experts.push(None);
557 }
558 }
559 let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
560 let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
561 Some(Mlp::new(
562 mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
563 cfg.hidden_size,
564 intermediate_size,
565 &cfg.quantization_config,
566 cfg.hidden_act,
567 comm,
568 )?)
569 } else {
570 None
571 };
572 let gate = MoeGate::new(
573 cfg,
574 mapper.set_device(layer_idx, vb.pp("gate"), false),
575 n_routed_experts,
576 )?;
577 Ok(Self {
578 experts,
579 shared_experts,
580 gate,
581 all_reduce: SumAllReduce::new(comm),
582 experts_end_idx,
583 experts_start_idx,
584 world_size: comm.world_size(),
585 })
586 }
587
588 fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
589 let mut y = xs.zeros_like()?;
590 let counts = topk_ids
591 .flatten_all()?
592 .bincount(self.experts.len() as u32)?;
593 for (i, count) in counts
594 .iter()
595 .enumerate()
596 .take(self.experts_end_idx)
597 .skip(self.experts_start_idx)
598 {
599 if *count == 0 {
600 continue;
601 }
602 let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
603 let idx = &idx_top.i(0)?.contiguous()?;
604 let top = &idx_top.i(1)?.contiguous()?;
605
606 let expert = self.experts[i]
607 .as_ref()
608 .context("Expert is not present for this rank.")?;
609
610 y = y.index_add(
611 idx,
612 &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
613 &topk_weight
614 .index_select(idx, 0)?
615 .gather(&top.unsqueeze(1)?, 1)?
616 .squeeze(1)?
617 .unsqueeze(D::Minus1)?
618 .to_dtype(xs.dtype())?,
619 )?,
620 0,
621 )?;
622 }
623
624 if self.world_size > 1 {
625 y = self.all_reduce.sum_all_reduce(&y)?;
626 }
627
628 Ok(y)
629 }
630
631 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
632 let identity = xs.clone();
633 let orig_shape = xs.shape();
634 let (topk_idx, topk_weight) = self.gate.forward(xs)?;
635 let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
636
637 let mut y = self
638 .moe_infer(&xs, &topk_idx, &topk_weight)?
639 .reshape(orig_shape)?;
640 if let Some(ref shared_experts) = self.shared_experts {
641 y = (y + shared_experts.forward(&identity)?)?;
642 }
643 Ok(y)
644 }
645}
646
647enum MoeOrMlp {
648 Moe(Moe),
649 Mlp(Mlp),
650}
651
652impl MoeOrMlp {
653 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
654 match self {
655 Self::Mlp(mlp) => mlp.forward(xs),
656 Self::Moe(moe) => moe.forward(xs),
657 }
658 }
659}
660
661struct DecoderLayer {
662 input_layernorm: RmsNorm,
663 post_attention_layernorm: RmsNorm,
664 attn: Attention,
665 moe_or_mlp: MoeOrMlp,
666}
667
668impl DecoderLayer {
669 #[allow(clippy::too_many_arguments)]
670 fn new(
671 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
672 cfg: &DeepSeekV2Config,
673 vb: ShardedVarBuilder,
674 mapper: &dyn DeviceMapper,
675 layer_idx: usize,
676 loading_isq: bool,
677 paged_attn: Option<PagedAttention>,
678 comm: &Arc<mistralrs_quant::Comm>,
679 ) -> Result<Self> {
680 let attn = Attention::new(
681 rotary_emb,
682 cfg,
683 vb.pp("self_attn"),
684 mapper,
685 layer_idx,
686 loading_isq,
687 paged_attn,
688 comm,
689 )?;
690 let input_layernorm = RmsNorm::new(
691 cfg.hidden_size,
692 cfg.rms_norm_eps,
693 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
694 )?;
695 let post_attention_layernorm = RmsNorm::new(
696 cfg.hidden_size,
697 cfg.rms_norm_eps,
698 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
699 )?;
700 let moe_or_mlp = if cfg.n_routed_experts.is_some()
701 && layer_idx >= cfg.first_k_dense_replace
702 && layer_idx % cfg.moe_layer_freq == 0
703 {
704 MoeOrMlp::Moe(Moe::new(
705 cfg,
706 vb.pp("mlp"),
707 mapper,
708 layer_idx,
709 loading_isq,
710 cfg.n_shared_experts,
711 cfg.n_routed_experts.unwrap(),
712 comm,
713 )?)
714 } else {
715 MoeOrMlp::Mlp(Mlp::new(
716 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
717 cfg.hidden_size,
718 cfg.intermediate_size,
719 &cfg.quantization_config,
720 cfg.hidden_act,
721 comm,
722 )?)
723 };
724
725 Ok(Self {
726 input_layernorm,
727 post_attention_layernorm,
728 attn,
729 moe_or_mlp,
730 })
731 }
732
733 fn forward(
734 &self,
735 xs: &Tensor,
736 attention_mask: Option<&Tensor>,
737 seqlen_offsets: &[usize],
738 kv_cache: &mut KvCache,
739 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
740 flash_params: &FlashParams,
741 ) -> Result<Tensor> {
742 let residual = xs;
743 let xs = self.input_layernorm.forward(xs)?;
744 let xs = self.attn.forward(
745 &xs,
746 attention_mask,
747 seqlen_offsets,
748 kv_cache,
749 metadata,
750 flash_params,
751 )?;
752 let xs = (xs + residual)?;
753 let residual = &xs;
754 let xs = self
755 .moe_or_mlp
756 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
757 residual + xs
758 }
759}
760
761pub struct DeepSeekV2 {
762 lm_head: Arc<dyn QuantMethod>,
763 embed_tokens: Embedding,
764 norm: RmsNorm,
765 layers: Vec<DecoderLayer>,
766 cache: EitherCache,
767 device: Device,
768 max_seq_len: usize,
769 cfg: ModelConfigMetadata,
770 mapper: Box<dyn DeviceMapper + Send + Sync>,
771}
772
773impl DeepSeekV2 {
774 pub fn new(
775 cfg: &DeepSeekV2Config,
776 vb: ShardedVarBuilder,
777 _is_gptx: bool,
778 normal_loading_metadata: NormalLoadingMetadata,
779 attention_mechanism: AttentionImplementation,
780 ) -> Result<Self> {
781 let vb_m = vb.pp("model");
782
783 let mapper = normal_loading_metadata.mapper;
784
785 let embed_tokens = embedding(
786 cfg.vocab_size,
787 cfg.hidden_size,
788 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
789 )?;
790 let lm_head = if !cfg.tie_word_embeddings {
791 ReplicatedLayer::new(
792 cfg.hidden_size,
793 cfg.vocab_size,
794 &None,
795 false,
796 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
797 )?
798 } else {
799 ReplicatedLayer::from_linear(candle_nn::Linear::new(
800 mapper.cast_nm_device(
801 embed_tokens.embeddings(),
802 normal_loading_metadata.loading_isq,
803 )?,
804 None,
805 ))?
806 };
807 let norm = RmsNorm::new(
808 cfg.hidden_size,
809 cfg.rms_norm_eps,
810 mapper.set_nm_device(vb_m.pp("norm"), false),
811 )?;
812
813 let mut ropes = HashMap::new();
814 let rope_cfg = DeepSeekV2RopeConfig {
815 rope_scaling: cfg.rope_scaling.clone(),
816 max_position_embeddings: cfg.max_position_embeddings,
817 rope_theta: cfg.rope_theta,
818 qk_rope_head_dim: cfg.qk_rope_head_dim,
819 };
820 for i in 0..cfg.num_hidden_layers {
821 let device = mapper
822 .device_for(i, false)
823 .unwrap_or(&normal_loading_metadata.real_device);
824 ropes.insert(
825 device.location(),
826 Arc::new(DeepSeekV2RotaryEmbedding::new(
827 &rope_cfg,
828 vb.dtype(),
829 device,
830 )?),
831 );
832 }
833
834 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
835 let vb_l = vb_m.pp("layers");
836 for layer_idx in NiceProgressBar::<_, 'b'>(
837 0..cfg.num_hidden_layers,
838 "Loading repeating layers",
839 &normal_loading_metadata.multi_progress,
840 ) {
841 let device = mapper
842 .device_for(layer_idx, false)
843 .unwrap_or(&normal_loading_metadata.real_device);
844 let rotary_emb = ropes
845 .get(&device.location())
846 .expect("No RoPE for device location!")
847 .clone();
848 let paged_attn = match &attention_mechanism {
849 AttentionImplementation::Eager => None,
850 AttentionImplementation::PagedAttention => Some(
851 PagedAttention::new(cfg.v_head_dim, device, None)
852 .expect("Failed to create PagedAttention"),
853 ),
854 };
855 let comm = mapper.get_comm_for(layer_idx)?;
856 let layer = DecoderLayer::new(
857 rotary_emb.clone(),
858 cfg,
859 vb_l.pp(layer_idx),
860 &*mapper,
861 layer_idx,
862 normal_loading_metadata.loading_isq,
863 paged_attn,
864 &comm,
865 )?;
866 layers.push(layer)
867 }
868
869 Ok(Self {
870 lm_head,
871 embed_tokens,
872 norm,
873 layers,
874 cache: EitherCache::Normal(NormalCache::new(
875 cfg.num_hidden_layers,
876 cfg.max_position_embeddings,
877 )),
878 device: normal_loading_metadata.real_device.clone(),
879 max_seq_len: cfg.max_position_embeddings,
880 cfg: ModelConfigMetadata {
881 max_seq_len: cfg.max_position_embeddings,
882 num_layers: cfg.num_hidden_layers,
883 hidden_size: cfg.hidden_size,
884 num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
885 .max(1),
886 num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
887 .max(1),
888 sliding_window: None,
889 k_head_dim: cfg.q_head_dim(),
890 v_head_dim: if matches!(
891 attention_mechanism,
892 AttentionImplementation::PagedAttention
893 ) {
894 cfg.q_head_dim()
895 } else {
896 cfg.v_head_dim
897 },
898 },
899 mapper,
900 })
901 }
902
903 #[allow(clippy::too_many_arguments)]
904 pub fn forward(
905 &self,
906 input_ids: &Tensor,
907 seqlen_offsets: &[usize],
908 context_lens: Vec<(usize, usize)>,
909 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
910 flash_params: &FlashParams,
911 ) -> Result<Tensor> {
912 let mut xs = self.embed_tokens.forward(input_ids)?;
913 let cache = &mut self.cache.normal().0;
914 let attention_mask = CausalMasker.make_causal_mask_matrix(
915 input_ids,
916 metadata
917 .as_ref()
918 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
919 .unwrap_or(cache as &dyn PastKvLenCache),
920 xs.dtype(),
921 self.cfg.num_attn_heads,
922 )?;
923 let attention_mask = attention_mask.filter(|_| {
925 metadata
926 .as_ref()
927 .map(|(_, meta)| meta.is_first_prompt_chunk)
928 .unwrap_or(true)
929 });
930 for (i, layer) in self.layers.iter().enumerate() {
931 xs = self.mapper.map(xs, i)?;
932 xs = layer.forward(
933 &xs,
934 attention_mask
935 .as_ref()
936 .map(|m| m.to_device(xs.device()).unwrap())
937 .as_ref(),
938 seqlen_offsets,
939 &mut cache[i],
940 metadata
941 .as_ref()
942 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
943 flash_params,
944 )?;
945 }
946 let xs = xs.to_device(&self.device)?;
947 let xs = xs.apply(&self.norm)?;
948 extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
949 }
950}
951
952impl IsqModel for DeepSeekV2 {
953 fn get_layers(
954 &mut self,
955 ) -> (
956 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
957 &dyn DeviceMapper,
958 ) {
959 let mut tensors = Vec::new();
960 tensors.push((&mut self.lm_head, None));
961 for (i, layer) in self.layers.iter_mut().enumerate() {
962 match &mut layer.attn.q {
963 QProj::Plain(q) => {
964 tensors.push((q, Some(i)));
965 }
966 QProj::Lora { a, norm: _, b } => {
967 tensors.push((a, Some(i)));
968 tensors.push((b, Some(i)));
969 }
970 }
971 tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
972 tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
973 tensors.push((&mut layer.attn.o_proj, Some(i)));
974 match &mut layer.moe_or_mlp {
975 MoeOrMlp::Mlp(mlp) => {
976 tensors.push((&mut mlp.gate, Some(i)));
977 tensors.push((&mut mlp.up, Some(i)));
978 tensors.push((&mut mlp.down, Some(i)));
979 }
980 MoeOrMlp::Moe(moe) => {
981 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
982 tensors.push((&mut mlp.gate, Some(i)));
983 tensors.push((&mut mlp.up, Some(i)));
984 tensors.push((&mut mlp.down, Some(i)));
985 }
986 if let Some(mlp) = &mut moe.shared_experts {
987 tensors.push((&mut mlp.gate, Some(i)));
988 tensors.push((&mut mlp.up, Some(i)));
989 tensors.push((&mut mlp.down, Some(i)));
990 }
991 }
992 }
993 }
994 (tensors, &*self.mapper)
995 }
996
997 fn get_layers_moe_experts_only(
998 &mut self,
999 ) -> (
1000 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1001 &dyn DeviceMapper,
1002 ) {
1003 let mut tensors = Vec::new();
1004 tensors.push((&mut self.lm_head, None));
1005 for (i, layer) in self.layers.iter_mut().enumerate() {
1006 match &mut layer.moe_or_mlp {
1007 MoeOrMlp::Mlp(mlp) => {
1008 tensors.push((&mut mlp.gate, Some(i)));
1009 tensors.push((&mut mlp.up, Some(i)));
1010 tensors.push((&mut mlp.down, Some(i)));
1011 }
1012 MoeOrMlp::Moe(moe) => {
1013 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1014 tensors.push((&mut mlp.gate, Some(i)));
1015 tensors.push((&mut mlp.up, Some(i)));
1016 tensors.push((&mut mlp.down, Some(i)));
1017 }
1018 if let Some(mlp) = &mut moe.shared_experts {
1019 tensors.push((&mut mlp.gate, Some(i)));
1020 tensors.push((&mut mlp.up, Some(i)));
1021 tensors.push((&mut mlp.down, Some(i)));
1022 }
1023 }
1024 }
1025 }
1026 (tensors, &*self.mapper)
1027 }
1028
1029 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1030 let uvb = UnVarBuilder::new();
1031
1032 let uvb_m = uvb.pp("model");
1033 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1034 uvb_m.pp("norm").add(&self.norm);
1035
1036 for (layer_idx, layer) in self.layers.iter().enumerate() {
1037 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1038 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1039 uvb_l
1040 .pp("post_attention_layernorm")
1041 .add(&layer.post_attention_layernorm);
1042
1043 uvb_l
1044 .pp("self_attn")
1045 .pp("kv_a_layernorm")
1046 .add(&layer.attn.kv_a_layernorm);
1047
1048 match &layer.moe_or_mlp {
1049 MoeOrMlp::Moe(moe) => {
1050 uvb_l
1051 .pp("mlp")
1052 .pp("gate")
1053 .add_tensor("weight", moe.gate.weight.clone());
1054 }
1055 MoeOrMlp::Mlp(_) => (),
1056 }
1057
1058 match &layer.attn.q {
1059 QProj::Plain(_) => (),
1060 QProj::Lora { a: _, norm, b: _ } => {
1061 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1062 }
1063 }
1064 }
1065
1066 uvb.to_safetensors()
1067 }
1068
1069 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1070 let uvb = UnVarBuilder::new();
1071
1072 let uvb_m = uvb.pp("model");
1073 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1074 uvb_m.pp("norm").add(&self.norm);
1075
1076 for (layer_idx, layer) in self.layers.iter().enumerate() {
1077 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1078 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1079 uvb_l
1080 .pp("post_attention_layernorm")
1081 .add(&layer.post_attention_layernorm);
1082
1083 uvb_l
1084 .pp("self_attn")
1085 .pp("kv_a_layernorm")
1086 .add(&layer.attn.kv_a_layernorm);
1087
1088 match &layer.moe_or_mlp {
1089 MoeOrMlp::Moe(moe) => {
1090 uvb_l
1091 .pp("mlp")
1092 .pp("gate")
1093 .add_tensor("weight", moe.gate.weight.clone());
1094 }
1095 MoeOrMlp::Mlp(_) => (),
1096 }
1097
1098 match &layer.attn.q {
1099 QProj::Plain(q) => {
1100 uvb_l.pp("self_attn").pp("q_proj").add(q);
1101 }
1102 QProj::Lora { a, norm, b } => {
1103 uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1104 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1105 uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1106 }
1107 }
1108 uvb_l
1109 .pp("self_attn")
1110 .pp("kv_a_proj_with_mqa")
1111 .add(&layer.attn.kv_a_proj_with_mqa);
1112 uvb_l
1113 .pp("self_attn")
1114 .pp("kv_b_proj")
1115 .add(&layer.attn.kv_b_proj);
1116 uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1117 }
1118
1119 Some(uvb.to_safetensors())
1120 }
1121}
1122
1123impl NormalModel for DeepSeekV2 {
1124 fn forward(
1125 &self,
1126 input_ids: &Tensor,
1127 seqlen_offsets: &[usize],
1128 context_lens: Vec<(usize, usize)>,
1129 _position_ids: Vec<usize>,
1130 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1131 flash_params: &FlashParams,
1132 ) -> Result<Tensor> {
1133 self.forward(
1134 input_ids,
1135 seqlen_offsets,
1136 context_lens,
1137 metadata,
1138 flash_params,
1139 )
1140 }
1141 fn xlora_forward(
1142 &self,
1143 _input_ids: &Tensor,
1144 _input_ids_full: &Tensor,
1145 _seqlen_offsets: &[usize],
1146 _seqlen_offsets_full: &[usize],
1147 _no_kv_cache: bool,
1148 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1149 _context_lens: Vec<(usize, usize)>,
1150 _position_ids: Vec<usize>,
1151 _flash_params: &FlashParams,
1152 _flash_params_full: &FlashParams,
1153 ) -> Result<Tensor> {
1154 unimplemented!()
1155 }
1156 fn cache(&self) -> &EitherCache {
1157 &self.cache
1158 }
1159 fn cache_mut(&mut self) -> &mut EitherCache {
1160 &mut self.cache
1161 }
1162 fn device(&self) -> &Device {
1163 &self.device
1164 }
1165 fn is_xlora(&self) -> bool {
1166 false
1167 }
1168 fn max_seq_len(&self) -> usize {
1169 self.max_seq_len
1170 }
1171 fn config(&self) -> &ModelConfigMetadata {
1172 &self.cfg
1173 }
1174}
1175
1176impl AnyMoeBaseModelMixin for DeepSeekV2 {}