1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8 distributed::DistributedOperation, ColumnParallelLayer, QuantMethod, QuantizedConfig,
9 ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 attention::SdpaParams,
16 device_map::DeviceMapper,
17 layers::{
18 embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19 DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20 },
21 layers_masker::{masked_fill, PastKvLenCache},
22 ops::{BincountOp, NonZeroOp, SplitOp, TopKLastDimOp, TopKOutput},
23 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24 pipeline::{
25 extract_logits,
26 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28 },
29 serde_default_fn,
30 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32
33serde_default_fn!(f64, routed_scaling_factor, 1.0);
34serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
35serde_default_fn!(usize, moe_layer_freq, 1);
36serde_default_fn!(usize, first_k_dense_replace, 0);
37serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
38serde_default_fn!(Activation, hidden_act, Activation::Silu);
39serde_default_fn!(bool, tie_word_embeddings, false);
40serde_default_fn!(bool, use_flash_attn_default, false);
41
42#[derive(Deserialize, Clone, Debug)]
43enum TopkMethod {
44 #[serde(rename = "noaux_tc")]
45 NoAuxTc,
46 #[serde(rename = "greedy")]
47 Greedy,
48 #[serde(rename = "group_limited_greedy")]
49 GroupLimitedGreedy,
50}
51
52#[derive(Deserialize, Clone, Debug)]
53enum ScoringFunc {
54 #[serde(rename = "softmax")]
55 Softmax,
56 #[serde(rename = "sigmoid")]
57 Sigmoid,
58}
59
60#[derive(Deserialize, Clone, Debug)]
61pub struct DeepSeekV3Config {
62 pub(crate) vocab_size: usize,
63 pub(crate) hidden_size: usize,
64 pub(crate) intermediate_size: usize,
65 pub(crate) moe_intermediate_size: usize,
66 pub(crate) num_hidden_layers: usize,
67 pub(crate) num_attention_heads: usize,
68 pub(crate) n_shared_experts: Option<usize>,
69 pub(crate) n_routed_experts: Option<usize>,
70 #[serde(default = "routed_scaling_factor")]
71 pub(crate) routed_scaling_factor: f64,
72 #[serde(default = "topk_method")]
73 topk_method: TopkMethod,
74 pub(crate) num_experts_per_tok: Option<usize>,
75 #[serde(default = "moe_layer_freq")]
76 pub(crate) moe_layer_freq: usize,
77 #[serde(default = "first_k_dense_replace")]
78 pub(crate) first_k_dense_replace: usize,
79 #[serde(default = "scoring_func")]
80 scoring_func: ScoringFunc,
81 #[serde(default = "hidden_act")]
82 pub(crate) hidden_act: Activation,
83 pub(crate) max_position_embeddings: usize,
84 pub(crate) rms_norm_eps: f64,
85 #[serde(default = "tie_word_embeddings")]
86 pub(crate) tie_word_embeddings: bool,
87 pub(crate) rope_theta: f32,
88 pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
89 pub(crate) attention_bias: bool,
90 pub(crate) q_lora_rank: Option<usize>,
91 pub(crate) qk_rope_head_dim: usize,
92 pub(crate) kv_lora_rank: usize,
93 pub(crate) v_head_dim: usize,
94 pub(crate) qk_nope_head_dim: usize,
95 #[serde(default = "use_flash_attn_default")]
96 pub(crate) use_flash_attn: bool,
97 pub(crate) quantization_config: Option<QuantizedConfig>,
98 pub(crate) n_group: usize,
99 pub(crate) topk_group: usize,
100}
101
102impl DeepSeekV3Config {
103 pub(crate) fn q_head_dim(&self) -> usize {
104 self.qk_rope_head_dim + self.qk_nope_head_dim
105 }
106
107 fn softmax_scale(&self) -> f32 {
108 let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
109 if let Some(DeepSeekV2RopeScaling::Yarn {
110 mscale_all_dim,
111 factor,
112 ..
113 }) = self.rope_scaling
114 {
115 let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
116 softmax_scale = softmax_scale * mscale * mscale;
117 }
118 softmax_scale
119 }
120}
121
122enum QProj {
123 Plain(Arc<dyn QuantMethod>),
124 Lora {
125 a: Arc<dyn QuantMethod>,
126 norm: RmsNorm,
127 b: Arc<dyn QuantMethod>,
128 },
129}
130
131impl QProj {
132 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
133 match self {
134 Self::Lora { a, norm, b } => {
135 b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
136 }
137 Self::Plain(lin) => lin.forward_autocast(xs),
138 }
139 }
140}
141
142struct Attention {
143 q: QProj,
144 kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
145 kv_a_layernorm: RmsNorm,
146 kv_b_proj: Arc<dyn QuantMethod>,
147 o_proj: Arc<dyn QuantMethod>,
148 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
149 cfg: DeepSeekV3Config,
150 q_head_dim: usize,
151 paged_attn: Option<PagedAttention>,
152 sdpa_params: SdpaParams,
153 num_attention_heads: usize,
154}
155
156impl Attention {
157 #[allow(clippy::too_many_arguments)]
158 fn new(
159 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
160 cfg: &DeepSeekV3Config,
161 vb: ShardedVarBuilder,
162 mapper: &dyn DeviceMapper,
163 layer_idx: usize,
164 loading_isq: bool,
165 paged_attn: Option<PagedAttention>,
166 comm: &Arc<mistralrs_quant::Comm>,
167 ) -> Result<Self> {
168 let q_head_dim = cfg.q_head_dim();
169 let q = match cfg.q_lora_rank {
170 Some(lora_rank) => {
171 let a = ReplicatedLayer::new(
172 cfg.hidden_size,
173 lora_rank,
174 &cfg.quantization_config,
175 cfg.attention_bias,
176 mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
177 )?;
178 let norm = RmsNorm::new(
179 lora_rank,
180 cfg.rms_norm_eps,
181 mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
182 )?;
183 let b = ColumnParallelLayer::new(
184 lora_rank,
185 cfg.num_attention_heads * q_head_dim,
186 &cfg.quantization_config,
187 false,
188 comm,
189 mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
190 )?;
191 QProj::Lora { a, norm, b }
192 }
193 None => QProj::Plain(ColumnParallelLayer::new(
194 cfg.hidden_size,
195 cfg.num_attention_heads * q_head_dim,
196 &cfg.quantization_config,
197 false,
198 comm,
199 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
200 )?),
201 };
202
203 let kv_a_proj_with_mqa = ReplicatedLayer::new(
204 cfg.hidden_size,
205 cfg.kv_lora_rank + cfg.qk_rope_head_dim,
206 &cfg.quantization_config,
207 cfg.attention_bias,
208 mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
209 )?;
210 let kv_a_layernorm = RmsNorm::new(
211 cfg.kv_lora_rank,
212 cfg.rms_norm_eps,
213 mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
214 )?;
215 let kv_b_proj = ColumnParallelLayer::new(
216 cfg.kv_lora_rank,
217 cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
218 &cfg.quantization_config,
219 false,
220 comm,
221 mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
222 )?;
223
224 let o_proj = RowParallelLayer::new(
225 cfg.num_attention_heads * cfg.v_head_dim,
226 cfg.hidden_size,
227 &cfg.quantization_config,
228 cfg.attention_bias,
229 comm,
230 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
231 )?;
232
233 Ok(Self {
234 q,
235 kv_a_proj_with_mqa,
236 kv_a_layernorm,
237 kv_b_proj,
238 o_proj,
239 rotary_emb,
240 cfg: cfg.clone(),
241 q_head_dim,
242 paged_attn,
243 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
244 sdpa_params: SdpaParams {
245 n_kv_groups: 1,
246 use_flash_attn: cfg.use_flash_attn,
247 softcap: None,
248 softmax_scale: cfg.softmax_scale(),
249 sliding_window: None,
250 },
251 })
252 }
253
254 fn forward(
255 &self,
256 xs: &Tensor,
257 attention_mask: Option<&Tensor>,
258 seqlen_offsets: &[usize],
259 kv_cache: &mut KvCache,
260 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
261 flash_params: &FlashParams,
262 ) -> Result<Tensor> {
263 let (bs, seq_len, _) = xs.dims3()?;
264
265 let mut q = self.q.forward(xs)?;
266 q = q
267 .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
268 .transpose(1, 2)?;
269 let q_split = q.split(
270 &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
271 D::Minus1,
272 )?;
273 let q_nope = q_split[0].clone();
274 let mut q_pe = q_split[1].clone();
275
276 let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
277 let ckv_split = compressed_kv.split(
278 &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
279 D::Minus1,
280 )?;
281 compressed_kv = ckv_split[0].clone();
282 let mut k_pe = ckv_split[1].clone();
283 k_pe = k_pe
284 .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
285 .transpose(1, 2)?;
286 let mut kv = self
287 .kv_b_proj
288 .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
289 kv = kv
290 .reshape((
291 bs,
292 seq_len,
293 self.num_attention_heads,
294 self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
295 ))?
296 .transpose(1, 2)?;
297
298 let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
299 let k_nope = kv_split[0].clone();
300 let mut v = kv_split[1].clone();
301
302 (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
303
304 let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
305 let mut k = Tensor::cat(
306 &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
307 D::Minus1,
308 )?
309 .contiguous()?;
310
311 let mut attn_out = match &self.paged_attn {
312 Some(paged_attn) => match metadata {
313 Some(((key_cache, value_cache), input_metadata)) => {
314 let v = v
315 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
316 .contiguous()?;
317 paged_attn
318 .forward(
319 &q,
320 &k,
321 &v,
322 attention_mask,
323 Some(key_cache),
324 Some(value_cache),
325 input_metadata,
326 &self.sdpa_params,
327 Some(flash_params),
328 )?
329 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
330 }
331 None => {
332 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
335 assert!(attention_mask.is_some());
337 let v = v
338 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
339 .contiguous()?;
340 paged_attn
341 .forward(
342 &q,
343 &k,
344 &v,
345 attention_mask,
346 None,
347 None,
348 &input_metadata,
349 &self.sdpa_params,
350 Some(flash_params),
351 )?
352 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
353 }
354 },
355 None => {
356 (k, v) = kv_cache.append(&k, &v)?;
357
358 Sdpa.run_attention(
359 &q,
360 &k,
361 &v,
362 attention_mask,
363 Some(flash_params),
364 &self.sdpa_params,
365 )?
366 }
367 };
368
369 attn_out = if attention_mask.is_some() {
370 attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
371 } else {
372 attn_out.reshape((bs, seq_len, ()))?
373 };
374
375 self.o_proj.forward_autocast(&attn_out)
376 }
377}
378
379struct Expert {
380 gate: Arc<dyn QuantMethod>,
381 up: Arc<dyn QuantMethod>,
382 down: Arc<dyn QuantMethod>,
383 act: Activation,
384}
385
386impl Expert {
387 fn new(
388 cfg: &DeepSeekV3Config,
389 vb: ShardedVarBuilder,
390 hidden_size: Option<usize>,
391 intermediate_size: Option<usize>,
392 ) -> Result<Self> {
393 let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
394 let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
395
396 Ok(Self {
397 gate: ReplicatedLayer::new(
398 hidden_size,
399 intermediate_size,
400 &cfg.quantization_config,
401 false,
402 vb.pp("gate_proj"),
403 )?,
404 up: ReplicatedLayer::new(
405 hidden_size,
406 intermediate_size,
407 &cfg.quantization_config,
408 false,
409 vb.pp("up_proj"),
410 )?,
411 down: ReplicatedLayer::new(
412 intermediate_size,
413 hidden_size,
414 &cfg.quantization_config,
415 false,
416 vb.pp("down_proj"),
417 )?,
418 act: cfg.hidden_act,
419 })
420 }
421
422 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423 let original_dtype = xs.dtype();
424 let mut xs = xs.clone();
425 if let Some(t) = self.gate.quantized_act_type() {
426 xs = xs.to_dtype(t)?;
427 }
428 let lhs = self.gate.forward(&xs)?;
429 let rhs = self.up.forward(&xs)?;
430 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
431 &lhs,
432 &rhs,
433 self.act.try_into()?,
434 )?)?;
435 if self.gate.quantized_act_type().is_some() {
436 res = res.to_dtype(original_dtype)?;
437 }
438 Ok(res)
439 }
440}
441
442struct MoeGate {
443 weight: Tensor,
444 cfg: DeepSeekV3Config,
445 top_k: usize,
446 n_routed_experts: usize,
447 e_score_correction_bias: Option<Tensor>,
448}
449
450impl MoeGate {
451 fn new(cfg: &DeepSeekV3Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
452 let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
453 let e_score_correction_bias = if matches!(cfg.topk_method, TopkMethod::NoAuxTc) {
454 Some(vb.get_with_hints_dtype(
455 n_routed_experts,
456 "e_score_correction_bias",
457 Default::default(),
458 DType::F32,
459 )?)
460 } else {
461 None
462 };
463 Ok(Self {
464 weight,
465 cfg: cfg.clone(),
466 top_k: cfg.num_experts_per_tok.unwrap(),
467 n_routed_experts,
468 e_score_correction_bias,
469 })
470 }
471
472 fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
474 let (bs, seq_len, h) = xs.dims3()?;
475 let xs = xs.reshape(((), h))?;
477 let logits = xs
478 .to_dtype(DType::F32)?
479 .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
480 let scores = match self.cfg.scoring_func {
481 ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
482 ScoringFunc::Sigmoid => candle_nn::ops::sigmoid(&logits)?,
483 };
484
485 let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
487 TopkMethod::Greedy => {
488 let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
489 (values, indices)
490 }
491 TopkMethod::NoAuxTc => {
492 let Some(e_score_correction_bias) = &self.e_score_correction_bias else {
493 candle_core::bail!("Expected e_score_correction_bias")
494 };
495 let scores_for_choice = scores
496 .reshape((bs * seq_len, ()))?
497 .broadcast_add(&e_score_correction_bias.unsqueeze(0)?)?;
498 let group_scores = scores_for_choice
500 .reshape((bs * seq_len, self.cfg.n_group, ()))?
501 .topk(2)?
502 .values
503 .sum(D::Minus1)?;
504 let group_idx = group_scores.topk(self.cfg.topk_group)?.indices;
506 let mut group_mask = group_scores.zeros_like()?;
508 group_mask = group_mask.scatter_add(
510 &group_idx,
511 &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
512 1,
513 )?;
514 let score_mask = group_mask
516 .unsqueeze(D::Minus1)?
517 .expand((
518 bs * seq_len,
519 self.cfg.n_group,
520 self.n_routed_experts / self.cfg.n_group,
521 ))?
522 .reshape((bs * seq_len, ()))?;
523 let tmp_scores = scores_for_choice.broadcast_mul(&score_mask)?;
526 let topk_idx = tmp_scores.topk(self.top_k)?.indices;
527 (scores.gather(&topk_idx, 1)?, topk_idx)
528 }
529 TopkMethod::GroupLimitedGreedy => {
530 let group_scores = scores
532 .reshape((bs * seq_len, self.cfg.n_group, ()))?
533 .max(D::Minus1)?;
534 let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices;
536 let mut group_mask = group_scores.zeros_like()?;
538 group_mask = group_mask.scatter_add(
540 &group_idx,
541 &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
542 1,
543 )?;
544 let score_mask = group_mask
546 .unsqueeze(D::Minus1)?
547 .expand((
548 bs * seq_len,
549 self.cfg.n_group,
550 self.n_routed_experts / self.cfg.n_group,
551 ))?
552 .reshape((bs, seq_len, ()))?;
553 let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
556 let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
557 (values, indices)
558 }
559 };
560
561 if matches!(self.cfg.scoring_func, ScoringFunc::Sigmoid) {
562 let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
563 topk_weight = topk_weight.broadcast_div(&denmoninator)?;
564 }
565
566 topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
568
569 Ok((topk_idx, topk_weight))
570 }
571}
572
573struct Moe {
574 experts: Vec<Option<Expert>>,
575 shared_experts: Option<Mlp>,
576 gate: MoeGate,
577 all_reduce: SumAllReduce,
578 experts_start_idx: usize,
579 experts_end_idx: usize,
580 world_size: usize,
581}
582
583impl Moe {
584 #[allow(clippy::too_many_arguments)]
585 fn new(
586 cfg: &DeepSeekV3Config,
587 vb: ShardedVarBuilder,
588 mapper: &dyn DeviceMapper,
589 layer_idx: usize,
590 loading_isq: bool,
591 n_shared_experts: Option<usize>,
592 n_routed_experts: usize,
593 comm: &Arc<mistralrs_quant::Comm>,
594 ) -> Result<Self> {
595 let mut experts = Vec::with_capacity(n_routed_experts);
596 let n_local_experts = n_routed_experts / comm.world_size();
597 let experts_start_idx = comm.rank() * n_local_experts;
598 let experts_end_idx = experts_start_idx + n_local_experts;
599 for i in 0..n_routed_experts {
600 if i >= experts_start_idx && i < experts_end_idx {
601 let vb_e = vb.pp("experts").pp(i);
602 experts.push(Some(Expert::new(
603 cfg,
604 mapper.set_device(layer_idx, vb_e, loading_isq),
605 None,
606 Some(cfg.moe_intermediate_size),
607 )?));
608 } else {
609 experts.push(None);
610 }
611 }
612 let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
613 let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
614 Some(Mlp::new(
615 mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
616 cfg.hidden_size,
617 intermediate_size,
618 &cfg.quantization_config,
619 cfg.hidden_act,
620 comm,
621 )?)
622 } else {
623 None
624 };
625 let gate = MoeGate::new(
626 cfg,
627 mapper.set_device(layer_idx, vb.pp("gate"), false),
628 n_routed_experts,
629 )?;
630 Ok(Self {
631 experts,
632 shared_experts,
633 gate,
634 all_reduce: SumAllReduce::new(comm),
635 experts_end_idx,
636 experts_start_idx,
637 world_size: comm.world_size(),
638 })
639 }
640
641 fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
642 let mut y = xs.zeros_like()?;
643 let counts = topk_ids
644 .flatten_all()?
645 .bincount(self.experts.len() as u32)?;
646 for (i, count) in counts
647 .iter()
648 .enumerate()
649 .take(self.experts_end_idx)
650 .skip(self.experts_start_idx)
651 {
652 if *count == 0 {
653 continue;
654 }
655 let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
656 let idx = &idx_top.i(0)?.contiguous()?;
657 let top = &idx_top.i(1)?.contiguous()?;
658
659 let expert = self.experts[i]
660 .as_ref()
661 .context("Expert is not present for this rank.")?;
662
663 y = y.index_add(
664 idx,
665 &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
666 &topk_weight
667 .index_select(idx, 0)?
668 .gather(&top.unsqueeze(1)?, 1)?
669 .squeeze(1)?
670 .unsqueeze(D::Minus1)?
671 .to_dtype(xs.dtype())?,
672 )?,
673 0,
674 )?;
675 }
676
677 if self.world_size > 1 {
678 y = self.all_reduce.sum_all_reduce(&y)?;
679 }
680
681 Ok(y)
682 }
683
684 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
685 let identity = xs.clone();
686 let orig_shape = xs.shape();
687 let (topk_idx, topk_weight) = self.gate.forward(xs)?;
688 let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
689
690 let mut y = self
691 .moe_infer(&xs, &topk_idx, &topk_weight)?
692 .reshape(orig_shape)?;
693 if let Some(ref shared_experts) = self.shared_experts {
694 y = (y + shared_experts.forward(&identity)?)?;
695 }
696 Ok(y)
697 }
698}
699
700enum MoeOrMlp {
701 Moe(Moe),
702 Mlp(Mlp),
703}
704
705impl MoeOrMlp {
706 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
707 match self {
708 Self::Mlp(mlp) => mlp.forward(xs),
709 Self::Moe(moe) => moe.forward(xs),
710 }
711 }
712}
713
714struct DecoderLayer {
715 input_layernorm: RmsNorm,
716 post_attention_layernorm: RmsNorm,
717 attn: Attention,
718 moe_or_mlp: MoeOrMlp,
719}
720
721impl DecoderLayer {
722 #[allow(clippy::too_many_arguments)]
723 fn new(
724 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
725 cfg: &DeepSeekV3Config,
726 vb: ShardedVarBuilder,
727 mapper: &dyn DeviceMapper,
728 layer_idx: usize,
729 loading_isq: bool,
730 paged_attn: Option<PagedAttention>,
731 comm: &Arc<mistralrs_quant::Comm>,
732 ) -> Result<Self> {
733 let attn = Attention::new(
734 rotary_emb,
735 cfg,
736 vb.pp("self_attn"),
737 mapper,
738 layer_idx,
739 loading_isq,
740 paged_attn,
741 comm,
742 )?;
743 let input_layernorm = RmsNorm::new(
744 cfg.hidden_size,
745 cfg.rms_norm_eps,
746 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
747 )?;
748 let post_attention_layernorm = RmsNorm::new(
749 cfg.hidden_size,
750 cfg.rms_norm_eps,
751 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
752 )?;
753 let moe_or_mlp = if cfg.n_routed_experts.is_some()
754 && layer_idx >= cfg.first_k_dense_replace
755 && layer_idx % cfg.moe_layer_freq == 0
756 {
757 MoeOrMlp::Moe(Moe::new(
758 cfg,
759 vb.pp("mlp"),
760 mapper,
761 layer_idx,
762 loading_isq,
763 cfg.n_shared_experts,
764 cfg.n_routed_experts.unwrap(),
765 comm,
766 )?)
767 } else {
768 MoeOrMlp::Mlp(Mlp::new(
769 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
770 cfg.hidden_size,
771 cfg.intermediate_size,
772 &cfg.quantization_config,
773 cfg.hidden_act,
774 comm,
775 )?)
776 };
777
778 Ok(Self {
779 input_layernorm,
780 post_attention_layernorm,
781 attn,
782 moe_or_mlp,
783 })
784 }
785
786 fn forward(
787 &self,
788 xs: &Tensor,
789 attention_mask: Option<&Tensor>,
790 seqlen_offsets: &[usize],
791 kv_cache: &mut KvCache,
792 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
793 flash_params: &FlashParams,
794 ) -> Result<Tensor> {
795 let residual = xs;
796 let xs = self.input_layernorm.forward(xs)?;
797 let xs = self.attn.forward(
798 &xs,
799 attention_mask,
800 seqlen_offsets,
801 kv_cache,
802 metadata,
803 flash_params,
804 )?;
805 let xs = (xs + residual)?;
806 let residual = &xs;
807 let xs = self
808 .moe_or_mlp
809 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
810 residual + xs
811 }
812}
813
814pub struct DeepSeekV3 {
815 lm_head: Arc<dyn QuantMethod>,
816 embed_tokens: Embedding,
817 norm: RmsNorm,
818 layers: Vec<DecoderLayer>,
819 cache: EitherCache,
820 device: Device,
821 max_seq_len: usize,
822 cfg: ModelConfigMetadata,
823 mapper: Box<dyn DeviceMapper + Send + Sync>,
824}
825
826impl DeepSeekV3 {
827 pub fn new(
828 cfg: &DeepSeekV3Config,
829 vb: ShardedVarBuilder,
830 _is_gptx: bool,
831 normal_loading_metadata: NormalLoadingMetadata,
832 attention_mechanism: AttentionImplementation,
833 ) -> Result<Self> {
834 let vb_m = vb.pp("model");
835
836 let mapper = normal_loading_metadata.mapper;
837
838 let embed_tokens = embedding(
839 cfg.vocab_size,
840 cfg.hidden_size,
841 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
842 )?;
843 let lm_head = if !cfg.tie_word_embeddings {
844 ReplicatedLayer::new(
845 cfg.hidden_size,
846 cfg.vocab_size,
847 &None,
848 false,
849 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
850 )?
851 } else {
852 ReplicatedLayer::from_linear(candle_nn::Linear::new(
853 mapper.cast_nm_device(
854 embed_tokens.embeddings(),
855 normal_loading_metadata.loading_isq,
856 )?,
857 None,
858 ))?
859 };
860 let norm = RmsNorm::new(
861 cfg.hidden_size,
862 cfg.rms_norm_eps,
863 mapper.set_nm_device(vb_m.pp("norm"), false),
864 )?;
865
866 let mut ropes = HashMap::new();
867 let rope_cfg = DeepSeekV2RopeConfig {
868 rope_scaling: cfg.rope_scaling.clone(),
869 max_position_embeddings: cfg.max_position_embeddings,
870 rope_theta: cfg.rope_theta,
871 qk_rope_head_dim: cfg.qk_rope_head_dim,
872 };
873 for i in 0..cfg.num_hidden_layers {
874 let device = mapper
875 .device_for(i, false)
876 .unwrap_or(&normal_loading_metadata.real_device);
877 ropes.insert(
878 device.location(),
879 Arc::new(DeepSeekV2RotaryEmbedding::new(
880 &rope_cfg,
881 vb.dtype(),
882 device,
883 )?),
884 );
885 }
886
887 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
888 let vb_l = vb_m.pp("layers");
889 for layer_idx in NiceProgressBar::<_, 'b'>(
890 0..cfg.num_hidden_layers,
891 "Loading repeating layers",
892 &normal_loading_metadata.multi_progress,
893 ) {
894 let device = mapper
895 .device_for(layer_idx, false)
896 .unwrap_or(&normal_loading_metadata.real_device);
897 let rotary_emb = ropes
898 .get(&device.location())
899 .expect("No RoPE for device location!")
900 .clone();
901 let paged_attn = match &attention_mechanism {
902 AttentionImplementation::Eager => None,
903 AttentionImplementation::PagedAttention => Some(
904 PagedAttention::new(cfg.v_head_dim, device, None)
905 .expect("Failed to create PagedAttention"),
906 ),
907 };
908 let comm = mapper.get_comm_for(layer_idx)?;
909 let layer = DecoderLayer::new(
910 rotary_emb.clone(),
911 cfg,
912 vb_l.pp(layer_idx),
913 &*mapper,
914 layer_idx,
915 normal_loading_metadata.loading_isq,
916 paged_attn,
917 &comm,
918 )?;
919 layers.push(layer)
920 }
921
922 Ok(Self {
923 lm_head,
924 embed_tokens,
925 norm,
926 layers,
927 cache: EitherCache::Normal(NormalCache::new(
928 cfg.num_hidden_layers,
929 cfg.max_position_embeddings,
930 )),
931 device: normal_loading_metadata.real_device.clone(),
932 max_seq_len: cfg.max_position_embeddings,
933 cfg: ModelConfigMetadata {
934 max_seq_len: cfg.max_position_embeddings,
935 num_layers: cfg.num_hidden_layers,
936 hidden_size: cfg.hidden_size,
937 num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
938 .max(1),
939 num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
940 .max(1),
941 sliding_window: None,
942 k_head_dim: cfg.q_head_dim(),
943 v_head_dim: if matches!(
944 attention_mechanism,
945 AttentionImplementation::PagedAttention
946 ) {
947 cfg.q_head_dim()
948 } else {
949 cfg.v_head_dim
950 },
951 },
952 mapper,
953 })
954 }
955
956 #[allow(clippy::too_many_arguments)]
957 pub fn forward(
958 &self,
959 input_ids: &Tensor,
960 seqlen_offsets: &[usize],
961 context_lens: Vec<(usize, usize)>,
962 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
963 flash_params: &FlashParams,
964 ) -> Result<Tensor> {
965 let mut xs = self.embed_tokens.forward(input_ids)?;
966 let cache = &mut self.cache.normal().0;
967 let attention_mask = CausalMasker.make_causal_mask_matrix(
968 input_ids,
969 metadata
970 .as_ref()
971 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
972 .unwrap_or(cache as &dyn PastKvLenCache),
973 xs.dtype(),
974 self.cfg.num_attn_heads,
975 )?;
976 let attention_mask = attention_mask.filter(|_| {
978 metadata
979 .as_ref()
980 .map(|(_, meta)| meta.is_first_prompt_chunk)
981 .unwrap_or(true)
982 });
983 for (i, layer) in self.layers.iter().enumerate() {
984 xs = self.mapper.map(xs, i)?;
985 xs = layer.forward(
986 &xs,
987 attention_mask
988 .as_ref()
989 .map(|m| m.to_device(xs.device()).unwrap())
990 .as_ref(),
991 seqlen_offsets,
992 &mut cache[i],
993 metadata
994 .as_ref()
995 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
996 flash_params,
997 )?;
998 }
999 let xs = xs.to_device(&self.device)?;
1000 let xs = xs.apply(&self.norm)?;
1001 extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
1002 }
1003}
1004
1005impl IsqModel for DeepSeekV3 {
1006 fn get_layers(
1007 &mut self,
1008 ) -> (
1009 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1010 &dyn DeviceMapper,
1011 ) {
1012 let mut tensors = Vec::new();
1013 tensors.push((&mut self.lm_head, None));
1014 for (i, layer) in self.layers.iter_mut().enumerate() {
1015 match &mut layer.attn.q {
1016 QProj::Plain(q) => {
1017 tensors.push((q, Some(i)));
1018 }
1019 QProj::Lora { a, norm: _, b } => {
1020 tensors.push((a, Some(i)));
1021 tensors.push((b, Some(i)));
1022 }
1023 }
1024 tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
1025 tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
1026 tensors.push((&mut layer.attn.o_proj, Some(i)));
1027 match &mut layer.moe_or_mlp {
1028 MoeOrMlp::Mlp(mlp) => {
1029 tensors.push((&mut mlp.gate, Some(i)));
1030 tensors.push((&mut mlp.up, Some(i)));
1031 tensors.push((&mut mlp.down, Some(i)));
1032 }
1033 MoeOrMlp::Moe(moe) => {
1034 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1035 tensors.push((&mut mlp.gate, Some(i)));
1036 tensors.push((&mut mlp.up, Some(i)));
1037 tensors.push((&mut mlp.down, Some(i)));
1038 }
1039 if let Some(mlp) = &mut moe.shared_experts {
1040 tensors.push((&mut mlp.gate, Some(i)));
1041 tensors.push((&mut mlp.up, Some(i)));
1042 tensors.push((&mut mlp.down, Some(i)));
1043 }
1044 }
1045 }
1046 }
1047 (tensors, &*self.mapper)
1048 }
1049
1050 fn get_layers_moe_experts_only(
1051 &mut self,
1052 ) -> (
1053 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1054 &dyn DeviceMapper,
1055 ) {
1056 let mut tensors = Vec::new();
1057 tensors.push((&mut self.lm_head, None));
1058 for (i, layer) in self.layers.iter_mut().enumerate() {
1059 match &mut layer.moe_or_mlp {
1060 MoeOrMlp::Mlp(mlp) => {
1061 tensors.push((&mut mlp.gate, Some(i)));
1062 tensors.push((&mut mlp.up, Some(i)));
1063 tensors.push((&mut mlp.down, Some(i)));
1064 }
1065 MoeOrMlp::Moe(moe) => {
1066 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1067 tensors.push((&mut mlp.gate, Some(i)));
1068 tensors.push((&mut mlp.up, Some(i)));
1069 tensors.push((&mut mlp.down, Some(i)));
1070 }
1071 if let Some(mlp) = &mut moe.shared_experts {
1072 tensors.push((&mut mlp.gate, Some(i)));
1073 tensors.push((&mut mlp.up, Some(i)));
1074 tensors.push((&mut mlp.down, Some(i)));
1075 }
1076 }
1077 }
1078 }
1079 (tensors, &*self.mapper)
1080 }
1081
1082 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1083 let uvb = UnVarBuilder::new();
1084
1085 let uvb_m = uvb.pp("model");
1086 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1087 uvb_m.pp("norm").add(&self.norm);
1088
1089 for (layer_idx, layer) in self.layers.iter().enumerate() {
1090 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1091 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1092 uvb_l
1093 .pp("post_attention_layernorm")
1094 .add(&layer.post_attention_layernorm);
1095
1096 uvb_l
1097 .pp("self_attn")
1098 .pp("kv_a_layernorm")
1099 .add(&layer.attn.kv_a_layernorm);
1100
1101 match &layer.moe_or_mlp {
1102 MoeOrMlp::Moe(moe) => {
1103 uvb_l
1104 .pp("mlp")
1105 .pp("gate")
1106 .add_tensor("weight", moe.gate.weight.clone());
1107 }
1108 MoeOrMlp::Mlp(_) => (),
1109 }
1110
1111 match &layer.attn.q {
1112 QProj::Plain(_) => (),
1113 QProj::Lora { a: _, norm, b: _ } => {
1114 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1115 }
1116 }
1117 }
1118
1119 uvb.to_safetensors()
1120 }
1121
1122 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1123 let uvb = UnVarBuilder::new();
1124
1125 let uvb_m = uvb.pp("model");
1126 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1127 uvb_m.pp("norm").add(&self.norm);
1128
1129 for (layer_idx, layer) in self.layers.iter().enumerate() {
1130 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1131 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1132 uvb_l
1133 .pp("post_attention_layernorm")
1134 .add(&layer.post_attention_layernorm);
1135
1136 uvb_l
1137 .pp("self_attn")
1138 .pp("kv_a_layernorm")
1139 .add(&layer.attn.kv_a_layernorm);
1140
1141 match &layer.moe_or_mlp {
1142 MoeOrMlp::Moe(moe) => {
1143 uvb_l
1144 .pp("mlp")
1145 .pp("gate")
1146 .add_tensor("weight", moe.gate.weight.clone());
1147 }
1148 MoeOrMlp::Mlp(_) => (),
1149 }
1150
1151 match &layer.attn.q {
1152 QProj::Plain(q) => {
1153 uvb_l.pp("self_attn").pp("q_proj").add(q);
1154 }
1155 QProj::Lora { a, norm, b } => {
1156 uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1157 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1158 uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1159 }
1160 }
1161 uvb_l
1162 .pp("self_attn")
1163 .pp("kv_a_proj_with_mqa")
1164 .add(&layer.attn.kv_a_proj_with_mqa);
1165 uvb_l
1166 .pp("self_attn")
1167 .pp("kv_b_proj")
1168 .add(&layer.attn.kv_b_proj);
1169 uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1170 }
1171
1172 Some(uvb.to_safetensors())
1173 }
1174}
1175
1176impl NormalModel for DeepSeekV3 {
1177 fn forward(
1178 &self,
1179 input_ids: &Tensor,
1180 seqlen_offsets: &[usize],
1181 context_lens: Vec<(usize, usize)>,
1182 _position_ids: Vec<usize>,
1183 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1184 flash_params: &FlashParams,
1185 ) -> Result<Tensor> {
1186 self.forward(
1187 input_ids,
1188 seqlen_offsets,
1189 context_lens,
1190 metadata,
1191 flash_params,
1192 )
1193 }
1194 fn xlora_forward(
1195 &self,
1196 _input_ids: &Tensor,
1197 _input_ids_full: &Tensor,
1198 _seqlen_offsets: &[usize],
1199 _seqlen_offsets_full: &[usize],
1200 _no_kv_cache: bool,
1201 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1202 _context_lens: Vec<(usize, usize)>,
1203 _position_ids: Vec<usize>,
1204 _flash_params: &FlashParams,
1205 _flash_params_full: &FlashParams,
1206 ) -> Result<Tensor> {
1207 unimplemented!()
1208 }
1209 fn cache(&self) -> &EitherCache {
1210 &self.cache
1211 }
1212 fn cache_mut(&mut self) -> &mut EitherCache {
1213 &mut self.cache
1214 }
1215 fn device(&self) -> &Device {
1216 &self.device
1217 }
1218 fn is_xlora(&self) -> bool {
1219 false
1220 }
1221 fn max_seq_len(&self) -> usize {
1222 self.max_seq_len
1223 }
1224 fn config(&self) -> &ModelConfigMetadata {
1225 &self.cfg
1226 }
1227}
1228
1229impl AnyMoeBaseModelMixin for DeepSeekV3 {}