1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8 ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
9 RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 attention::SdpaParams,
16 device_map::DeviceMapper,
17 layers::{
18 embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19 DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20 },
21 layers_masker::{masked_fill, PastKvLenCache},
22 ops::{SplitOp, TopKLastDimOp, TopKOutput},
23 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24 pipeline::{
25 extract_logits,
26 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28 },
29 serde_default_fn,
30 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32use std::collections::HashSet;
33use std::iter::FromIterator;
34serde_default_fn!(f64, routed_scaling_factor, 1.0);
35serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
36serde_default_fn!(usize, moe_layer_freq, 1);
37serde_default_fn!(usize, first_k_dense_replace, 0);
38serde_default_fn!(bool, norm_topk_prob, false);
39serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
40serde_default_fn!(Activation, hidden_act, Activation::Silu);
41serde_default_fn!(bool, tie_word_embeddings, false);
42serde_default_fn!(bool, use_flash_attn_default, false);
43
44#[derive(Deserialize, Clone, Debug)]
45enum TopkMethod {
46 #[serde(rename = "greedy")]
47 Greedy,
48 #[serde(rename = "group_limited_greedy")]
49 GroupLimitedGreedy,
50}
51
52#[derive(Deserialize, Clone, Debug)]
53enum ScoringFunc {
54 #[serde(rename = "softmax")]
55 Softmax,
56}
57
58#[derive(Deserialize, Clone, Debug)]
59pub struct DeepSeekV2Config {
60 pub(crate) vocab_size: usize,
61 pub(crate) hidden_size: usize,
62 pub(crate) intermediate_size: usize,
63 pub(crate) moe_intermediate_size: usize,
64 pub(crate) num_hidden_layers: usize,
65 pub(crate) num_attention_heads: usize,
66 pub(crate) n_shared_experts: Option<usize>,
67 pub(crate) n_routed_experts: Option<usize>,
68 #[serde(default = "routed_scaling_factor")]
69 pub(crate) routed_scaling_factor: f64,
70 #[serde(default = "topk_method")]
71 topk_method: TopkMethod,
72 pub(crate) num_experts_per_tok: Option<usize>,
73 #[serde(default = "moe_layer_freq")]
74 pub(crate) moe_layer_freq: usize,
75 #[serde(default = "first_k_dense_replace")]
76 pub(crate) first_k_dense_replace: usize,
77 #[serde(default = "norm_topk_prob")]
79 pub(crate) norm_topk_prob: bool,
80 #[serde(default = "scoring_func")]
81 scoring_func: ScoringFunc,
82 #[serde(default = "hidden_act")]
83 pub(crate) hidden_act: Activation,
84 pub(crate) max_position_embeddings: usize,
85 pub(crate) rms_norm_eps: f64,
86 #[serde(default = "tie_word_embeddings")]
87 pub(crate) tie_word_embeddings: bool,
88 pub(crate) rope_theta: f32,
89 pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
90 pub(crate) attention_bias: bool,
91 pub(crate) q_lora_rank: Option<usize>,
92 pub(crate) qk_rope_head_dim: usize,
93 pub(crate) kv_lora_rank: usize,
94 pub(crate) v_head_dim: usize,
95 pub(crate) qk_nope_head_dim: usize,
96 #[serde(default = "use_flash_attn_default")]
97 pub(crate) use_flash_attn: bool,
98 pub(crate) quantization_config: Option<QuantizedConfig>,
99 pub(crate) n_group: usize,
100 pub(crate) topk_group: usize,
101}
102
103impl DeepSeekV2Config {
104 pub(crate) fn q_head_dim(&self) -> usize {
105 self.qk_rope_head_dim + self.qk_nope_head_dim
106 }
107
108 fn softmax_scale(&self) -> f32 {
109 let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
110 if let Some(DeepSeekV2RopeScaling::Yarn {
111 mscale_all_dim,
112 factor,
113 ..
114 }) = self.rope_scaling
115 {
116 let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
117 softmax_scale = softmax_scale * mscale * mscale;
118 }
119 softmax_scale
120 }
121}
122
123enum QProj {
124 Plain(Arc<dyn QuantMethod>),
125 Lora {
126 a: Arc<dyn QuantMethod>,
127 norm: RmsNorm,
128 b: Arc<dyn QuantMethod>,
129 },
130}
131
132impl QProj {
133 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
134 match self {
135 Self::Lora { a, norm, b } => {
136 b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
137 }
138 Self::Plain(lin) => lin.forward_autocast(xs),
139 }
140 }
141}
142
143struct Attention {
144 q: QProj,
145 kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
146 kv_a_layernorm: RmsNorm,
147 kv_b_proj: Arc<dyn QuantMethod>,
148 o_proj: Arc<dyn QuantMethod>,
149 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
150 cfg: DeepSeekV2Config,
151 q_head_dim: usize,
152 paged_attn: Option<PagedAttention>,
153 sdpa_params: SdpaParams,
154 num_attention_heads: usize,
155}
156
157impl Attention {
158 #[allow(clippy::too_many_arguments)]
159 fn new(
160 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
161 cfg: &DeepSeekV2Config,
162 vb: ShardedVarBuilder,
163 mapper: &dyn DeviceMapper,
164 layer_idx: usize,
165 loading_isq: bool,
166 paged_attn: Option<PagedAttention>,
167 comm: &Arc<mistralrs_quant::Comm>,
168 ) -> Result<Self> {
169 let q_head_dim = cfg.q_head_dim();
170 let q = match cfg.q_lora_rank {
171 Some(lora_rank) => {
172 let a = ReplicatedLayer::new(
173 cfg.hidden_size,
174 lora_rank,
175 &cfg.quantization_config,
176 cfg.attention_bias,
177 mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
178 )?;
179 let norm = RmsNorm::new(
180 lora_rank,
181 cfg.rms_norm_eps,
182 mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
183 )?;
184 let b = ColumnParallelLayer::new(
185 lora_rank,
186 cfg.num_attention_heads * q_head_dim,
187 &cfg.quantization_config,
188 false,
189 comm,
190 mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
191 )?;
192 QProj::Lora { a, norm, b }
193 }
194 None => QProj::Plain(ColumnParallelLayer::new(
195 cfg.hidden_size,
196 cfg.num_attention_heads * q_head_dim,
197 &cfg.quantization_config,
198 false,
199 comm,
200 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
201 )?),
202 };
203
204 let kv_a_proj_with_mqa = ReplicatedLayer::new(
205 cfg.hidden_size,
206 cfg.kv_lora_rank + cfg.qk_rope_head_dim,
207 &cfg.quantization_config,
208 cfg.attention_bias,
209 mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
210 )?;
211 let kv_a_layernorm = RmsNorm::new(
212 cfg.kv_lora_rank,
213 cfg.rms_norm_eps,
214 mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
215 )?;
216 let kv_b_proj = ColumnParallelLayer::new(
217 cfg.kv_lora_rank,
218 cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
219 &cfg.quantization_config,
220 false,
221 comm,
222 mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
223 )?;
224
225 let o_proj = RowParallelLayer::new(
226 cfg.num_attention_heads * cfg.v_head_dim,
227 cfg.hidden_size,
228 &cfg.quantization_config,
229 cfg.attention_bias,
230 comm,
231 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
232 )?;
233
234 Ok(Self {
235 q,
236 kv_a_proj_with_mqa,
237 kv_a_layernorm,
238 kv_b_proj,
239 o_proj,
240 rotary_emb,
241 cfg: cfg.clone(),
242 q_head_dim,
243 paged_attn,
244 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
245 sdpa_params: SdpaParams {
246 n_kv_groups: 1,
247 use_flash_attn: cfg.use_flash_attn,
248 softcap: None,
249 softmax_scale: cfg.softmax_scale(),
250 sliding_window: None,
251 },
252 })
253 }
254
255 fn forward(
256 &self,
257 xs: &Tensor,
258 attention_mask: Option<&Tensor>,
259 seqlen_offsets: &[usize],
260 kv_cache: &mut KvCache,
261 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
262 flash_params: &FlashParams,
263 ) -> Result<Tensor> {
264 let (bs, seq_len, _) = xs.dims3()?;
265
266 let mut q = self.q.forward(xs)?;
267 q = q
268 .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
269 .transpose(1, 2)?;
270 let q_split = q.split(
271 &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
272 D::Minus1,
273 )?;
274 let q_nope = q_split[0].clone();
275 let mut q_pe = q_split[1].clone();
276
277 let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
278 let ckv_split = compressed_kv.split(
279 &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
280 D::Minus1,
281 )?;
282 compressed_kv = ckv_split[0].clone();
283 let mut k_pe = ckv_split[1].clone();
284 k_pe = k_pe
285 .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
286 .transpose(1, 2)?;
287 let mut kv = self
288 .kv_b_proj
289 .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
290 kv = kv
291 .reshape((
292 bs,
293 seq_len,
294 self.num_attention_heads,
295 self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
296 ))?
297 .transpose(1, 2)?;
298
299 let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
300 let k_nope = kv_split[0].clone();
301 let mut v = kv_split[1].clone();
302
303 (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
304
305 let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
306 let mut k = Tensor::cat(
307 &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
308 D::Minus1,
309 )?
310 .contiguous()?;
311
312 let mut attn_out = match &self.paged_attn {
313 Some(paged_attn) => match metadata {
314 Some(((key_cache, value_cache), input_metadata)) => {
315 let v = v
316 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
317 .contiguous()?;
318 paged_attn
319 .forward(
320 &q,
321 &k,
322 &v,
323 attention_mask,
324 Some(key_cache),
325 Some(value_cache),
326 input_metadata,
327 &self.sdpa_params,
328 Some(flash_params),
329 )?
330 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
331 }
332 None => {
333 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
336 assert!(attention_mask.is_some());
338 let v = v
339 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
340 .contiguous()?;
341 paged_attn
342 .forward(
343 &q,
344 &k,
345 &v,
346 attention_mask,
347 None,
348 None,
349 &input_metadata,
350 &self.sdpa_params,
351 Some(flash_params),
352 )?
353 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
354 }
355 },
356 None => {
357 (k, v) = kv_cache.append(&k, &v)?;
358
359 Sdpa.run_attention(
360 &q,
361 &k,
362 &v,
363 attention_mask,
364 Some(flash_params),
365 &self.sdpa_params,
366 )?
367 }
368 };
369
370 attn_out = if attention_mask.is_some() {
371 attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
372 } else {
373 attn_out.reshape((bs, seq_len, ()))?
374 };
375
376 self.o_proj.forward_autocast(&attn_out)
377 }
378}
379
380struct Expert {
381 gate: Arc<dyn QuantMethod>,
382 up: Arc<dyn QuantMethod>,
383 down: Arc<dyn QuantMethod>,
384 act: Activation,
385}
386
387impl Expert {
388 fn new(
389 cfg: &DeepSeekV2Config,
390 vb: ShardedVarBuilder,
391 hidden_size: Option<usize>,
392 intermediate_size: Option<usize>,
393 ) -> Result<Self> {
394 let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
395 let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
396
397 Ok(Self {
398 gate: ReplicatedLayer::new(
399 hidden_size,
400 intermediate_size,
401 &cfg.quantization_config,
402 false,
403 vb.pp("gate_proj"),
404 )?,
405 up: ReplicatedLayer::new(
406 hidden_size,
407 intermediate_size,
408 &cfg.quantization_config,
409 false,
410 vb.pp("up_proj"),
411 )?,
412 down: ReplicatedLayer::new(
413 intermediate_size,
414 hidden_size,
415 &cfg.quantization_config,
416 false,
417 vb.pp("down_proj"),
418 )?,
419 act: cfg.hidden_act,
420 })
421 }
422
423 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
424 let original_dtype = xs.dtype();
425 let mut xs = xs.clone();
426 if let Some(t) = self.gate.quantized_act_type() {
427 xs = xs.to_dtype(t)?;
428 }
429 let lhs = self.gate.forward(&xs)?;
430 let rhs = self.up.forward(&xs)?;
431 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
432 &lhs,
433 &rhs,
434 self.act.try_into()?,
435 )?)?;
436 if self.gate.quantized_act_type().is_some() {
437 res = res.to_dtype(original_dtype)?;
438 }
439 Ok(res)
440 }
441}
442
443struct MoeGate {
444 weight: Tensor,
445 cfg: DeepSeekV2Config,
446 top_k: usize,
447 n_routed_experts: usize,
448}
449
450impl MoeGate {
451 fn new(cfg: &DeepSeekV2Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
452 let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
453 Ok(Self {
454 weight,
455 cfg: cfg.clone(),
456 top_k: cfg.num_experts_per_tok.unwrap(),
457 n_routed_experts,
458 })
459 }
460
461 fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
463 let (bs, seq_len, h) = xs.dims3()?;
464 let xs = xs.reshape(((), h))?;
466 let logits = xs
467 .to_dtype(DType::F32)?
468 .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
469 let scores = match self.cfg.scoring_func {
470 ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
471 };
472
473 let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
475 TopkMethod::Greedy => {
476 let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
477 (values, indices)
478 }
479 TopkMethod::GroupLimitedGreedy => {
480 let group_scores = scores
482 .reshape((bs * seq_len, self.cfg.n_group, ()))?
483 .max(D::Minus1)?;
484 let group_idx = group_scores.topk_unsorted(self.cfg.topk_group)?.indices;
486 let mut group_mask = group_scores.zeros_like()?;
488 group_mask = group_mask.scatter_add(
490 &group_idx,
491 &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
492 1,
493 )?;
494 let score_mask = group_mask
496 .unsqueeze(D::Minus1)?
497 .expand((
498 bs * seq_len,
499 self.cfg.n_group,
500 self.n_routed_experts / self.cfg.n_group,
501 ))?
502 .reshape((bs * seq_len, ()))?;
503 let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
506 let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
507 (values, indices)
508 }
509 };
510
511 if self.top_k > 1 && self.cfg.norm_topk_prob {
512 let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
513 topk_weight = (topk_weight / denmoninator)?;
514 } else {
515 topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
516 }
517 Ok((topk_idx, topk_weight))
518 }
519}
520
521struct Moe {
522 experts: Vec<Option<Expert>>,
523 shared_experts: Option<Mlp>,
524 gate: MoeGate,
525 all_reduce: SumAllReduce,
526 experts_start_idx: usize,
527 experts_end_idx: usize,
528 world_size: usize,
529}
530
531impl Moe {
532 #[allow(clippy::too_many_arguments)]
533 fn new(
534 cfg: &DeepSeekV2Config,
535 vb: ShardedVarBuilder,
536 mapper: &dyn DeviceMapper,
537 layer_idx: usize,
538 loading_isq: bool,
539 n_shared_experts: Option<usize>,
540 n_routed_experts: usize,
541 comm: &Arc<mistralrs_quant::Comm>,
542 ) -> Result<Self> {
543 let mut experts = Vec::with_capacity(n_routed_experts);
544 let n_local_experts = n_routed_experts / comm.world_size();
545 let experts_start_idx = comm.rank() * n_local_experts;
546 let experts_end_idx = experts_start_idx + n_local_experts;
547 for i in 0..n_routed_experts {
548 if i >= experts_start_idx && i < experts_end_idx {
549 let vb_e = vb.pp("experts").pp(i);
550 experts.push(Some(Expert::new(
551 cfg,
552 mapper.set_device(layer_idx, vb_e, loading_isq),
553 None,
554 Some(cfg.moe_intermediate_size),
555 )?));
556 } else {
557 experts.push(None);
558 }
559 }
560 let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
561 let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
562 Some(Mlp::new(
563 mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
564 cfg.hidden_size,
565 intermediate_size,
566 &cfg.quantization_config,
567 cfg.hidden_act,
568 comm,
569 )?)
570 } else {
571 None
572 };
573 let gate = MoeGate::new(
574 cfg,
575 mapper.set_device(layer_idx, vb.pp("gate"), false),
576 n_routed_experts,
577 )?;
578 Ok(Self {
579 experts,
580 shared_experts,
581 gate,
582 all_reduce: SumAllReduce::new(comm),
583 experts_end_idx,
584 experts_start_idx,
585 world_size: comm.world_size(),
586 })
587 }
588
589 fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
590 let mut y = xs.zeros_like()?;
591 let topk_weight = if topk_weight.dtype() != xs.dtype() {
592 topk_weight.to_dtype(xs.dtype())?
593 } else {
594 topk_weight.to_owned()
595 };
596 let unique_ids: HashSet<u32> =
597 HashSet::from_iter(topk_ids.to_device(&Device::Cpu)?.flatten_all()?.to_vec1()?);
598 for i in self.experts_start_idx..self.experts_end_idx {
599 if !unique_ids.contains(&(i as u32)) {
600 continue;
601 }
602 let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
603 let idx = &idx_top.i(0)?.contiguous()?;
604 let top = &idx_top.i(1)?.contiguous()?;
605
606 let expert = self.experts[i]
607 .as_ref()
608 .context("Expert is not present for this rank.")?;
609
610 y = y.index_add(
611 idx,
612 &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
613 &topk_weight
614 .index_select(idx, 0)?
615 .gather(&top.unsqueeze(1)?, 1)?
616 .squeeze(1)?
617 .unsqueeze(D::Minus1)?,
618 )?,
619 0,
620 )?;
621 }
622
623 if self.world_size > 1 {
624 y = self.all_reduce.sum_all_reduce(&y)?;
625 }
626
627 Ok(y)
628 }
629
630 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
631 let identity = xs.clone();
632 let orig_shape = xs.shape();
633 let (topk_idx, topk_weight) = self.gate.forward(xs)?;
634 let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
635
636 let mut y = self
637 .moe_infer(&xs, &topk_idx, &topk_weight)?
638 .reshape(orig_shape)?;
639 if let Some(ref shared_experts) = self.shared_experts {
640 y = (y + shared_experts.forward(&identity)?)?;
641 }
642 Ok(y)
643 }
644}
645
646enum MoeOrMlp {
647 Moe(Moe),
648 Mlp(Mlp),
649}
650
651impl MoeOrMlp {
652 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
653 match self {
654 Self::Mlp(mlp) => mlp.forward(xs),
655 Self::Moe(moe) => moe.forward(xs),
656 }
657 }
658}
659
660struct DecoderLayer {
661 input_layernorm: RmsNorm,
662 post_attention_layernorm: RmsNorm,
663 attn: Attention,
664 moe_or_mlp: MoeOrMlp,
665}
666
667impl DecoderLayer {
668 #[allow(clippy::too_many_arguments)]
669 fn new(
670 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
671 cfg: &DeepSeekV2Config,
672 vb: ShardedVarBuilder,
673 mapper: &dyn DeviceMapper,
674 layer_idx: usize,
675 loading_isq: bool,
676 paged_attn: Option<PagedAttention>,
677 comm: &Arc<mistralrs_quant::Comm>,
678 ) -> Result<Self> {
679 let attn = Attention::new(
680 rotary_emb,
681 cfg,
682 vb.pp("self_attn"),
683 mapper,
684 layer_idx,
685 loading_isq,
686 paged_attn,
687 comm,
688 )?;
689 let input_layernorm = RmsNorm::new(
690 cfg.hidden_size,
691 cfg.rms_norm_eps,
692 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
693 )?;
694 let post_attention_layernorm = RmsNorm::new(
695 cfg.hidden_size,
696 cfg.rms_norm_eps,
697 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
698 )?;
699 let moe_or_mlp = if cfg.n_routed_experts.is_some()
700 && layer_idx >= cfg.first_k_dense_replace
701 && layer_idx % cfg.moe_layer_freq == 0
702 {
703 MoeOrMlp::Moe(Moe::new(
704 cfg,
705 vb.pp("mlp"),
706 mapper,
707 layer_idx,
708 loading_isq,
709 cfg.n_shared_experts,
710 cfg.n_routed_experts.unwrap(),
711 comm,
712 )?)
713 } else {
714 MoeOrMlp::Mlp(Mlp::new(
715 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
716 cfg.hidden_size,
717 cfg.intermediate_size,
718 &cfg.quantization_config,
719 cfg.hidden_act,
720 comm,
721 )?)
722 };
723
724 Ok(Self {
725 input_layernorm,
726 post_attention_layernorm,
727 attn,
728 moe_or_mlp,
729 })
730 }
731
732 fn forward(
733 &self,
734 xs: &Tensor,
735 attention_mask: Option<&Tensor>,
736 seqlen_offsets: &[usize],
737 kv_cache: &mut KvCache,
738 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
739 flash_params: &FlashParams,
740 ) -> Result<Tensor> {
741 let residual = xs;
742 let xs = self.input_layernorm.forward(xs)?;
743 let xs = self.attn.forward(
744 &xs,
745 attention_mask,
746 seqlen_offsets,
747 kv_cache,
748 metadata,
749 flash_params,
750 )?;
751 let xs = (xs + residual)?;
752 let residual = &xs;
753 let xs = self
754 .moe_or_mlp
755 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
756 residual + xs
757 }
758}
759
760pub struct DeepSeekV2 {
761 lm_head: Arc<dyn QuantMethod>,
762 embed_tokens: Embedding,
763 norm: RmsNorm,
764 layers: Vec<DecoderLayer>,
765 cache: EitherCache,
766 device: Device,
767 max_seq_len: usize,
768 cfg: ModelConfigMetadata,
769 mapper: Box<dyn DeviceMapper + Send + Sync>,
770}
771
772impl DeepSeekV2 {
773 pub fn new(
774 cfg: &DeepSeekV2Config,
775 vb: ShardedVarBuilder,
776 _is_gptx: bool,
777 normal_loading_metadata: NormalLoadingMetadata,
778 attention_mechanism: AttentionImplementation,
779 ) -> Result<Self> {
780 let vb_m = vb.pp("model");
781
782 let mapper = normal_loading_metadata.mapper;
783
784 let embed_tokens = embedding(
785 cfg.vocab_size,
786 cfg.hidden_size,
787 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
788 &cfg.quantization_config,
789 )?;
790 let lm_head = if !cfg.tie_word_embeddings {
791 ReplicatedLayer::new(
792 cfg.hidden_size,
793 cfg.vocab_size,
794 &None,
795 false,
796 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
797 )?
798 } else {
799 ReplicatedLayer::from_linear(candle_nn::Linear::new(
800 mapper.cast_nm_device(
801 embed_tokens.embeddings(),
802 normal_loading_metadata.loading_isq,
803 )?,
804 None,
805 ))?
806 };
807 let norm = RmsNorm::new(
808 cfg.hidden_size,
809 cfg.rms_norm_eps,
810 mapper.set_nm_device(vb_m.pp("norm"), false),
811 )?;
812
813 let mut ropes = HashMap::new();
814 let rope_cfg = DeepSeekV2RopeConfig {
815 rope_scaling: cfg.rope_scaling.clone(),
816 max_position_embeddings: cfg.max_position_embeddings,
817 rope_theta: cfg.rope_theta,
818 qk_rope_head_dim: cfg.qk_rope_head_dim,
819 };
820 for i in 0..cfg.num_hidden_layers {
821 let device = mapper
822 .device_for(i, false)
823 .unwrap_or(&normal_loading_metadata.real_device);
824 ropes.insert(
825 device.location(),
826 Arc::new(DeepSeekV2RotaryEmbedding::new(
827 &rope_cfg,
828 vb.dtype(),
829 device,
830 )?),
831 );
832 }
833
834 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
835 let vb_l = vb_m.pp("layers");
836 for layer_idx in NiceProgressBar::<_, 'b'>(
837 0..cfg.num_hidden_layers,
838 "Loading repeating layers",
839 &normal_loading_metadata.multi_progress,
840 ) {
841 let device = mapper
842 .device_for(layer_idx, false)
843 .unwrap_or(&normal_loading_metadata.real_device);
844 let rotary_emb = ropes
845 .get(&device.location())
846 .expect("No RoPE for device location!")
847 .clone();
848 let paged_attn = match &attention_mechanism {
849 AttentionImplementation::Eager => None,
850 AttentionImplementation::PagedAttention => Some(
851 PagedAttention::new(cfg.v_head_dim, device, None)
852 .expect("Failed to create PagedAttention"),
853 ),
854 };
855 let comm = mapper.get_comm_for(layer_idx)?;
856 let layer = DecoderLayer::new(
857 rotary_emb.clone(),
858 cfg,
859 vb_l.pp(layer_idx),
860 &*mapper,
861 layer_idx,
862 normal_loading_metadata.loading_isq,
863 paged_attn,
864 &comm,
865 )?;
866 layers.push(layer)
867 }
868
869 Ok(Self {
870 lm_head,
871 embed_tokens,
872 norm,
873 layers,
874 cache: EitherCache::Normal(NormalCache::new(
875 cfg.num_hidden_layers,
876 cfg.max_position_embeddings,
877 )),
878 device: normal_loading_metadata.real_device.clone(),
879 max_seq_len: cfg.max_position_embeddings,
880 cfg: ModelConfigMetadata {
881 max_seq_len: cfg.max_position_embeddings,
882 num_layers: cfg.num_hidden_layers,
883 hidden_size: cfg.hidden_size,
884 num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
885 .max(1),
886 num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
887 .max(1),
888 sliding_window: None,
889 k_head_dim: cfg.q_head_dim(),
890 v_head_dim: if matches!(
891 attention_mechanism,
892 AttentionImplementation::PagedAttention
893 ) {
894 cfg.q_head_dim()
895 } else {
896 cfg.v_head_dim
897 },
898 },
899 mapper,
900 })
901 }
902
903 #[allow(clippy::too_many_arguments)]
904 pub fn forward(
905 &self,
906 input_ids: &Tensor,
907 seqlen_offsets: &[usize],
908 context_lens: Vec<(usize, usize)>,
909 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
910 flash_params: &FlashParams,
911 ) -> Result<Tensor> {
912 let mut xs = self.embed_tokens.forward(input_ids)?;
913 let cache = &mut self.cache.normal().0;
914 let attention_mask = CausalMasker.make_causal_mask_matrix(
915 input_ids,
916 metadata
917 .as_ref()
918 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
919 .unwrap_or(cache as &dyn PastKvLenCache),
920 xs.dtype(),
921 self.cfg.num_attn_heads,
922 )?;
923 let attention_mask = attention_mask.filter(|_| {
925 metadata
926 .as_ref()
927 .map(|(_, meta)| meta.is_first_prompt_chunk)
928 .unwrap_or(true)
929 });
930 for (i, layer) in self.layers.iter().enumerate() {
931 xs = self.mapper.map(xs, i)?;
932 xs = layer.forward(
933 &xs,
934 attention_mask
935 .as_ref()
936 .map(|m| m.to_device(xs.device()).unwrap())
937 .as_ref(),
938 seqlen_offsets,
939 &mut cache[i],
940 metadata
941 .as_ref()
942 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
943 flash_params,
944 )?;
945 }
946 let xs = xs.to_device(&self.device)?;
947 let xs = xs.apply(&self.norm)?;
948 extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
949 }
950}
951
952impl IsqModel for DeepSeekV2 {
953 fn get_layers(
954 &mut self,
955 ) -> (
956 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
957 &dyn DeviceMapper,
958 ) {
959 let mut tensors = Vec::new();
960 tensors.push((&mut self.lm_head, None));
961 for (i, layer) in self.layers.iter_mut().enumerate() {
962 match &mut layer.attn.q {
963 QProj::Plain(q) => {
964 tensors.push((q, Some(i)));
965 }
966 QProj::Lora { a, norm: _, b } => {
967 tensors.push((a, Some(i)));
968 tensors.push((b, Some(i)));
969 }
970 }
971 tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
972 tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
973 tensors.push((&mut layer.attn.o_proj, Some(i)));
974 match &mut layer.moe_or_mlp {
975 MoeOrMlp::Mlp(mlp) => {
976 tensors.push((&mut mlp.gate, Some(i)));
977 tensors.push((&mut mlp.up, Some(i)));
978 tensors.push((&mut mlp.down, Some(i)));
979 }
980 MoeOrMlp::Moe(moe) => {
981 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
982 tensors.push((&mut mlp.gate, Some(i)));
983 tensors.push((&mut mlp.up, Some(i)));
984 tensors.push((&mut mlp.down, Some(i)));
985 }
986 if let Some(mlp) = &mut moe.shared_experts {
987 tensors.push((&mut mlp.gate, Some(i)));
988 tensors.push((&mut mlp.up, Some(i)));
989 tensors.push((&mut mlp.down, Some(i)));
990 }
991 }
992 }
993 }
994 (tensors, &*self.mapper)
995 }
996
997 fn get_layers_moe_experts_only(
998 &mut self,
999 ) -> (
1000 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1001 &dyn DeviceMapper,
1002 ) {
1003 let mut tensors = Vec::new();
1004 tensors.push((&mut self.lm_head, None));
1005 for (i, layer) in self.layers.iter_mut().enumerate() {
1006 match &mut layer.moe_or_mlp {
1007 MoeOrMlp::Mlp(mlp) => {
1008 tensors.push((&mut mlp.gate, Some(i)));
1009 tensors.push((&mut mlp.up, Some(i)));
1010 tensors.push((&mut mlp.down, Some(i)));
1011 }
1012 MoeOrMlp::Moe(moe) => {
1013 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1014 tensors.push((&mut mlp.gate, Some(i)));
1015 tensors.push((&mut mlp.up, Some(i)));
1016 tensors.push((&mut mlp.down, Some(i)));
1017 }
1018 if let Some(mlp) = &mut moe.shared_experts {
1019 tensors.push((&mut mlp.gate, Some(i)));
1020 tensors.push((&mut mlp.up, Some(i)));
1021 tensors.push((&mut mlp.down, Some(i)));
1022 }
1023 }
1024 }
1025 }
1026 (tensors, &*self.mapper)
1027 }
1028
1029 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1030 let uvb = UnVarBuilder::new();
1031
1032 let uvb_m = uvb.pp("model");
1033 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1034 uvb_m.pp("norm").add(&self.norm);
1035
1036 for (layer_idx, layer) in self.layers.iter().enumerate() {
1037 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1038 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1039 uvb_l
1040 .pp("post_attention_layernorm")
1041 .add(&layer.post_attention_layernorm);
1042
1043 uvb_l
1044 .pp("self_attn")
1045 .pp("kv_a_layernorm")
1046 .add(&layer.attn.kv_a_layernorm);
1047
1048 match &layer.moe_or_mlp {
1049 MoeOrMlp::Moe(moe) => {
1050 uvb_l
1051 .pp("mlp")
1052 .pp("gate")
1053 .add_tensor("weight", moe.gate.weight.clone());
1054 }
1055 MoeOrMlp::Mlp(_) => (),
1056 }
1057
1058 match &layer.attn.q {
1059 QProj::Plain(_) => (),
1060 QProj::Lora { a: _, norm, b: _ } => {
1061 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1062 }
1063 }
1064 }
1065
1066 uvb.to_safetensors()
1067 }
1068
1069 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1070 let uvb = UnVarBuilder::new();
1071
1072 let uvb_m = uvb.pp("model");
1073 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1074 uvb_m.pp("norm").add(&self.norm);
1075
1076 for (layer_idx, layer) in self.layers.iter().enumerate() {
1077 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1078 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1079 uvb_l
1080 .pp("post_attention_layernorm")
1081 .add(&layer.post_attention_layernorm);
1082
1083 uvb_l
1084 .pp("self_attn")
1085 .pp("kv_a_layernorm")
1086 .add(&layer.attn.kv_a_layernorm);
1087
1088 match &layer.moe_or_mlp {
1089 MoeOrMlp::Moe(moe) => {
1090 uvb_l
1091 .pp("mlp")
1092 .pp("gate")
1093 .add_tensor("weight", moe.gate.weight.clone());
1094 }
1095 MoeOrMlp::Mlp(_) => (),
1096 }
1097
1098 match &layer.attn.q {
1099 QProj::Plain(q) => {
1100 uvb_l.pp("self_attn").pp("q_proj").add(q);
1101 }
1102 QProj::Lora { a, norm, b } => {
1103 uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1104 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1105 uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1106 }
1107 }
1108 uvb_l
1109 .pp("self_attn")
1110 .pp("kv_a_proj_with_mqa")
1111 .add(&layer.attn.kv_a_proj_with_mqa);
1112 uvb_l
1113 .pp("self_attn")
1114 .pp("kv_b_proj")
1115 .add(&layer.attn.kv_b_proj);
1116 uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1117 }
1118
1119 Some(uvb.to_safetensors())
1120 }
1121}
1122
1123impl NormalModel for DeepSeekV2 {
1124 fn forward(
1125 &self,
1126 input_ids: &Tensor,
1127 seqlen_offsets: &[usize],
1128 context_lens: Vec<(usize, usize)>,
1129 _position_ids: Vec<usize>,
1130 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1131 flash_params: &FlashParams,
1132 ) -> Result<Tensor> {
1133 self.forward(
1134 input_ids,
1135 seqlen_offsets,
1136 context_lens,
1137 metadata,
1138 flash_params,
1139 )
1140 }
1141 fn xlora_forward(
1142 &self,
1143 _input_ids: &Tensor,
1144 _input_ids_full: &Tensor,
1145 _seqlen_offsets: &[usize],
1146 _seqlen_offsets_full: &[usize],
1147 _no_kv_cache: bool,
1148 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1149 _context_lens: Vec<(usize, usize)>,
1150 _position_ids: Vec<usize>,
1151 _flash_params: &FlashParams,
1152 _flash_params_full: &FlashParams,
1153 ) -> Result<Tensor> {
1154 unimplemented!()
1155 }
1156 fn cache(&self) -> &EitherCache {
1157 &self.cache
1158 }
1159 fn cache_mut(&mut self) -> &mut EitherCache {
1160 &mut self.cache
1161 }
1162 fn device(&self) -> &Device {
1163 &self.device
1164 }
1165 fn is_xlora(&self) -> bool {
1166 false
1167 }
1168 fn max_seq_len(&self) -> usize {
1169 self.max_seq_len
1170 }
1171 fn config(&self) -> &ModelConfigMetadata {
1172 &self.cfg
1173 }
1174}
1175
1176impl AnyMoeBaseModelMixin for DeepSeekV2 {}