1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8 ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
9 RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 attention::SdpaParams,
16 device_map::DeviceMapper,
17 layers::{
18 embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19 DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20 },
21 layers_masker::{masked_fill, PastKvLenCache},
22 ops::{SplitOp, TopKLastDimOp, TopKOutput},
23 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24 pipeline::{
25 extract_logits,
26 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28 },
29 serde_default_fn,
30 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32use std::collections::HashSet;
33use std::iter::FromIterator;
34serde_default_fn!(f64, routed_scaling_factor, 1.0);
35serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
36serde_default_fn!(usize, moe_layer_freq, 1);
37serde_default_fn!(usize, first_k_dense_replace, 0);
38serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
39serde_default_fn!(Activation, hidden_act, Activation::Silu);
40serde_default_fn!(bool, tie_word_embeddings, false);
41serde_default_fn!(bool, use_flash_attn_default, false);
42
43#[derive(Deserialize, Clone, Debug)]
44enum TopkMethod {
45 #[serde(rename = "noaux_tc")]
46 NoAuxTc,
47 #[serde(rename = "greedy")]
48 Greedy,
49 #[serde(rename = "group_limited_greedy")]
50 GroupLimitedGreedy,
51}
52
53#[derive(Deserialize, Clone, Debug)]
54enum ScoringFunc {
55 #[serde(rename = "softmax")]
56 Softmax,
57 #[serde(rename = "sigmoid")]
58 Sigmoid,
59}
60
61#[derive(Deserialize, Clone, Debug)]
62pub struct DeepSeekV3Config {
63 pub(crate) vocab_size: usize,
64 pub(crate) hidden_size: usize,
65 pub(crate) intermediate_size: usize,
66 pub(crate) moe_intermediate_size: usize,
67 pub(crate) num_hidden_layers: usize,
68 pub(crate) num_attention_heads: usize,
69 pub(crate) n_shared_experts: Option<usize>,
70 pub(crate) n_routed_experts: Option<usize>,
71 #[serde(default = "routed_scaling_factor")]
72 pub(crate) routed_scaling_factor: f64,
73 #[serde(default = "topk_method")]
74 topk_method: TopkMethod,
75 pub(crate) num_experts_per_tok: Option<usize>,
76 #[serde(default = "moe_layer_freq")]
77 pub(crate) moe_layer_freq: usize,
78 #[serde(default = "first_k_dense_replace")]
79 pub(crate) first_k_dense_replace: usize,
80 #[serde(default = "scoring_func")]
81 scoring_func: ScoringFunc,
82 #[serde(default = "hidden_act")]
83 pub(crate) hidden_act: Activation,
84 pub(crate) max_position_embeddings: usize,
85 pub(crate) rms_norm_eps: f64,
86 #[serde(default = "tie_word_embeddings")]
87 pub(crate) tie_word_embeddings: bool,
88 pub(crate) rope_theta: f32,
89 pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
90 pub(crate) attention_bias: bool,
91 pub(crate) q_lora_rank: Option<usize>,
92 pub(crate) qk_rope_head_dim: usize,
93 pub(crate) kv_lora_rank: usize,
94 pub(crate) v_head_dim: usize,
95 pub(crate) qk_nope_head_dim: usize,
96 #[serde(default = "use_flash_attn_default")]
97 pub(crate) use_flash_attn: bool,
98 pub(crate) quantization_config: Option<QuantizedConfig>,
99 pub(crate) n_group: usize,
100 pub(crate) topk_group: usize,
101}
102
103impl DeepSeekV3Config {
104 pub(crate) fn q_head_dim(&self) -> usize {
105 self.qk_rope_head_dim + self.qk_nope_head_dim
106 }
107
108 fn softmax_scale(&self) -> f32 {
109 let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
110 if let Some(DeepSeekV2RopeScaling::Yarn {
111 mscale_all_dim,
112 factor,
113 ..
114 }) = self.rope_scaling
115 {
116 let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
117 softmax_scale = softmax_scale * mscale * mscale;
118 }
119 softmax_scale
120 }
121}
122
123enum QProj {
124 Plain(Arc<dyn QuantMethod>),
125 Lora {
126 a: Arc<dyn QuantMethod>,
127 norm: RmsNorm,
128 b: Arc<dyn QuantMethod>,
129 },
130}
131
132impl QProj {
133 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
134 match self {
135 Self::Lora { a, norm, b } => {
136 b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
137 }
138 Self::Plain(lin) => lin.forward_autocast(xs),
139 }
140 }
141}
142
143struct Attention {
144 q: QProj,
145 kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
146 kv_a_layernorm: RmsNorm,
147 kv_b_proj: Arc<dyn QuantMethod>,
148 o_proj: Arc<dyn QuantMethod>,
149 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
150 cfg: DeepSeekV3Config,
151 q_head_dim: usize,
152 paged_attn: Option<PagedAttention>,
153 sdpa_params: SdpaParams,
154 num_attention_heads: usize,
155}
156
157impl Attention {
158 #[allow(clippy::too_many_arguments)]
159 fn new(
160 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
161 cfg: &DeepSeekV3Config,
162 vb: ShardedVarBuilder,
163 mapper: &dyn DeviceMapper,
164 layer_idx: usize,
165 loading_isq: bool,
166 paged_attn: Option<PagedAttention>,
167 comm: &Arc<mistralrs_quant::Comm>,
168 ) -> Result<Self> {
169 let q_head_dim = cfg.q_head_dim();
170 let q = match cfg.q_lora_rank {
171 Some(lora_rank) => {
172 let a = ReplicatedLayer::new(
173 cfg.hidden_size,
174 lora_rank,
175 &cfg.quantization_config,
176 cfg.attention_bias,
177 mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
178 )?;
179 let norm = RmsNorm::new(
180 lora_rank,
181 cfg.rms_norm_eps,
182 mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
183 )?;
184 let b = ColumnParallelLayer::new(
185 lora_rank,
186 cfg.num_attention_heads * q_head_dim,
187 &cfg.quantization_config,
188 false,
189 comm,
190 mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
191 )?;
192 QProj::Lora { a, norm, b }
193 }
194 None => QProj::Plain(ColumnParallelLayer::new(
195 cfg.hidden_size,
196 cfg.num_attention_heads * q_head_dim,
197 &cfg.quantization_config,
198 false,
199 comm,
200 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
201 )?),
202 };
203
204 let kv_a_proj_with_mqa = ReplicatedLayer::new(
205 cfg.hidden_size,
206 cfg.kv_lora_rank + cfg.qk_rope_head_dim,
207 &cfg.quantization_config,
208 cfg.attention_bias,
209 mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
210 )?;
211 let kv_a_layernorm = RmsNorm::new(
212 cfg.kv_lora_rank,
213 cfg.rms_norm_eps,
214 mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
215 )?;
216 let kv_b_proj = ColumnParallelLayer::new(
217 cfg.kv_lora_rank,
218 cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
219 &cfg.quantization_config,
220 false,
221 comm,
222 mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
223 )?;
224
225 let o_proj = RowParallelLayer::new(
226 cfg.num_attention_heads * cfg.v_head_dim,
227 cfg.hidden_size,
228 &cfg.quantization_config,
229 cfg.attention_bias,
230 comm,
231 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
232 )?;
233
234 Ok(Self {
235 q,
236 kv_a_proj_with_mqa,
237 kv_a_layernorm,
238 kv_b_proj,
239 o_proj,
240 rotary_emb,
241 cfg: cfg.clone(),
242 q_head_dim,
243 paged_attn,
244 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
245 sdpa_params: SdpaParams {
246 n_kv_groups: 1,
247 use_flash_attn: cfg.use_flash_attn,
248 softcap: None,
249 softmax_scale: cfg.softmax_scale(),
250 sliding_window: None,
251 },
252 })
253 }
254
255 fn forward(
256 &self,
257 xs: &Tensor,
258 attention_mask: Option<&Tensor>,
259 seqlen_offsets: &[usize],
260 kv_cache: &mut KvCache,
261 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
262 flash_params: &FlashParams,
263 ) -> Result<Tensor> {
264 let (bs, seq_len, _) = xs.dims3()?;
265
266 let mut q = self.q.forward(xs)?;
267 q = q
268 .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
269 .transpose(1, 2)?;
270 let q_split = q.split(
271 &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
272 D::Minus1,
273 )?;
274 let q_nope = q_split[0].clone();
275 let mut q_pe = q_split[1].clone();
276
277 let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
278 let ckv_split = compressed_kv.split(
279 &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
280 D::Minus1,
281 )?;
282 compressed_kv = ckv_split[0].clone();
283 let mut k_pe = ckv_split[1].clone();
284 k_pe = k_pe
285 .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
286 .transpose(1, 2)?;
287 let mut kv = self
288 .kv_b_proj
289 .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
290 kv = kv
291 .reshape((
292 bs,
293 seq_len,
294 self.num_attention_heads,
295 self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
296 ))?
297 .transpose(1, 2)?;
298
299 let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
300 let k_nope = kv_split[0].clone();
301 let mut v = kv_split[1].clone();
302
303 (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
304
305 let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
306 let mut k = Tensor::cat(
307 &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
308 D::Minus1,
309 )?
310 .contiguous()?;
311
312 let mut attn_out = match &self.paged_attn {
313 Some(paged_attn) => match metadata {
314 Some(((key_cache, value_cache), input_metadata)) => {
315 let v = v
316 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
317 .contiguous()?;
318 paged_attn
319 .forward(
320 &q,
321 &k,
322 &v,
323 attention_mask,
324 Some(key_cache),
325 Some(value_cache),
326 input_metadata,
327 &self.sdpa_params,
328 Some(flash_params),
329 )?
330 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
331 }
332 None => {
333 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
336 assert!(attention_mask.is_some());
338 let v = v
339 .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
340 .contiguous()?;
341 paged_attn
342 .forward(
343 &q,
344 &k,
345 &v,
346 attention_mask,
347 None,
348 None,
349 &input_metadata,
350 &self.sdpa_params,
351 Some(flash_params),
352 )?
353 .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
354 }
355 },
356 None => {
357 (k, v) = kv_cache.append(&k, &v)?;
358
359 Sdpa.run_attention(
360 &q,
361 &k,
362 &v,
363 attention_mask,
364 Some(flash_params),
365 &self.sdpa_params,
366 )?
367 }
368 };
369
370 attn_out = if attention_mask.is_some() {
371 attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
372 } else {
373 attn_out.reshape((bs, seq_len, ()))?
374 };
375
376 self.o_proj.forward_autocast(&attn_out)
377 }
378}
379
380struct Expert {
381 gate: Arc<dyn QuantMethod>,
382 up: Arc<dyn QuantMethod>,
383 down: Arc<dyn QuantMethod>,
384 act: Activation,
385}
386
387impl Expert {
388 fn new(
389 cfg: &DeepSeekV3Config,
390 vb: ShardedVarBuilder,
391 hidden_size: Option<usize>,
392 intermediate_size: Option<usize>,
393 ) -> Result<Self> {
394 let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
395 let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
396
397 Ok(Self {
398 gate: ReplicatedLayer::new(
399 hidden_size,
400 intermediate_size,
401 &cfg.quantization_config,
402 false,
403 vb.pp("gate_proj"),
404 )?,
405 up: ReplicatedLayer::new(
406 hidden_size,
407 intermediate_size,
408 &cfg.quantization_config,
409 false,
410 vb.pp("up_proj"),
411 )?,
412 down: ReplicatedLayer::new(
413 intermediate_size,
414 hidden_size,
415 &cfg.quantization_config,
416 false,
417 vb.pp("down_proj"),
418 )?,
419 act: cfg.hidden_act,
420 })
421 }
422
423 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
424 let original_dtype = xs.dtype();
425 let mut xs = xs.clone();
426 if let Some(t) = self.gate.quantized_act_type() {
427 xs = xs.to_dtype(t)?;
428 }
429 let lhs = self.gate.forward(&xs)?;
430 let rhs = self.up.forward(&xs)?;
431 let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
432 &lhs,
433 &rhs,
434 self.act.try_into()?,
435 )?)?;
436 if self.gate.quantized_act_type().is_some() {
437 res = res.to_dtype(original_dtype)?;
438 }
439 Ok(res)
440 }
441}
442
443struct MoeGate {
444 weight: Tensor,
445 cfg: DeepSeekV3Config,
446 top_k: usize,
447 n_routed_experts: usize,
448 e_score_correction_bias: Option<Tensor>,
449}
450
451impl MoeGate {
452 fn new(cfg: &DeepSeekV3Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
453 let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
454 let e_score_correction_bias = if matches!(cfg.topk_method, TopkMethod::NoAuxTc) {
455 Some(vb.get_with_hints_dtype(
456 n_routed_experts,
457 "e_score_correction_bias",
458 Default::default(),
459 DType::F32,
460 )?)
461 } else {
462 None
463 };
464 Ok(Self {
465 weight,
466 cfg: cfg.clone(),
467 top_k: cfg.num_experts_per_tok.unwrap(),
468 n_routed_experts,
469 e_score_correction_bias,
470 })
471 }
472
473 fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
475 let (bs, seq_len, h) = xs.dims3()?;
476 let xs = xs.reshape(((), h))?;
478 let logits = xs
479 .to_dtype(DType::F32)?
480 .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
481 let scores = match self.cfg.scoring_func {
482 ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
483 ScoringFunc::Sigmoid => candle_nn::ops::sigmoid(&logits)?,
484 };
485
486 let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
488 TopkMethod::Greedy => {
489 let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
490 (values, indices)
491 }
492 TopkMethod::NoAuxTc => {
493 let Some(e_score_correction_bias) = &self.e_score_correction_bias else {
494 candle_core::bail!("Expected e_score_correction_bias")
495 };
496 let scores_for_choice = scores
497 .reshape((bs * seq_len, ()))?
498 .broadcast_add(&e_score_correction_bias.unsqueeze(0)?)?;
499 let group_scores = scores_for_choice
501 .reshape((bs * seq_len, self.cfg.n_group, ()))?
502 .topk(2)?
503 .values
504 .sum(D::Minus1)?;
505 let group_idx = group_scores.topk(self.cfg.topk_group)?.indices;
507 let mut group_mask = group_scores.zeros_like()?;
509 group_mask = group_mask.scatter_add(
511 &group_idx,
512 &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
513 1,
514 )?;
515 let score_mask = group_mask
517 .unsqueeze(D::Minus1)?
518 .expand((
519 bs * seq_len,
520 self.cfg.n_group,
521 self.n_routed_experts / self.cfg.n_group,
522 ))?
523 .reshape((bs * seq_len, ()))?;
524 let tmp_scores = scores_for_choice.broadcast_mul(&score_mask)?;
527 let topk_idx = tmp_scores.topk(self.top_k)?.indices;
528 (scores.gather(&topk_idx, 1)?, topk_idx)
529 }
530 TopkMethod::GroupLimitedGreedy => {
531 let group_scores = scores
533 .reshape((bs * seq_len, self.cfg.n_group, ()))?
534 .max(D::Minus1)?;
535 let group_idx = group_scores.topk_unsorted(self.cfg.topk_group)?.indices;
537 let mut group_mask = group_scores.zeros_like()?;
539 group_mask = group_mask.scatter_add(
541 &group_idx,
542 &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
543 1,
544 )?;
545 let score_mask = group_mask
547 .unsqueeze(D::Minus1)?
548 .expand((
549 bs * seq_len,
550 self.cfg.n_group,
551 self.n_routed_experts / self.cfg.n_group,
552 ))?
553 .reshape((bs * seq_len, ()))?;
554 let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
557 let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
558 (values, indices)
559 }
560 };
561
562 if matches!(self.cfg.scoring_func, ScoringFunc::Sigmoid) {
563 let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
564 topk_weight = topk_weight.broadcast_div(&denmoninator)?;
565 }
566
567 topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
569
570 Ok((topk_idx, topk_weight))
571 }
572}
573
574struct Moe {
575 experts: Vec<Option<Expert>>,
576 shared_experts: Option<Mlp>,
577 gate: MoeGate,
578 all_reduce: SumAllReduce,
579 experts_start_idx: usize,
580 experts_end_idx: usize,
581 world_size: usize,
582}
583
584impl Moe {
585 #[allow(clippy::too_many_arguments)]
586 fn new(
587 cfg: &DeepSeekV3Config,
588 vb: ShardedVarBuilder,
589 mapper: &dyn DeviceMapper,
590 layer_idx: usize,
591 loading_isq: bool,
592 n_shared_experts: Option<usize>,
593 n_routed_experts: usize,
594 comm: &Arc<mistralrs_quant::Comm>,
595 ) -> Result<Self> {
596 let mut experts = Vec::with_capacity(n_routed_experts);
597 let n_local_experts = n_routed_experts / comm.world_size();
598 let experts_start_idx = comm.rank() * n_local_experts;
599 let experts_end_idx = experts_start_idx + n_local_experts;
600 for i in 0..n_routed_experts {
601 if i >= experts_start_idx && i < experts_end_idx {
602 let vb_e = vb.pp("experts").pp(i);
603 experts.push(Some(Expert::new(
604 cfg,
605 mapper.set_device(layer_idx, vb_e, loading_isq),
606 None,
607 Some(cfg.moe_intermediate_size),
608 )?));
609 } else {
610 experts.push(None);
611 }
612 }
613 let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
614 let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
615 Some(Mlp::new(
616 mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
617 cfg.hidden_size,
618 intermediate_size,
619 &cfg.quantization_config,
620 cfg.hidden_act,
621 comm,
622 )?)
623 } else {
624 None
625 };
626 let gate = MoeGate::new(
627 cfg,
628 mapper.set_device(layer_idx, vb.pp("gate"), false),
629 n_routed_experts,
630 )?;
631 Ok(Self {
632 experts,
633 shared_experts,
634 gate,
635 all_reduce: SumAllReduce::new(comm),
636 experts_end_idx,
637 experts_start_idx,
638 world_size: comm.world_size(),
639 })
640 }
641
642 fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
643 let mut y = xs.zeros_like()?;
644 let topk_weight = if topk_weight.dtype() != xs.dtype() {
645 topk_weight.to_dtype(xs.dtype())?
646 } else {
647 topk_weight.to_owned()
648 };
649 let unique_ids: HashSet<u32> =
650 HashSet::from_iter(topk_ids.to_device(&Device::Cpu)?.flatten_all()?.to_vec1()?);
651 for i in self.experts_start_idx..self.experts_end_idx {
652 if !unique_ids.contains(&(i as u32)) {
653 continue;
654 }
655 let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
656 let idx = &idx_top.i(0)?.contiguous()?;
657 let top = &idx_top.i(1)?.contiguous()?;
658
659 let expert = self.experts[i]
660 .as_ref()
661 .context("Expert is not present for this rank.")?;
662
663 y = y.index_add(
664 idx,
665 &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
666 &topk_weight
667 .index_select(idx, 0)?
668 .gather(&top.unsqueeze(1)?, 1)?
669 .squeeze(1)?
670 .unsqueeze(D::Minus1)?,
671 )?,
672 0,
673 )?;
674 }
675
676 if self.world_size > 1 {
677 y = self.all_reduce.sum_all_reduce(&y)?;
678 }
679
680 Ok(y)
681 }
682
683 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
684 let identity = xs.clone();
685 let orig_shape = xs.shape();
686 let (topk_idx, topk_weight) = self.gate.forward(xs)?;
687 let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
688
689 let mut y = self
690 .moe_infer(&xs, &topk_idx, &topk_weight)?
691 .reshape(orig_shape)?;
692 if let Some(ref shared_experts) = self.shared_experts {
693 y = (y + shared_experts.forward(&identity)?)?;
694 }
695 Ok(y)
696 }
697}
698
699enum MoeOrMlp {
700 Moe(Moe),
701 Mlp(Mlp),
702}
703
704impl MoeOrMlp {
705 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
706 match self {
707 Self::Mlp(mlp) => mlp.forward(xs),
708 Self::Moe(moe) => moe.forward(xs),
709 }
710 }
711}
712
713struct DecoderLayer {
714 input_layernorm: RmsNorm,
715 post_attention_layernorm: RmsNorm,
716 attn: Attention,
717 moe_or_mlp: MoeOrMlp,
718}
719
720impl DecoderLayer {
721 #[allow(clippy::too_many_arguments)]
722 fn new(
723 rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
724 cfg: &DeepSeekV3Config,
725 vb: ShardedVarBuilder,
726 mapper: &dyn DeviceMapper,
727 layer_idx: usize,
728 loading_isq: bool,
729 paged_attn: Option<PagedAttention>,
730 comm: &Arc<mistralrs_quant::Comm>,
731 ) -> Result<Self> {
732 let attn = Attention::new(
733 rotary_emb,
734 cfg,
735 vb.pp("self_attn"),
736 mapper,
737 layer_idx,
738 loading_isq,
739 paged_attn,
740 comm,
741 )?;
742 let input_layernorm = RmsNorm::new(
743 cfg.hidden_size,
744 cfg.rms_norm_eps,
745 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
746 )?;
747 let post_attention_layernorm = RmsNorm::new(
748 cfg.hidden_size,
749 cfg.rms_norm_eps,
750 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
751 )?;
752 let moe_or_mlp = if cfg.n_routed_experts.is_some()
753 && layer_idx >= cfg.first_k_dense_replace
754 && layer_idx % cfg.moe_layer_freq == 0
755 {
756 MoeOrMlp::Moe(Moe::new(
757 cfg,
758 vb.pp("mlp"),
759 mapper,
760 layer_idx,
761 loading_isq,
762 cfg.n_shared_experts,
763 cfg.n_routed_experts.unwrap(),
764 comm,
765 )?)
766 } else {
767 MoeOrMlp::Mlp(Mlp::new(
768 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
769 cfg.hidden_size,
770 cfg.intermediate_size,
771 &cfg.quantization_config,
772 cfg.hidden_act,
773 comm,
774 )?)
775 };
776
777 Ok(Self {
778 input_layernorm,
779 post_attention_layernorm,
780 attn,
781 moe_or_mlp,
782 })
783 }
784
785 fn forward(
786 &self,
787 xs: &Tensor,
788 attention_mask: Option<&Tensor>,
789 seqlen_offsets: &[usize],
790 kv_cache: &mut KvCache,
791 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
792 flash_params: &FlashParams,
793 ) -> Result<Tensor> {
794 let residual = xs;
795 let xs = self.input_layernorm.forward(xs)?;
796 let xs = self.attn.forward(
797 &xs,
798 attention_mask,
799 seqlen_offsets,
800 kv_cache,
801 metadata,
802 flash_params,
803 )?;
804 let xs = (xs + residual)?;
805 let residual = &xs;
806 let xs = self
807 .moe_or_mlp
808 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
809 residual + xs
810 }
811}
812
813pub struct DeepSeekV3 {
814 lm_head: Arc<dyn QuantMethod>,
815 embed_tokens: Embedding,
816 norm: RmsNorm,
817 layers: Vec<DecoderLayer>,
818 cache: EitherCache,
819 device: Device,
820 max_seq_len: usize,
821 cfg: ModelConfigMetadata,
822 mapper: Box<dyn DeviceMapper + Send + Sync>,
823}
824
825impl DeepSeekV3 {
826 pub fn new(
827 cfg: &DeepSeekV3Config,
828 vb: ShardedVarBuilder,
829 _is_gptx: bool,
830 normal_loading_metadata: NormalLoadingMetadata,
831 attention_mechanism: AttentionImplementation,
832 ) -> Result<Self> {
833 let vb_m = vb.pp("model");
834
835 let mapper = normal_loading_metadata.mapper;
836
837 let embed_tokens = embedding(
838 cfg.vocab_size,
839 cfg.hidden_size,
840 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
841 &cfg.quantization_config,
842 )?;
843 let lm_head = if !cfg.tie_word_embeddings {
844 ReplicatedLayer::new(
845 cfg.hidden_size,
846 cfg.vocab_size,
847 &None,
848 false,
849 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
850 )?
851 } else {
852 ReplicatedLayer::from_linear(candle_nn::Linear::new(
853 mapper.cast_nm_device(
854 embed_tokens.embeddings(),
855 normal_loading_metadata.loading_isq,
856 )?,
857 None,
858 ))?
859 };
860 let norm = RmsNorm::new(
861 cfg.hidden_size,
862 cfg.rms_norm_eps,
863 mapper.set_nm_device(vb_m.pp("norm"), false),
864 )?;
865
866 let mut ropes = HashMap::new();
867 let rope_cfg = DeepSeekV2RopeConfig {
868 rope_scaling: cfg.rope_scaling.clone(),
869 max_position_embeddings: cfg.max_position_embeddings,
870 rope_theta: cfg.rope_theta,
871 qk_rope_head_dim: cfg.qk_rope_head_dim,
872 };
873 for i in 0..cfg.num_hidden_layers {
874 let device = mapper
875 .device_for(i, false)
876 .unwrap_or(&normal_loading_metadata.real_device);
877 ropes.insert(
878 device.location(),
879 Arc::new(DeepSeekV2RotaryEmbedding::new(
880 &rope_cfg,
881 vb.dtype(),
882 device,
883 )?),
884 );
885 }
886
887 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
888 let vb_l = vb_m.pp("layers");
889 for layer_idx in NiceProgressBar::<_, 'b'>(
890 0..cfg.num_hidden_layers,
891 "Loading repeating layers",
892 &normal_loading_metadata.multi_progress,
893 ) {
894 let device = mapper
895 .device_for(layer_idx, false)
896 .unwrap_or(&normal_loading_metadata.real_device);
897 let rotary_emb = ropes
898 .get(&device.location())
899 .expect("No RoPE for device location!")
900 .clone();
901 let paged_attn = match &attention_mechanism {
902 AttentionImplementation::Eager => None,
903 AttentionImplementation::PagedAttention => Some(
904 PagedAttention::new(cfg.v_head_dim, device, None)
905 .expect("Failed to create PagedAttention"),
906 ),
907 };
908 let comm = mapper.get_comm_for(layer_idx)?;
909 let layer = DecoderLayer::new(
910 rotary_emb.clone(),
911 cfg,
912 vb_l.pp(layer_idx),
913 &*mapper,
914 layer_idx,
915 normal_loading_metadata.loading_isq,
916 paged_attn,
917 &comm,
918 )?;
919 layers.push(layer)
920 }
921
922 Ok(Self {
923 lm_head,
924 embed_tokens,
925 norm,
926 layers,
927 cache: EitherCache::Normal(NormalCache::new(
928 cfg.num_hidden_layers,
929 cfg.max_position_embeddings,
930 )),
931 device: normal_loading_metadata.real_device.clone(),
932 max_seq_len: cfg.max_position_embeddings,
933 cfg: ModelConfigMetadata {
934 max_seq_len: cfg.max_position_embeddings,
935 num_layers: cfg.num_hidden_layers,
936 hidden_size: cfg.hidden_size,
937 num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
938 .max(1),
939 num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
940 .max(1),
941 sliding_window: None,
942 k_head_dim: cfg.q_head_dim(),
943 v_head_dim: if matches!(
944 attention_mechanism,
945 AttentionImplementation::PagedAttention
946 ) {
947 cfg.q_head_dim()
948 } else {
949 cfg.v_head_dim
950 },
951 },
952 mapper,
953 })
954 }
955
956 #[allow(clippy::too_many_arguments)]
957 pub fn forward(
958 &self,
959 input_ids: &Tensor,
960 seqlen_offsets: &[usize],
961 context_lens: Vec<(usize, usize)>,
962 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
963 flash_params: &FlashParams,
964 ) -> Result<Tensor> {
965 let mut xs = self.embed_tokens.forward(input_ids)?;
966 let cache = &mut self.cache.normal().0;
967 let attention_mask = CausalMasker.make_causal_mask_matrix(
968 input_ids,
969 metadata
970 .as_ref()
971 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
972 .unwrap_or(cache as &dyn PastKvLenCache),
973 xs.dtype(),
974 self.cfg.num_attn_heads,
975 )?;
976 let attention_mask = attention_mask.filter(|_| {
978 metadata
979 .as_ref()
980 .map(|(_, meta)| meta.is_first_prompt_chunk)
981 .unwrap_or(true)
982 });
983 for (i, layer) in self.layers.iter().enumerate() {
984 xs = self.mapper.map(xs, i)?;
985 xs = layer.forward(
986 &xs,
987 attention_mask
988 .as_ref()
989 .map(|m| m.to_device(xs.device()).unwrap())
990 .as_ref(),
991 seqlen_offsets,
992 &mut cache[i],
993 metadata
994 .as_ref()
995 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
996 flash_params,
997 )?;
998 }
999 let xs = xs.to_device(&self.device)?;
1000 let xs = xs.apply(&self.norm)?;
1001 extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
1002 }
1003}
1004
1005impl IsqModel for DeepSeekV3 {
1006 fn get_layers(
1007 &mut self,
1008 ) -> (
1009 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1010 &dyn DeviceMapper,
1011 ) {
1012 let mut tensors = Vec::new();
1013 tensors.push((&mut self.lm_head, None));
1014 for (i, layer) in self.layers.iter_mut().enumerate() {
1015 match &mut layer.attn.q {
1016 QProj::Plain(q) => {
1017 tensors.push((q, Some(i)));
1018 }
1019 QProj::Lora { a, norm: _, b } => {
1020 tensors.push((a, Some(i)));
1021 tensors.push((b, Some(i)));
1022 }
1023 }
1024 tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
1025 tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
1026 tensors.push((&mut layer.attn.o_proj, Some(i)));
1027 match &mut layer.moe_or_mlp {
1028 MoeOrMlp::Mlp(mlp) => {
1029 tensors.push((&mut mlp.gate, Some(i)));
1030 tensors.push((&mut mlp.up, Some(i)));
1031 tensors.push((&mut mlp.down, Some(i)));
1032 }
1033 MoeOrMlp::Moe(moe) => {
1034 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1035 tensors.push((&mut mlp.gate, Some(i)));
1036 tensors.push((&mut mlp.up, Some(i)));
1037 tensors.push((&mut mlp.down, Some(i)));
1038 }
1039 if let Some(mlp) = &mut moe.shared_experts {
1040 tensors.push((&mut mlp.gate, Some(i)));
1041 tensors.push((&mut mlp.up, Some(i)));
1042 tensors.push((&mut mlp.down, Some(i)));
1043 }
1044 }
1045 }
1046 }
1047 (tensors, &*self.mapper)
1048 }
1049
1050 fn get_layers_moe_experts_only(
1051 &mut self,
1052 ) -> (
1053 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1054 &dyn DeviceMapper,
1055 ) {
1056 let mut tensors = Vec::new();
1057 tensors.push((&mut self.lm_head, None));
1058 for (i, layer) in self.layers.iter_mut().enumerate() {
1059 match &mut layer.moe_or_mlp {
1060 MoeOrMlp::Mlp(mlp) => {
1061 tensors.push((&mut mlp.gate, Some(i)));
1062 tensors.push((&mut mlp.up, Some(i)));
1063 tensors.push((&mut mlp.down, Some(i)));
1064 }
1065 MoeOrMlp::Moe(moe) => {
1066 for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1067 tensors.push((&mut mlp.gate, Some(i)));
1068 tensors.push((&mut mlp.up, Some(i)));
1069 tensors.push((&mut mlp.down, Some(i)));
1070 }
1071 if let Some(mlp) = &mut moe.shared_experts {
1072 tensors.push((&mut mlp.gate, Some(i)));
1073 tensors.push((&mut mlp.up, Some(i)));
1074 tensors.push((&mut mlp.down, Some(i)));
1075 }
1076 }
1077 }
1078 }
1079 (tensors, &*self.mapper)
1080 }
1081
1082 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1083 let uvb = UnVarBuilder::new();
1084
1085 let uvb_m = uvb.pp("model");
1086 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1087 uvb_m.pp("norm").add(&self.norm);
1088
1089 for (layer_idx, layer) in self.layers.iter().enumerate() {
1090 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1091 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1092 uvb_l
1093 .pp("post_attention_layernorm")
1094 .add(&layer.post_attention_layernorm);
1095
1096 uvb_l
1097 .pp("self_attn")
1098 .pp("kv_a_layernorm")
1099 .add(&layer.attn.kv_a_layernorm);
1100
1101 match &layer.moe_or_mlp {
1102 MoeOrMlp::Moe(moe) => {
1103 uvb_l
1104 .pp("mlp")
1105 .pp("gate")
1106 .add_tensor("weight", moe.gate.weight.clone());
1107 }
1108 MoeOrMlp::Mlp(_) => (),
1109 }
1110
1111 match &layer.attn.q {
1112 QProj::Plain(_) => (),
1113 QProj::Lora { a: _, norm, b: _ } => {
1114 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1115 }
1116 }
1117 }
1118
1119 uvb.to_safetensors()
1120 }
1121
1122 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1123 let uvb = UnVarBuilder::new();
1124
1125 let uvb_m = uvb.pp("model");
1126 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1127 uvb_m.pp("norm").add(&self.norm);
1128
1129 for (layer_idx, layer) in self.layers.iter().enumerate() {
1130 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1131 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1132 uvb_l
1133 .pp("post_attention_layernorm")
1134 .add(&layer.post_attention_layernorm);
1135
1136 uvb_l
1137 .pp("self_attn")
1138 .pp("kv_a_layernorm")
1139 .add(&layer.attn.kv_a_layernorm);
1140
1141 match &layer.moe_or_mlp {
1142 MoeOrMlp::Moe(moe) => {
1143 uvb_l
1144 .pp("mlp")
1145 .pp("gate")
1146 .add_tensor("weight", moe.gate.weight.clone());
1147 }
1148 MoeOrMlp::Mlp(_) => (),
1149 }
1150
1151 match &layer.attn.q {
1152 QProj::Plain(q) => {
1153 uvb_l.pp("self_attn").pp("q_proj").add(q);
1154 }
1155 QProj::Lora { a, norm, b } => {
1156 uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1157 uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1158 uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1159 }
1160 }
1161 uvb_l
1162 .pp("self_attn")
1163 .pp("kv_a_proj_with_mqa")
1164 .add(&layer.attn.kv_a_proj_with_mqa);
1165 uvb_l
1166 .pp("self_attn")
1167 .pp("kv_b_proj")
1168 .add(&layer.attn.kv_b_proj);
1169 uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1170 }
1171
1172 Some(uvb.to_safetensors())
1173 }
1174}
1175
1176impl NormalModel for DeepSeekV3 {
1177 fn forward(
1178 &self,
1179 input_ids: &Tensor,
1180 seqlen_offsets: &[usize],
1181 context_lens: Vec<(usize, usize)>,
1182 _position_ids: Vec<usize>,
1183 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1184 flash_params: &FlashParams,
1185 ) -> Result<Tensor> {
1186 self.forward(
1187 input_ids,
1188 seqlen_offsets,
1189 context_lens,
1190 metadata,
1191 flash_params,
1192 )
1193 }
1194 fn xlora_forward(
1195 &self,
1196 _input_ids: &Tensor,
1197 _input_ids_full: &Tensor,
1198 _seqlen_offsets: &[usize],
1199 _seqlen_offsets_full: &[usize],
1200 _no_kv_cache: bool,
1201 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1202 _context_lens: Vec<(usize, usize)>,
1203 _position_ids: Vec<usize>,
1204 _flash_params: &FlashParams,
1205 _flash_params_full: &FlashParams,
1206 ) -> Result<Tensor> {
1207 unimplemented!()
1208 }
1209 fn cache(&self) -> &EitherCache {
1210 &self.cache
1211 }
1212 fn cache_mut(&mut self) -> &mut EitherCache {
1213 &mut self.cache
1214 }
1215 fn device(&self) -> &Device {
1216 &self.device
1217 }
1218 fn is_xlora(&self) -> bool {
1219 false
1220 }
1221 fn max_seq_len(&self) -> usize {
1222 self.max_seq_len
1223 }
1224 fn config(&self) -> &ModelConfigMetadata {
1225 &self.cfg
1226 }
1227}
1228
1229impl AnyMoeBaseModelMixin for DeepSeekV3 {}