mistralrs_core/models/
deepseek3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8    ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
9    RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14    amoe::AnyMoeBaseModelMixin,
15    attention::SdpaParams,
16    device_map::DeviceMapper,
17    layers::{
18        embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19        DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20    },
21    layers_masker::{masked_fill, PastKvLenCache},
22    ops::{SplitOp, TopKLastDimOp, TopKOutput},
23    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24    pipeline::{
25        extract_logits,
26        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28    },
29    serde_default_fn,
30    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32use std::collections::HashSet;
33use std::iter::FromIterator;
34serde_default_fn!(f64, routed_scaling_factor, 1.0);
35serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
36serde_default_fn!(usize, moe_layer_freq, 1);
37serde_default_fn!(usize, first_k_dense_replace, 0);
38serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
39serde_default_fn!(Activation, hidden_act, Activation::Silu);
40serde_default_fn!(bool, tie_word_embeddings, false);
41serde_default_fn!(bool, use_flash_attn_default, false);
42
43#[derive(Deserialize, Clone, Debug)]
44enum TopkMethod {
45    #[serde(rename = "noaux_tc")]
46    NoAuxTc,
47    #[serde(rename = "greedy")]
48    Greedy,
49    #[serde(rename = "group_limited_greedy")]
50    GroupLimitedGreedy,
51}
52
53#[derive(Deserialize, Clone, Debug)]
54enum ScoringFunc {
55    #[serde(rename = "softmax")]
56    Softmax,
57    #[serde(rename = "sigmoid")]
58    Sigmoid,
59}
60
61#[derive(Deserialize, Clone, Debug)]
62pub struct DeepSeekV3Config {
63    pub(crate) vocab_size: usize,
64    pub(crate) hidden_size: usize,
65    pub(crate) intermediate_size: usize,
66    pub(crate) moe_intermediate_size: usize,
67    pub(crate) num_hidden_layers: usize,
68    pub(crate) num_attention_heads: usize,
69    pub(crate) n_shared_experts: Option<usize>,
70    pub(crate) n_routed_experts: Option<usize>,
71    #[serde(default = "routed_scaling_factor")]
72    pub(crate) routed_scaling_factor: f64,
73    #[serde(default = "topk_method")]
74    topk_method: TopkMethod,
75    pub(crate) num_experts_per_tok: Option<usize>,
76    #[serde(default = "moe_layer_freq")]
77    pub(crate) moe_layer_freq: usize,
78    #[serde(default = "first_k_dense_replace")]
79    pub(crate) first_k_dense_replace: usize,
80    #[serde(default = "scoring_func")]
81    scoring_func: ScoringFunc,
82    #[serde(default = "hidden_act")]
83    pub(crate) hidden_act: Activation,
84    pub(crate) max_position_embeddings: usize,
85    pub(crate) rms_norm_eps: f64,
86    #[serde(default = "tie_word_embeddings")]
87    pub(crate) tie_word_embeddings: bool,
88    pub(crate) rope_theta: f32,
89    pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
90    pub(crate) attention_bias: bool,
91    pub(crate) q_lora_rank: Option<usize>,
92    pub(crate) qk_rope_head_dim: usize,
93    pub(crate) kv_lora_rank: usize,
94    pub(crate) v_head_dim: usize,
95    pub(crate) qk_nope_head_dim: usize,
96    #[serde(default = "use_flash_attn_default")]
97    pub(crate) use_flash_attn: bool,
98    pub(crate) quantization_config: Option<QuantizedConfig>,
99    pub(crate) n_group: usize,
100    pub(crate) topk_group: usize,
101}
102
103impl DeepSeekV3Config {
104    pub(crate) fn q_head_dim(&self) -> usize {
105        self.qk_rope_head_dim + self.qk_nope_head_dim
106    }
107
108    fn softmax_scale(&self) -> f32 {
109        let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
110        if let Some(DeepSeekV2RopeScaling::Yarn {
111            mscale_all_dim,
112            factor,
113            ..
114        }) = self.rope_scaling
115        {
116            let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
117            softmax_scale = softmax_scale * mscale * mscale;
118        }
119        softmax_scale
120    }
121}
122
123enum QProj {
124    Plain(Arc<dyn QuantMethod>),
125    Lora {
126        a: Arc<dyn QuantMethod>,
127        norm: RmsNorm,
128        b: Arc<dyn QuantMethod>,
129    },
130}
131
132impl QProj {
133    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
134        match self {
135            Self::Lora { a, norm, b } => {
136                b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
137            }
138            Self::Plain(lin) => lin.forward_autocast(xs),
139        }
140    }
141}
142
143struct Attention {
144    q: QProj,
145    kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
146    kv_a_layernorm: RmsNorm,
147    kv_b_proj: Arc<dyn QuantMethod>,
148    o_proj: Arc<dyn QuantMethod>,
149    rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
150    cfg: DeepSeekV3Config,
151    q_head_dim: usize,
152    paged_attn: Option<PagedAttention>,
153    sdpa_params: SdpaParams,
154    num_attention_heads: usize,
155}
156
157impl Attention {
158    #[allow(clippy::too_many_arguments)]
159    fn new(
160        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
161        cfg: &DeepSeekV3Config,
162        vb: ShardedVarBuilder,
163        mapper: &dyn DeviceMapper,
164        layer_idx: usize,
165        loading_isq: bool,
166        paged_attn: Option<PagedAttention>,
167        comm: &Arc<mistralrs_quant::Comm>,
168    ) -> Result<Self> {
169        let q_head_dim = cfg.q_head_dim();
170        let q = match cfg.q_lora_rank {
171            Some(lora_rank) => {
172                let a = ReplicatedLayer::new(
173                    cfg.hidden_size,
174                    lora_rank,
175                    &cfg.quantization_config,
176                    cfg.attention_bias,
177                    mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
178                )?;
179                let norm = RmsNorm::new(
180                    lora_rank,
181                    cfg.rms_norm_eps,
182                    mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
183                )?;
184                let b = ColumnParallelLayer::new(
185                    lora_rank,
186                    cfg.num_attention_heads * q_head_dim,
187                    &cfg.quantization_config,
188                    false,
189                    comm,
190                    mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
191                )?;
192                QProj::Lora { a, norm, b }
193            }
194            None => QProj::Plain(ColumnParallelLayer::new(
195                cfg.hidden_size,
196                cfg.num_attention_heads * q_head_dim,
197                &cfg.quantization_config,
198                false,
199                comm,
200                mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
201            )?),
202        };
203
204        let kv_a_proj_with_mqa = ReplicatedLayer::new(
205            cfg.hidden_size,
206            cfg.kv_lora_rank + cfg.qk_rope_head_dim,
207            &cfg.quantization_config,
208            cfg.attention_bias,
209            mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
210        )?;
211        let kv_a_layernorm = RmsNorm::new(
212            cfg.kv_lora_rank,
213            cfg.rms_norm_eps,
214            mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
215        )?;
216        let kv_b_proj = ColumnParallelLayer::new(
217            cfg.kv_lora_rank,
218            cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
219            &cfg.quantization_config,
220            false,
221            comm,
222            mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
223        )?;
224
225        let o_proj = RowParallelLayer::new(
226            cfg.num_attention_heads * cfg.v_head_dim,
227            cfg.hidden_size,
228            &cfg.quantization_config,
229            cfg.attention_bias,
230            comm,
231            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
232        )?;
233
234        Ok(Self {
235            q,
236            kv_a_proj_with_mqa,
237            kv_a_layernorm,
238            kv_b_proj,
239            o_proj,
240            rotary_emb,
241            cfg: cfg.clone(),
242            q_head_dim,
243            paged_attn,
244            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
245            sdpa_params: SdpaParams {
246                n_kv_groups: 1,
247                use_flash_attn: cfg.use_flash_attn,
248                softcap: None,
249                softmax_scale: cfg.softmax_scale(),
250                sliding_window: None,
251            },
252        })
253    }
254
255    fn forward(
256        &self,
257        xs: &Tensor,
258        attention_mask: Option<&Tensor>,
259        seqlen_offsets: &[usize],
260        kv_cache: &mut KvCache,
261        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
262        flash_params: &FlashParams,
263    ) -> Result<Tensor> {
264        let (bs, seq_len, _) = xs.dims3()?;
265
266        let mut q = self.q.forward(xs)?;
267        q = q
268            .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
269            .transpose(1, 2)?;
270        let q_split = q.split(
271            &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
272            D::Minus1,
273        )?;
274        let q_nope = q_split[0].clone();
275        let mut q_pe = q_split[1].clone();
276
277        let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
278        let ckv_split = compressed_kv.split(
279            &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
280            D::Minus1,
281        )?;
282        compressed_kv = ckv_split[0].clone();
283        let mut k_pe = ckv_split[1].clone();
284        k_pe = k_pe
285            .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
286            .transpose(1, 2)?;
287        let mut kv = self
288            .kv_b_proj
289            .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
290        kv = kv
291            .reshape((
292                bs,
293                seq_len,
294                self.num_attention_heads,
295                self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
296            ))?
297            .transpose(1, 2)?;
298
299        let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
300        let k_nope = kv_split[0].clone();
301        let mut v = kv_split[1].clone();
302
303        (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
304
305        let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
306        let mut k = Tensor::cat(
307            &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
308            D::Minus1,
309        )?
310        .contiguous()?;
311
312        let mut attn_out = match &self.paged_attn {
313            Some(paged_attn) => match metadata {
314                Some(((key_cache, value_cache), input_metadata)) => {
315                    let v = v
316                        .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
317                        .contiguous()?;
318                    paged_attn
319                        .forward(
320                            &q,
321                            &k,
322                            &v,
323                            attention_mask,
324                            Some(key_cache),
325                            Some(value_cache),
326                            input_metadata,
327                            &self.sdpa_params,
328                            Some(flash_params),
329                        )?
330                        .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
331                }
332                None => {
333                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
334                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
335                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
336                    // Sanity check.
337                    assert!(attention_mask.is_some());
338                    let v = v
339                        .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
340                        .contiguous()?;
341                    paged_attn
342                        .forward(
343                            &q,
344                            &k,
345                            &v,
346                            attention_mask,
347                            None,
348                            None,
349                            &input_metadata,
350                            &self.sdpa_params,
351                            Some(flash_params),
352                        )?
353                        .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
354                }
355            },
356            None => {
357                (k, v) = kv_cache.append(&k, &v)?;
358
359                Sdpa.run_attention(
360                    &q,
361                    &k,
362                    &v,
363                    attention_mask,
364                    Some(flash_params),
365                    &self.sdpa_params,
366                )?
367            }
368        };
369
370        attn_out = if attention_mask.is_some() {
371            attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
372        } else {
373            attn_out.reshape((bs, seq_len, ()))?
374        };
375
376        self.o_proj.forward_autocast(&attn_out)
377    }
378}
379
380struct Expert {
381    gate: Arc<dyn QuantMethod>,
382    up: Arc<dyn QuantMethod>,
383    down: Arc<dyn QuantMethod>,
384    act: Activation,
385}
386
387impl Expert {
388    fn new(
389        cfg: &DeepSeekV3Config,
390        vb: ShardedVarBuilder,
391        hidden_size: Option<usize>,
392        intermediate_size: Option<usize>,
393    ) -> Result<Self> {
394        let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
395        let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
396
397        Ok(Self {
398            gate: ReplicatedLayer::new(
399                hidden_size,
400                intermediate_size,
401                &cfg.quantization_config,
402                false,
403                vb.pp("gate_proj"),
404            )?,
405            up: ReplicatedLayer::new(
406                hidden_size,
407                intermediate_size,
408                &cfg.quantization_config,
409                false,
410                vb.pp("up_proj"),
411            )?,
412            down: ReplicatedLayer::new(
413                intermediate_size,
414                hidden_size,
415                &cfg.quantization_config,
416                false,
417                vb.pp("down_proj"),
418            )?,
419            act: cfg.hidden_act,
420        })
421    }
422
423    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
424        let original_dtype = xs.dtype();
425        let mut xs = xs.clone();
426        if let Some(t) = self.gate.quantized_act_type() {
427            xs = xs.to_dtype(t)?;
428        }
429        let lhs = self.gate.forward(&xs)?;
430        let rhs = self.up.forward(&xs)?;
431        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
432            &lhs,
433            &rhs,
434            self.act.try_into()?,
435        )?)?;
436        if self.gate.quantized_act_type().is_some() {
437            res = res.to_dtype(original_dtype)?;
438        }
439        Ok(res)
440    }
441}
442
443struct MoeGate {
444    weight: Tensor,
445    cfg: DeepSeekV3Config,
446    top_k: usize,
447    n_routed_experts: usize,
448    e_score_correction_bias: Option<Tensor>,
449}
450
451impl MoeGate {
452    fn new(cfg: &DeepSeekV3Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
453        let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
454        let e_score_correction_bias = if matches!(cfg.topk_method, TopkMethod::NoAuxTc) {
455            Some(vb.get_with_hints_dtype(
456                n_routed_experts,
457                "e_score_correction_bias",
458                Default::default(),
459                DType::F32,
460            )?)
461        } else {
462            None
463        };
464        Ok(Self {
465            weight,
466            cfg: cfg.clone(),
467            top_k: cfg.num_experts_per_tok.unwrap(),
468            n_routed_experts,
469            e_score_correction_bias,
470        })
471    }
472
473    /// (topk_idx, topk_weight)
474    fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
475        let (bs, seq_len, h) = xs.dims3()?;
476        // Compute gating score
477        let xs = xs.reshape(((), h))?;
478        let logits = xs
479            .to_dtype(DType::F32)?
480            .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
481        let scores = match self.cfg.scoring_func {
482            ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
483            ScoringFunc::Sigmoid => candle_nn::ops::sigmoid(&logits)?,
484        };
485
486        // Select top-k experts
487        let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
488            TopkMethod::Greedy => {
489                let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
490                (values, indices)
491            }
492            TopkMethod::NoAuxTc => {
493                let Some(e_score_correction_bias) = &self.e_score_correction_bias else {
494                    candle_core::bail!("Expected e_score_correction_bias")
495                };
496                let scores_for_choice = scores
497                    .reshape((bs * seq_len, ()))?
498                    .broadcast_add(&e_score_correction_bias.unsqueeze(0)?)?;
499                // (n, n_group)
500                let group_scores = scores_for_choice
501                    .reshape((bs * seq_len, self.cfg.n_group, ()))?
502                    .topk(2)?
503                    .values
504                    .sum(D::Minus1)?;
505                // (n, topk_group)
506                let group_idx = group_scores.topk(self.cfg.topk_group)?.indices;
507                // (n, n_group)
508                let mut group_mask = group_scores.zeros_like()?;
509                // (n, n_group)
510                group_mask = group_mask.scatter_add(
511                    &group_idx,
512                    &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
513                    1,
514                )?;
515                // (n, e)
516                let score_mask = group_mask
517                    .unsqueeze(D::Minus1)?
518                    .expand((
519                        bs * seq_len,
520                        self.cfg.n_group,
521                        self.n_routed_experts / self.cfg.n_group,
522                    ))?
523                    .reshape((bs * seq_len, ()))?;
524                // (n, e)
525                // Invert the mask
526                let tmp_scores = scores_for_choice.broadcast_mul(&score_mask)?;
527                let topk_idx = tmp_scores.topk(self.top_k)?.indices;
528                (scores.gather(&topk_idx, 1)?, topk_idx)
529            }
530            TopkMethod::GroupLimitedGreedy => {
531                // (n, n_group)
532                let group_scores = scores
533                    .reshape((bs * seq_len, self.cfg.n_group, ()))?
534                    .max(D::Minus1)?;
535                // (n, topk_group)
536                let group_idx = group_scores.topk_unsorted(self.cfg.topk_group)?.indices;
537                // (n, n_group)
538                let mut group_mask = group_scores.zeros_like()?;
539                // (n, n_group)
540                group_mask = group_mask.scatter_add(
541                    &group_idx,
542                    &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
543                    1,
544                )?;
545                // (n, e)
546                let score_mask = group_mask
547                    .unsqueeze(D::Minus1)?
548                    .expand((
549                        bs * seq_len,
550                        self.cfg.n_group,
551                        self.n_routed_experts / self.cfg.n_group,
552                    ))?
553                    .reshape((bs * seq_len, ()))?;
554                // (n, e)
555                // Invert the mask
556                let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
557                let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
558                (values, indices)
559            }
560        };
561
562        if matches!(self.cfg.scoring_func, ScoringFunc::Sigmoid) {
563            let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
564            topk_weight = topk_weight.broadcast_div(&denmoninator)?;
565        }
566
567        // Must multiply the scaling factor
568        topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
569
570        Ok((topk_idx, topk_weight))
571    }
572}
573
574struct Moe {
575    experts: Vec<Option<Expert>>,
576    shared_experts: Option<Mlp>,
577    gate: MoeGate,
578    all_reduce: SumAllReduce,
579    experts_start_idx: usize,
580    experts_end_idx: usize,
581    world_size: usize,
582}
583
584impl Moe {
585    #[allow(clippy::too_many_arguments)]
586    fn new(
587        cfg: &DeepSeekV3Config,
588        vb: ShardedVarBuilder,
589        mapper: &dyn DeviceMapper,
590        layer_idx: usize,
591        loading_isq: bool,
592        n_shared_experts: Option<usize>,
593        n_routed_experts: usize,
594        comm: &Arc<mistralrs_quant::Comm>,
595    ) -> Result<Self> {
596        let mut experts = Vec::with_capacity(n_routed_experts);
597        let n_local_experts = n_routed_experts / comm.world_size();
598        let experts_start_idx = comm.rank() * n_local_experts;
599        let experts_end_idx = experts_start_idx + n_local_experts;
600        for i in 0..n_routed_experts {
601            if i >= experts_start_idx && i < experts_end_idx {
602                let vb_e = vb.pp("experts").pp(i);
603                experts.push(Some(Expert::new(
604                    cfg,
605                    mapper.set_device(layer_idx, vb_e, loading_isq),
606                    None,
607                    Some(cfg.moe_intermediate_size),
608                )?));
609            } else {
610                experts.push(None);
611            }
612        }
613        let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
614            let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
615            Some(Mlp::new(
616                mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
617                cfg.hidden_size,
618                intermediate_size,
619                &cfg.quantization_config,
620                cfg.hidden_act,
621                comm,
622            )?)
623        } else {
624            None
625        };
626        let gate = MoeGate::new(
627            cfg,
628            mapper.set_device(layer_idx, vb.pp("gate"), false),
629            n_routed_experts,
630        )?;
631        Ok(Self {
632            experts,
633            shared_experts,
634            gate,
635            all_reduce: SumAllReduce::new(comm),
636            experts_end_idx,
637            experts_start_idx,
638            world_size: comm.world_size(),
639        })
640    }
641
642    fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
643        let mut y = xs.zeros_like()?;
644        let topk_weight = if topk_weight.dtype() != xs.dtype() {
645            topk_weight.to_dtype(xs.dtype())?
646        } else {
647            topk_weight.to_owned()
648        };
649        let unique_ids: HashSet<u32> =
650            HashSet::from_iter(topk_ids.to_device(&Device::Cpu)?.flatten_all()?.to_vec1()?);
651        for i in self.experts_start_idx..self.experts_end_idx {
652            if !unique_ids.contains(&(i as u32)) {
653                continue;
654            }
655            let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
656            let idx = &idx_top.i(0)?.contiguous()?;
657            let top = &idx_top.i(1)?.contiguous()?;
658
659            let expert = self.experts[i]
660                .as_ref()
661                .context("Expert is not present for this rank.")?;
662
663            y = y.index_add(
664                idx,
665                &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
666                    &topk_weight
667                        .index_select(idx, 0)?
668                        .gather(&top.unsqueeze(1)?, 1)?
669                        .squeeze(1)?
670                        .unsqueeze(D::Minus1)?,
671                )?,
672                0,
673            )?;
674        }
675
676        if self.world_size > 1 {
677            y = self.all_reduce.sum_all_reduce(&y)?;
678        }
679
680        Ok(y)
681    }
682
683    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
684        let identity = xs.clone();
685        let orig_shape = xs.shape();
686        let (topk_idx, topk_weight) = self.gate.forward(xs)?;
687        let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
688
689        let mut y = self
690            .moe_infer(&xs, &topk_idx, &topk_weight)?
691            .reshape(orig_shape)?;
692        if let Some(ref shared_experts) = self.shared_experts {
693            y = (y + shared_experts.forward(&identity)?)?;
694        }
695        Ok(y)
696    }
697}
698
699enum MoeOrMlp {
700    Moe(Moe),
701    Mlp(Mlp),
702}
703
704impl MoeOrMlp {
705    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
706        match self {
707            Self::Mlp(mlp) => mlp.forward(xs),
708            Self::Moe(moe) => moe.forward(xs),
709        }
710    }
711}
712
713struct DecoderLayer {
714    input_layernorm: RmsNorm,
715    post_attention_layernorm: RmsNorm,
716    attn: Attention,
717    moe_or_mlp: MoeOrMlp,
718}
719
720impl DecoderLayer {
721    #[allow(clippy::too_many_arguments)]
722    fn new(
723        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
724        cfg: &DeepSeekV3Config,
725        vb: ShardedVarBuilder,
726        mapper: &dyn DeviceMapper,
727        layer_idx: usize,
728        loading_isq: bool,
729        paged_attn: Option<PagedAttention>,
730        comm: &Arc<mistralrs_quant::Comm>,
731    ) -> Result<Self> {
732        let attn = Attention::new(
733            rotary_emb,
734            cfg,
735            vb.pp("self_attn"),
736            mapper,
737            layer_idx,
738            loading_isq,
739            paged_attn,
740            comm,
741        )?;
742        let input_layernorm = RmsNorm::new(
743            cfg.hidden_size,
744            cfg.rms_norm_eps,
745            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
746        )?;
747        let post_attention_layernorm = RmsNorm::new(
748            cfg.hidden_size,
749            cfg.rms_norm_eps,
750            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
751        )?;
752        let moe_or_mlp = if cfg.n_routed_experts.is_some()
753            && layer_idx >= cfg.first_k_dense_replace
754            && layer_idx % cfg.moe_layer_freq == 0
755        {
756            MoeOrMlp::Moe(Moe::new(
757                cfg,
758                vb.pp("mlp"),
759                mapper,
760                layer_idx,
761                loading_isq,
762                cfg.n_shared_experts,
763                cfg.n_routed_experts.unwrap(),
764                comm,
765            )?)
766        } else {
767            MoeOrMlp::Mlp(Mlp::new(
768                mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
769                cfg.hidden_size,
770                cfg.intermediate_size,
771                &cfg.quantization_config,
772                cfg.hidden_act,
773                comm,
774            )?)
775        };
776
777        Ok(Self {
778            input_layernorm,
779            post_attention_layernorm,
780            attn,
781            moe_or_mlp,
782        })
783    }
784
785    fn forward(
786        &self,
787        xs: &Tensor,
788        attention_mask: Option<&Tensor>,
789        seqlen_offsets: &[usize],
790        kv_cache: &mut KvCache,
791        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
792        flash_params: &FlashParams,
793    ) -> Result<Tensor> {
794        let residual = xs;
795        let xs = self.input_layernorm.forward(xs)?;
796        let xs = self.attn.forward(
797            &xs,
798            attention_mask,
799            seqlen_offsets,
800            kv_cache,
801            metadata,
802            flash_params,
803        )?;
804        let xs = (xs + residual)?;
805        let residual = &xs;
806        let xs = self
807            .moe_or_mlp
808            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
809        residual + xs
810    }
811}
812
813pub struct DeepSeekV3 {
814    lm_head: Arc<dyn QuantMethod>,
815    embed_tokens: Embedding,
816    norm: RmsNorm,
817    layers: Vec<DecoderLayer>,
818    cache: EitherCache,
819    device: Device,
820    max_seq_len: usize,
821    cfg: ModelConfigMetadata,
822    mapper: Box<dyn DeviceMapper + Send + Sync>,
823}
824
825impl DeepSeekV3 {
826    pub fn new(
827        cfg: &DeepSeekV3Config,
828        vb: ShardedVarBuilder,
829        _is_gptx: bool,
830        normal_loading_metadata: NormalLoadingMetadata,
831        attention_mechanism: AttentionImplementation,
832    ) -> Result<Self> {
833        let vb_m = vb.pp("model");
834
835        let mapper = normal_loading_metadata.mapper;
836
837        let embed_tokens = embedding(
838            cfg.vocab_size,
839            cfg.hidden_size,
840            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
841            &cfg.quantization_config,
842        )?;
843        let lm_head = if !cfg.tie_word_embeddings {
844            ReplicatedLayer::new(
845                cfg.hidden_size,
846                cfg.vocab_size,
847                &None,
848                false,
849                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
850            )?
851        } else {
852            ReplicatedLayer::from_linear(candle_nn::Linear::new(
853                mapper.cast_nm_device(
854                    embed_tokens.embeddings(),
855                    normal_loading_metadata.loading_isq,
856                )?,
857                None,
858            ))?
859        };
860        let norm = RmsNorm::new(
861            cfg.hidden_size,
862            cfg.rms_norm_eps,
863            mapper.set_nm_device(vb_m.pp("norm"), false),
864        )?;
865
866        let mut ropes = HashMap::new();
867        let rope_cfg = DeepSeekV2RopeConfig {
868            rope_scaling: cfg.rope_scaling.clone(),
869            max_position_embeddings: cfg.max_position_embeddings,
870            rope_theta: cfg.rope_theta,
871            qk_rope_head_dim: cfg.qk_rope_head_dim,
872        };
873        for i in 0..cfg.num_hidden_layers {
874            let device = mapper
875                .device_for(i, false)
876                .unwrap_or(&normal_loading_metadata.real_device);
877            ropes.insert(
878                device.location(),
879                Arc::new(DeepSeekV2RotaryEmbedding::new(
880                    &rope_cfg,
881                    vb.dtype(),
882                    device,
883                )?),
884            );
885        }
886
887        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
888        let vb_l = vb_m.pp("layers");
889        for layer_idx in NiceProgressBar::<_, 'b'>(
890            0..cfg.num_hidden_layers,
891            "Loading repeating layers",
892            &normal_loading_metadata.multi_progress,
893        ) {
894            let device = mapper
895                .device_for(layer_idx, false)
896                .unwrap_or(&normal_loading_metadata.real_device);
897            let rotary_emb = ropes
898                .get(&device.location())
899                .expect("No RoPE for device location!")
900                .clone();
901            let paged_attn = match &attention_mechanism {
902                AttentionImplementation::Eager => None,
903                AttentionImplementation::PagedAttention => Some(
904                    PagedAttention::new(cfg.v_head_dim, device, None)
905                        .expect("Failed to create PagedAttention"),
906                ),
907            };
908            let comm = mapper.get_comm_for(layer_idx)?;
909            let layer = DecoderLayer::new(
910                rotary_emb.clone(),
911                cfg,
912                vb_l.pp(layer_idx),
913                &*mapper,
914                layer_idx,
915                normal_loading_metadata.loading_isq,
916                paged_attn,
917                &comm,
918            )?;
919            layers.push(layer)
920        }
921
922        Ok(Self {
923            lm_head,
924            embed_tokens,
925            norm,
926            layers,
927            cache: EitherCache::Normal(NormalCache::new(
928                cfg.num_hidden_layers,
929                cfg.max_position_embeddings,
930            )),
931            device: normal_loading_metadata.real_device.clone(),
932            max_seq_len: cfg.max_position_embeddings,
933            cfg: ModelConfigMetadata {
934                max_seq_len: cfg.max_position_embeddings,
935                num_layers: cfg.num_hidden_layers,
936                hidden_size: cfg.hidden_size,
937                num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
938                    .max(1),
939                num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
940                    .max(1),
941                sliding_window: None,
942                k_head_dim: cfg.q_head_dim(),
943                v_head_dim: if matches!(
944                    attention_mechanism,
945                    AttentionImplementation::PagedAttention
946                ) {
947                    cfg.q_head_dim()
948                } else {
949                    cfg.v_head_dim
950                },
951            },
952            mapper,
953        })
954    }
955
956    #[allow(clippy::too_many_arguments)]
957    pub fn forward(
958        &self,
959        input_ids: &Tensor,
960        seqlen_offsets: &[usize],
961        context_lens: Vec<(usize, usize)>,
962        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
963        flash_params: &FlashParams,
964    ) -> Result<Tensor> {
965        let mut xs = self.embed_tokens.forward(input_ids)?;
966        let cache = &mut self.cache.normal().0;
967        let attention_mask = CausalMasker.make_causal_mask_matrix(
968            input_ids,
969            metadata
970                .as_ref()
971                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
972                .unwrap_or(cache as &dyn PastKvLenCache),
973            xs.dtype(),
974            self.cfg.num_attn_heads,
975        )?;
976        // PagedAttention prompt chunking
977        let attention_mask = attention_mask.filter(|_| {
978            metadata
979                .as_ref()
980                .map(|(_, meta)| meta.is_first_prompt_chunk)
981                .unwrap_or(true)
982        });
983        for (i, layer) in self.layers.iter().enumerate() {
984            xs = self.mapper.map(xs, i)?;
985            xs = layer.forward(
986                &xs,
987                attention_mask
988                    .as_ref()
989                    .map(|m| m.to_device(xs.device()).unwrap())
990                    .as_ref(),
991                seqlen_offsets,
992                &mut cache[i],
993                metadata
994                    .as_ref()
995                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
996                flash_params,
997            )?;
998        }
999        let xs = xs.to_device(&self.device)?;
1000        let xs = xs.apply(&self.norm)?;
1001        extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
1002    }
1003}
1004
1005impl IsqModel for DeepSeekV3 {
1006    fn get_layers(
1007        &mut self,
1008    ) -> (
1009        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1010        &dyn DeviceMapper,
1011    ) {
1012        let mut tensors = Vec::new();
1013        tensors.push((&mut self.lm_head, None));
1014        for (i, layer) in self.layers.iter_mut().enumerate() {
1015            match &mut layer.attn.q {
1016                QProj::Plain(q) => {
1017                    tensors.push((q, Some(i)));
1018                }
1019                QProj::Lora { a, norm: _, b } => {
1020                    tensors.push((a, Some(i)));
1021                    tensors.push((b, Some(i)));
1022                }
1023            }
1024            tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
1025            tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
1026            tensors.push((&mut layer.attn.o_proj, Some(i)));
1027            match &mut layer.moe_or_mlp {
1028                MoeOrMlp::Mlp(mlp) => {
1029                    tensors.push((&mut mlp.gate, Some(i)));
1030                    tensors.push((&mut mlp.up, Some(i)));
1031                    tensors.push((&mut mlp.down, Some(i)));
1032                }
1033                MoeOrMlp::Moe(moe) => {
1034                    for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1035                        tensors.push((&mut mlp.gate, Some(i)));
1036                        tensors.push((&mut mlp.up, Some(i)));
1037                        tensors.push((&mut mlp.down, Some(i)));
1038                    }
1039                    if let Some(mlp) = &mut moe.shared_experts {
1040                        tensors.push((&mut mlp.gate, Some(i)));
1041                        tensors.push((&mut mlp.up, Some(i)));
1042                        tensors.push((&mut mlp.down, Some(i)));
1043                    }
1044                }
1045            }
1046        }
1047        (tensors, &*self.mapper)
1048    }
1049
1050    fn get_layers_moe_experts_only(
1051        &mut self,
1052    ) -> (
1053        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1054        &dyn DeviceMapper,
1055    ) {
1056        let mut tensors = Vec::new();
1057        tensors.push((&mut self.lm_head, None));
1058        for (i, layer) in self.layers.iter_mut().enumerate() {
1059            match &mut layer.moe_or_mlp {
1060                MoeOrMlp::Mlp(mlp) => {
1061                    tensors.push((&mut mlp.gate, Some(i)));
1062                    tensors.push((&mut mlp.up, Some(i)));
1063                    tensors.push((&mut mlp.down, Some(i)));
1064                }
1065                MoeOrMlp::Moe(moe) => {
1066                    for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1067                        tensors.push((&mut mlp.gate, Some(i)));
1068                        tensors.push((&mut mlp.up, Some(i)));
1069                        tensors.push((&mut mlp.down, Some(i)));
1070                    }
1071                    if let Some(mlp) = &mut moe.shared_experts {
1072                        tensors.push((&mut mlp.gate, Some(i)));
1073                        tensors.push((&mut mlp.up, Some(i)));
1074                        tensors.push((&mut mlp.down, Some(i)));
1075                    }
1076                }
1077            }
1078        }
1079        (tensors, &*self.mapper)
1080    }
1081
1082    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1083        let uvb = UnVarBuilder::new();
1084
1085        let uvb_m = uvb.pp("model");
1086        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1087        uvb_m.pp("norm").add(&self.norm);
1088
1089        for (layer_idx, layer) in self.layers.iter().enumerate() {
1090            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1091            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1092            uvb_l
1093                .pp("post_attention_layernorm")
1094                .add(&layer.post_attention_layernorm);
1095
1096            uvb_l
1097                .pp("self_attn")
1098                .pp("kv_a_layernorm")
1099                .add(&layer.attn.kv_a_layernorm);
1100
1101            match &layer.moe_or_mlp {
1102                MoeOrMlp::Moe(moe) => {
1103                    uvb_l
1104                        .pp("mlp")
1105                        .pp("gate")
1106                        .add_tensor("weight", moe.gate.weight.clone());
1107                }
1108                MoeOrMlp::Mlp(_) => (),
1109            }
1110
1111            match &layer.attn.q {
1112                QProj::Plain(_) => (),
1113                QProj::Lora { a: _, norm, b: _ } => {
1114                    uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1115                }
1116            }
1117        }
1118
1119        uvb.to_safetensors()
1120    }
1121
1122    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1123        let uvb = UnVarBuilder::new();
1124
1125        let uvb_m = uvb.pp("model");
1126        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1127        uvb_m.pp("norm").add(&self.norm);
1128
1129        for (layer_idx, layer) in self.layers.iter().enumerate() {
1130            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1131            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1132            uvb_l
1133                .pp("post_attention_layernorm")
1134                .add(&layer.post_attention_layernorm);
1135
1136            uvb_l
1137                .pp("self_attn")
1138                .pp("kv_a_layernorm")
1139                .add(&layer.attn.kv_a_layernorm);
1140
1141            match &layer.moe_or_mlp {
1142                MoeOrMlp::Moe(moe) => {
1143                    uvb_l
1144                        .pp("mlp")
1145                        .pp("gate")
1146                        .add_tensor("weight", moe.gate.weight.clone());
1147                }
1148                MoeOrMlp::Mlp(_) => (),
1149            }
1150
1151            match &layer.attn.q {
1152                QProj::Plain(q) => {
1153                    uvb_l.pp("self_attn").pp("q_proj").add(q);
1154                }
1155                QProj::Lora { a, norm, b } => {
1156                    uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1157                    uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1158                    uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1159                }
1160            }
1161            uvb_l
1162                .pp("self_attn")
1163                .pp("kv_a_proj_with_mqa")
1164                .add(&layer.attn.kv_a_proj_with_mqa);
1165            uvb_l
1166                .pp("self_attn")
1167                .pp("kv_b_proj")
1168                .add(&layer.attn.kv_b_proj);
1169            uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1170        }
1171
1172        Some(uvb.to_safetensors())
1173    }
1174}
1175
1176impl NormalModel for DeepSeekV3 {
1177    fn forward(
1178        &self,
1179        input_ids: &Tensor,
1180        seqlen_offsets: &[usize],
1181        context_lens: Vec<(usize, usize)>,
1182        _position_ids: Vec<usize>,
1183        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1184        flash_params: &FlashParams,
1185    ) -> Result<Tensor> {
1186        self.forward(
1187            input_ids,
1188            seqlen_offsets,
1189            context_lens,
1190            metadata,
1191            flash_params,
1192        )
1193    }
1194    fn xlora_forward(
1195        &self,
1196        _input_ids: &Tensor,
1197        _input_ids_full: &Tensor,
1198        _seqlen_offsets: &[usize],
1199        _seqlen_offsets_full: &[usize],
1200        _no_kv_cache: bool,
1201        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1202        _context_lens: Vec<(usize, usize)>,
1203        _position_ids: Vec<usize>,
1204        _flash_params: &FlashParams,
1205        _flash_params_full: &FlashParams,
1206    ) -> Result<Tensor> {
1207        unimplemented!()
1208    }
1209    fn cache(&self) -> &EitherCache {
1210        &self.cache
1211    }
1212    fn cache_mut(&mut self) -> &mut EitherCache {
1213        &mut self.cache
1214    }
1215    fn device(&self) -> &Device {
1216        &self.device
1217    }
1218    fn is_xlora(&self) -> bool {
1219        false
1220    }
1221    fn max_seq_len(&self) -> usize {
1222        self.max_seq_len
1223    }
1224    fn config(&self) -> &ModelConfigMetadata {
1225        &self.cfg
1226    }
1227}
1228
1229impl AnyMoeBaseModelMixin for DeepSeekV3 {}