mistralrs_core/models/
deepseek3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8    distributed::DistributedOperation, ColumnParallelLayer, QuantMethod, QuantizedConfig,
9    ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14    amoe::AnyMoeBaseModelMixin,
15    attention::SdpaParams,
16    device_map::DeviceMapper,
17    layers::{
18        embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19        DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20    },
21    layers_masker::{masked_fill, PastKvLenCache},
22    ops::{BincountOp, NonZeroOp, SplitOp, TopKLastDimOp, TopKOutput},
23    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24    pipeline::{
25        extract_logits,
26        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28    },
29    serde_default_fn,
30    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32
33serde_default_fn!(f64, routed_scaling_factor, 1.0);
34serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
35serde_default_fn!(usize, moe_layer_freq, 1);
36serde_default_fn!(usize, first_k_dense_replace, 0);
37serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
38serde_default_fn!(Activation, hidden_act, Activation::Silu);
39serde_default_fn!(bool, tie_word_embeddings, false);
40serde_default_fn!(bool, use_flash_attn_default, false);
41
42#[derive(Deserialize, Clone, Debug)]
43enum TopkMethod {
44    #[serde(rename = "noaux_tc")]
45    NoAuxTc,
46    #[serde(rename = "greedy")]
47    Greedy,
48    #[serde(rename = "group_limited_greedy")]
49    GroupLimitedGreedy,
50}
51
52#[derive(Deserialize, Clone, Debug)]
53enum ScoringFunc {
54    #[serde(rename = "softmax")]
55    Softmax,
56    #[serde(rename = "sigmoid")]
57    Sigmoid,
58}
59
60#[derive(Deserialize, Clone, Debug)]
61pub struct DeepSeekV3Config {
62    pub(crate) vocab_size: usize,
63    pub(crate) hidden_size: usize,
64    pub(crate) intermediate_size: usize,
65    pub(crate) moe_intermediate_size: usize,
66    pub(crate) num_hidden_layers: usize,
67    pub(crate) num_attention_heads: usize,
68    pub(crate) n_shared_experts: Option<usize>,
69    pub(crate) n_routed_experts: Option<usize>,
70    #[serde(default = "routed_scaling_factor")]
71    pub(crate) routed_scaling_factor: f64,
72    #[serde(default = "topk_method")]
73    topk_method: TopkMethod,
74    pub(crate) num_experts_per_tok: Option<usize>,
75    #[serde(default = "moe_layer_freq")]
76    pub(crate) moe_layer_freq: usize,
77    #[serde(default = "first_k_dense_replace")]
78    pub(crate) first_k_dense_replace: usize,
79    #[serde(default = "scoring_func")]
80    scoring_func: ScoringFunc,
81    #[serde(default = "hidden_act")]
82    pub(crate) hidden_act: Activation,
83    pub(crate) max_position_embeddings: usize,
84    pub(crate) rms_norm_eps: f64,
85    #[serde(default = "tie_word_embeddings")]
86    pub(crate) tie_word_embeddings: bool,
87    pub(crate) rope_theta: f32,
88    pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
89    pub(crate) attention_bias: bool,
90    pub(crate) q_lora_rank: Option<usize>,
91    pub(crate) qk_rope_head_dim: usize,
92    pub(crate) kv_lora_rank: usize,
93    pub(crate) v_head_dim: usize,
94    pub(crate) qk_nope_head_dim: usize,
95    #[serde(default = "use_flash_attn_default")]
96    pub(crate) use_flash_attn: bool,
97    pub(crate) quantization_config: Option<QuantizedConfig>,
98    pub(crate) n_group: usize,
99    pub(crate) topk_group: usize,
100}
101
102impl DeepSeekV3Config {
103    pub(crate) fn q_head_dim(&self) -> usize {
104        self.qk_rope_head_dim + self.qk_nope_head_dim
105    }
106
107    fn softmax_scale(&self) -> f32 {
108        let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
109        if let Some(DeepSeekV2RopeScaling::Yarn {
110            mscale_all_dim,
111            factor,
112            ..
113        }) = self.rope_scaling
114        {
115            let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
116            softmax_scale = softmax_scale * mscale * mscale;
117        }
118        softmax_scale
119    }
120}
121
122enum QProj {
123    Plain(Arc<dyn QuantMethod>),
124    Lora {
125        a: Arc<dyn QuantMethod>,
126        norm: RmsNorm,
127        b: Arc<dyn QuantMethod>,
128    },
129}
130
131impl QProj {
132    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
133        match self {
134            Self::Lora { a, norm, b } => {
135                b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
136            }
137            Self::Plain(lin) => lin.forward_autocast(xs),
138        }
139    }
140}
141
142struct Attention {
143    q: QProj,
144    kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
145    kv_a_layernorm: RmsNorm,
146    kv_b_proj: Arc<dyn QuantMethod>,
147    o_proj: Arc<dyn QuantMethod>,
148    rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
149    cfg: DeepSeekV3Config,
150    q_head_dim: usize,
151    paged_attn: Option<PagedAttention>,
152    sdpa_params: SdpaParams,
153    num_attention_heads: usize,
154}
155
156impl Attention {
157    #[allow(clippy::too_many_arguments)]
158    fn new(
159        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
160        cfg: &DeepSeekV3Config,
161        vb: ShardedVarBuilder,
162        mapper: &dyn DeviceMapper,
163        layer_idx: usize,
164        loading_isq: bool,
165        paged_attn: Option<PagedAttention>,
166        comm: &Arc<mistralrs_quant::Comm>,
167    ) -> Result<Self> {
168        let q_head_dim = cfg.q_head_dim();
169        let q = match cfg.q_lora_rank {
170            Some(lora_rank) => {
171                let a = ReplicatedLayer::new(
172                    cfg.hidden_size,
173                    lora_rank,
174                    &cfg.quantization_config,
175                    cfg.attention_bias,
176                    mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
177                )?;
178                let norm = RmsNorm::new(
179                    lora_rank,
180                    cfg.rms_norm_eps,
181                    mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
182                )?;
183                let b = ColumnParallelLayer::new(
184                    lora_rank,
185                    cfg.num_attention_heads * q_head_dim,
186                    &cfg.quantization_config,
187                    false,
188                    comm,
189                    mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
190                )?;
191                QProj::Lora { a, norm, b }
192            }
193            None => QProj::Plain(ColumnParallelLayer::new(
194                cfg.hidden_size,
195                cfg.num_attention_heads * q_head_dim,
196                &cfg.quantization_config,
197                false,
198                comm,
199                mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
200            )?),
201        };
202
203        let kv_a_proj_with_mqa = ReplicatedLayer::new(
204            cfg.hidden_size,
205            cfg.kv_lora_rank + cfg.qk_rope_head_dim,
206            &cfg.quantization_config,
207            cfg.attention_bias,
208            mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
209        )?;
210        let kv_a_layernorm = RmsNorm::new(
211            cfg.kv_lora_rank,
212            cfg.rms_norm_eps,
213            mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
214        )?;
215        let kv_b_proj = ColumnParallelLayer::new(
216            cfg.kv_lora_rank,
217            cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
218            &cfg.quantization_config,
219            false,
220            comm,
221            mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
222        )?;
223
224        let o_proj = RowParallelLayer::new(
225            cfg.num_attention_heads * cfg.v_head_dim,
226            cfg.hidden_size,
227            &cfg.quantization_config,
228            cfg.attention_bias,
229            comm,
230            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
231        )?;
232
233        Ok(Self {
234            q,
235            kv_a_proj_with_mqa,
236            kv_a_layernorm,
237            kv_b_proj,
238            o_proj,
239            rotary_emb,
240            cfg: cfg.clone(),
241            q_head_dim,
242            paged_attn,
243            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
244            sdpa_params: SdpaParams {
245                n_kv_groups: 1,
246                use_flash_attn: cfg.use_flash_attn,
247                softcap: None,
248                softmax_scale: cfg.softmax_scale(),
249                sliding_window: None,
250            },
251        })
252    }
253
254    fn forward(
255        &self,
256        xs: &Tensor,
257        attention_mask: Option<&Tensor>,
258        seqlen_offsets: &[usize],
259        kv_cache: &mut KvCache,
260        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
261        flash_params: &FlashParams,
262    ) -> Result<Tensor> {
263        let (bs, seq_len, _) = xs.dims3()?;
264
265        let mut q = self.q.forward(xs)?;
266        q = q
267            .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
268            .transpose(1, 2)?;
269        let q_split = q.split(
270            &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
271            D::Minus1,
272        )?;
273        let q_nope = q_split[0].clone();
274        let mut q_pe = q_split[1].clone();
275
276        let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
277        let ckv_split = compressed_kv.split(
278            &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
279            D::Minus1,
280        )?;
281        compressed_kv = ckv_split[0].clone();
282        let mut k_pe = ckv_split[1].clone();
283        k_pe = k_pe
284            .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
285            .transpose(1, 2)?;
286        let mut kv = self
287            .kv_b_proj
288            .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
289        kv = kv
290            .reshape((
291                bs,
292                seq_len,
293                self.num_attention_heads,
294                self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
295            ))?
296            .transpose(1, 2)?;
297
298        let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
299        let k_nope = kv_split[0].clone();
300        let mut v = kv_split[1].clone();
301
302        (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
303
304        let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
305        let mut k = Tensor::cat(
306            &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
307            D::Minus1,
308        )?
309        .contiguous()?;
310
311        let mut attn_out = match &self.paged_attn {
312            Some(paged_attn) => match metadata {
313                Some(((key_cache, value_cache), input_metadata)) => {
314                    let v = v
315                        .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
316                        .contiguous()?;
317                    paged_attn
318                        .forward(
319                            &q,
320                            &k,
321                            &v,
322                            attention_mask,
323                            Some(key_cache),
324                            Some(value_cache),
325                            input_metadata,
326                            &self.sdpa_params,
327                            Some(flash_params),
328                        )?
329                        .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
330                }
331                None => {
332                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
333                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
334                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
335                    // Sanity check.
336                    assert!(attention_mask.is_some());
337                    let v = v
338                        .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
339                        .contiguous()?;
340                    paged_attn
341                        .forward(
342                            &q,
343                            &k,
344                            &v,
345                            attention_mask,
346                            None,
347                            None,
348                            &input_metadata,
349                            &self.sdpa_params,
350                            Some(flash_params),
351                        )?
352                        .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
353                }
354            },
355            None => {
356                (k, v) = kv_cache.append(&k, &v)?;
357
358                Sdpa.run_attention(
359                    &q,
360                    &k,
361                    &v,
362                    attention_mask,
363                    Some(flash_params),
364                    &self.sdpa_params,
365                )?
366            }
367        };
368
369        attn_out = if attention_mask.is_some() {
370            attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
371        } else {
372            attn_out.reshape((bs, seq_len, ()))?
373        };
374
375        self.o_proj.forward_autocast(&attn_out)
376    }
377}
378
379struct Expert {
380    gate: Arc<dyn QuantMethod>,
381    up: Arc<dyn QuantMethod>,
382    down: Arc<dyn QuantMethod>,
383    act: Activation,
384}
385
386impl Expert {
387    fn new(
388        cfg: &DeepSeekV3Config,
389        vb: ShardedVarBuilder,
390        hidden_size: Option<usize>,
391        intermediate_size: Option<usize>,
392    ) -> Result<Self> {
393        let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
394        let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
395
396        Ok(Self {
397            gate: ReplicatedLayer::new(
398                hidden_size,
399                intermediate_size,
400                &cfg.quantization_config,
401                false,
402                vb.pp("gate_proj"),
403            )?,
404            up: ReplicatedLayer::new(
405                hidden_size,
406                intermediate_size,
407                &cfg.quantization_config,
408                false,
409                vb.pp("up_proj"),
410            )?,
411            down: ReplicatedLayer::new(
412                intermediate_size,
413                hidden_size,
414                &cfg.quantization_config,
415                false,
416                vb.pp("down_proj"),
417            )?,
418            act: cfg.hidden_act,
419        })
420    }
421
422    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423        let original_dtype = xs.dtype();
424        let mut xs = xs.clone();
425        if let Some(t) = self.gate.quantized_act_type() {
426            xs = xs.to_dtype(t)?;
427        }
428        let lhs = self.gate.forward(&xs)?;
429        let rhs = self.up.forward(&xs)?;
430        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
431            &lhs,
432            &rhs,
433            self.act.try_into()?,
434        )?)?;
435        if self.gate.quantized_act_type().is_some() {
436            res = res.to_dtype(original_dtype)?;
437        }
438        Ok(res)
439    }
440}
441
442struct MoeGate {
443    weight: Tensor,
444    cfg: DeepSeekV3Config,
445    top_k: usize,
446    n_routed_experts: usize,
447    e_score_correction_bias: Option<Tensor>,
448}
449
450impl MoeGate {
451    fn new(cfg: &DeepSeekV3Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
452        let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
453        let e_score_correction_bias = if matches!(cfg.topk_method, TopkMethod::NoAuxTc) {
454            Some(vb.get_with_hints_dtype(
455                n_routed_experts,
456                "e_score_correction_bias",
457                Default::default(),
458                DType::F32,
459            )?)
460        } else {
461            None
462        };
463        Ok(Self {
464            weight,
465            cfg: cfg.clone(),
466            top_k: cfg.num_experts_per_tok.unwrap(),
467            n_routed_experts,
468            e_score_correction_bias,
469        })
470    }
471
472    /// (topk_idx, topk_weight)
473    fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
474        let (bs, seq_len, h) = xs.dims3()?;
475        // Compute gating score
476        let xs = xs.reshape(((), h))?;
477        let logits = xs
478            .to_dtype(DType::F32)?
479            .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
480        let scores = match self.cfg.scoring_func {
481            ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
482            ScoringFunc::Sigmoid => candle_nn::ops::sigmoid(&logits)?,
483        };
484
485        // Select top-k experts
486        let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
487            TopkMethod::Greedy => {
488                let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
489                (values, indices)
490            }
491            TopkMethod::NoAuxTc => {
492                let Some(e_score_correction_bias) = &self.e_score_correction_bias else {
493                    candle_core::bail!("Expected e_score_correction_bias")
494                };
495                let scores_for_choice = scores
496                    .reshape((bs * seq_len, ()))?
497                    .broadcast_add(&e_score_correction_bias.unsqueeze(0)?)?;
498                // (n, n_group)
499                let group_scores = scores_for_choice
500                    .reshape((bs * seq_len, self.cfg.n_group, ()))?
501                    .topk(2)?
502                    .values
503                    .sum(D::Minus1)?;
504                // (n, topk_group)
505                let group_idx = group_scores.topk(self.cfg.topk_group)?.indices;
506                // (n, n_group)
507                let mut group_mask = group_scores.zeros_like()?;
508                // (n, n_group)
509                group_mask = group_mask.scatter_add(
510                    &group_idx,
511                    &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
512                    1,
513                )?;
514                // (n, e)
515                let score_mask = group_mask
516                    .unsqueeze(D::Minus1)?
517                    .expand((
518                        bs * seq_len,
519                        self.cfg.n_group,
520                        self.n_routed_experts / self.cfg.n_group,
521                    ))?
522                    .reshape((bs * seq_len, ()))?;
523                // (n, e)
524                // Invert the mask
525                let tmp_scores = scores_for_choice.broadcast_mul(&score_mask)?;
526                let topk_idx = tmp_scores.topk(self.top_k)?.indices;
527                (scores.gather(&topk_idx, 1)?, topk_idx)
528            }
529            TopkMethod::GroupLimitedGreedy => {
530                // (n, n_group)
531                let group_scores = scores
532                    .reshape((bs * seq_len, self.cfg.n_group, ()))?
533                    .max(D::Minus1)?;
534                // (n, topk_group)
535                let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices;
536                // (n, n_group)
537                let mut group_mask = group_scores.zeros_like()?;
538                // (n, n_group)
539                group_mask = group_mask.scatter_add(
540                    &group_idx,
541                    &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
542                    1,
543                )?;
544                // (n, e)
545                let score_mask = group_mask
546                    .unsqueeze(D::Minus1)?
547                    .expand((
548                        bs * seq_len,
549                        self.cfg.n_group,
550                        self.n_routed_experts / self.cfg.n_group,
551                    ))?
552                    .reshape((bs, seq_len, ()))?;
553                // (n, e)
554                // Invert the mask
555                let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
556                let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
557                (values, indices)
558            }
559        };
560
561        if matches!(self.cfg.scoring_func, ScoringFunc::Sigmoid) {
562            let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
563            topk_weight = topk_weight.broadcast_div(&denmoninator)?;
564        }
565
566        // Must multiply the scaling factor
567        topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
568
569        Ok((topk_idx, topk_weight))
570    }
571}
572
573struct Moe {
574    experts: Vec<Option<Expert>>,
575    shared_experts: Option<Mlp>,
576    gate: MoeGate,
577    all_reduce: SumAllReduce,
578    experts_start_idx: usize,
579    experts_end_idx: usize,
580    world_size: usize,
581}
582
583impl Moe {
584    #[allow(clippy::too_many_arguments)]
585    fn new(
586        cfg: &DeepSeekV3Config,
587        vb: ShardedVarBuilder,
588        mapper: &dyn DeviceMapper,
589        layer_idx: usize,
590        loading_isq: bool,
591        n_shared_experts: Option<usize>,
592        n_routed_experts: usize,
593        comm: &Arc<mistralrs_quant::Comm>,
594    ) -> Result<Self> {
595        let mut experts = Vec::with_capacity(n_routed_experts);
596        let n_local_experts = n_routed_experts / comm.world_size();
597        let experts_start_idx = comm.rank() * n_local_experts;
598        let experts_end_idx = experts_start_idx + n_local_experts;
599        for i in 0..n_routed_experts {
600            if i >= experts_start_idx && i < experts_end_idx {
601                let vb_e = vb.pp("experts").pp(i);
602                experts.push(Some(Expert::new(
603                    cfg,
604                    mapper.set_device(layer_idx, vb_e, loading_isq),
605                    None,
606                    Some(cfg.moe_intermediate_size),
607                )?));
608            } else {
609                experts.push(None);
610            }
611        }
612        let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
613            let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
614            Some(Mlp::new(
615                mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
616                cfg.hidden_size,
617                intermediate_size,
618                &cfg.quantization_config,
619                cfg.hidden_act,
620                comm,
621            )?)
622        } else {
623            None
624        };
625        let gate = MoeGate::new(
626            cfg,
627            mapper.set_device(layer_idx, vb.pp("gate"), false),
628            n_routed_experts,
629        )?;
630        Ok(Self {
631            experts,
632            shared_experts,
633            gate,
634            all_reduce: SumAllReduce::new(comm),
635            experts_end_idx,
636            experts_start_idx,
637            world_size: comm.world_size(),
638        })
639    }
640
641    fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
642        let mut y = xs.zeros_like()?;
643        let counts = topk_ids
644            .flatten_all()?
645            .bincount(self.experts.len() as u32)?;
646        for (i, count) in counts
647            .iter()
648            .enumerate()
649            .take(self.experts_end_idx)
650            .skip(self.experts_start_idx)
651        {
652            if *count == 0 {
653                continue;
654            }
655            let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
656            let idx = &idx_top.i(0)?.contiguous()?;
657            let top = &idx_top.i(1)?.contiguous()?;
658
659            let expert = self.experts[i]
660                .as_ref()
661                .context("Expert is not present for this rank.")?;
662
663            y = y.index_add(
664                idx,
665                &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
666                    &topk_weight
667                        .index_select(idx, 0)?
668                        .gather(&top.unsqueeze(1)?, 1)?
669                        .squeeze(1)?
670                        .unsqueeze(D::Minus1)?
671                        .to_dtype(xs.dtype())?,
672                )?,
673                0,
674            )?;
675        }
676
677        if self.world_size > 1 {
678            y = self.all_reduce.sum_all_reduce(&y)?;
679        }
680
681        Ok(y)
682    }
683
684    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
685        let identity = xs.clone();
686        let orig_shape = xs.shape();
687        let (topk_idx, topk_weight) = self.gate.forward(xs)?;
688        let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
689
690        let mut y = self
691            .moe_infer(&xs, &topk_idx, &topk_weight)?
692            .reshape(orig_shape)?;
693        if let Some(ref shared_experts) = self.shared_experts {
694            y = (y + shared_experts.forward(&identity)?)?;
695        }
696        Ok(y)
697    }
698}
699
700enum MoeOrMlp {
701    Moe(Moe),
702    Mlp(Mlp),
703}
704
705impl MoeOrMlp {
706    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
707        match self {
708            Self::Mlp(mlp) => mlp.forward(xs),
709            Self::Moe(moe) => moe.forward(xs),
710        }
711    }
712}
713
714struct DecoderLayer {
715    input_layernorm: RmsNorm,
716    post_attention_layernorm: RmsNorm,
717    attn: Attention,
718    moe_or_mlp: MoeOrMlp,
719}
720
721impl DecoderLayer {
722    #[allow(clippy::too_many_arguments)]
723    fn new(
724        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
725        cfg: &DeepSeekV3Config,
726        vb: ShardedVarBuilder,
727        mapper: &dyn DeviceMapper,
728        layer_idx: usize,
729        loading_isq: bool,
730        paged_attn: Option<PagedAttention>,
731        comm: &Arc<mistralrs_quant::Comm>,
732    ) -> Result<Self> {
733        let attn = Attention::new(
734            rotary_emb,
735            cfg,
736            vb.pp("self_attn"),
737            mapper,
738            layer_idx,
739            loading_isq,
740            paged_attn,
741            comm,
742        )?;
743        let input_layernorm = RmsNorm::new(
744            cfg.hidden_size,
745            cfg.rms_norm_eps,
746            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
747        )?;
748        let post_attention_layernorm = RmsNorm::new(
749            cfg.hidden_size,
750            cfg.rms_norm_eps,
751            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
752        )?;
753        let moe_or_mlp = if cfg.n_routed_experts.is_some()
754            && layer_idx >= cfg.first_k_dense_replace
755            && layer_idx % cfg.moe_layer_freq == 0
756        {
757            MoeOrMlp::Moe(Moe::new(
758                cfg,
759                vb.pp("mlp"),
760                mapper,
761                layer_idx,
762                loading_isq,
763                cfg.n_shared_experts,
764                cfg.n_routed_experts.unwrap(),
765                comm,
766            )?)
767        } else {
768            MoeOrMlp::Mlp(Mlp::new(
769                mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
770                cfg.hidden_size,
771                cfg.intermediate_size,
772                &cfg.quantization_config,
773                cfg.hidden_act,
774                comm,
775            )?)
776        };
777
778        Ok(Self {
779            input_layernorm,
780            post_attention_layernorm,
781            attn,
782            moe_or_mlp,
783        })
784    }
785
786    fn forward(
787        &self,
788        xs: &Tensor,
789        attention_mask: Option<&Tensor>,
790        seqlen_offsets: &[usize],
791        kv_cache: &mut KvCache,
792        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
793        flash_params: &FlashParams,
794    ) -> Result<Tensor> {
795        let residual = xs;
796        let xs = self.input_layernorm.forward(xs)?;
797        let xs = self.attn.forward(
798            &xs,
799            attention_mask,
800            seqlen_offsets,
801            kv_cache,
802            metadata,
803            flash_params,
804        )?;
805        let xs = (xs + residual)?;
806        let residual = &xs;
807        let xs = self
808            .moe_or_mlp
809            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
810        residual + xs
811    }
812}
813
814pub struct DeepSeekV3 {
815    lm_head: Arc<dyn QuantMethod>,
816    embed_tokens: Embedding,
817    norm: RmsNorm,
818    layers: Vec<DecoderLayer>,
819    cache: EitherCache,
820    device: Device,
821    max_seq_len: usize,
822    cfg: ModelConfigMetadata,
823    mapper: Box<dyn DeviceMapper + Send + Sync>,
824}
825
826impl DeepSeekV3 {
827    pub fn new(
828        cfg: &DeepSeekV3Config,
829        vb: ShardedVarBuilder,
830        _is_gptx: bool,
831        normal_loading_metadata: NormalLoadingMetadata,
832        attention_mechanism: AttentionImplementation,
833    ) -> Result<Self> {
834        let vb_m = vb.pp("model");
835
836        let mapper = normal_loading_metadata.mapper;
837
838        let embed_tokens = embedding(
839            cfg.vocab_size,
840            cfg.hidden_size,
841            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
842        )?;
843        let lm_head = if !cfg.tie_word_embeddings {
844            ReplicatedLayer::new(
845                cfg.hidden_size,
846                cfg.vocab_size,
847                &None,
848                false,
849                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
850            )?
851        } else {
852            ReplicatedLayer::from_linear(candle_nn::Linear::new(
853                mapper.cast_nm_device(
854                    embed_tokens.embeddings(),
855                    normal_loading_metadata.loading_isq,
856                )?,
857                None,
858            ))?
859        };
860        let norm = RmsNorm::new(
861            cfg.hidden_size,
862            cfg.rms_norm_eps,
863            mapper.set_nm_device(vb_m.pp("norm"), false),
864        )?;
865
866        let mut ropes = HashMap::new();
867        let rope_cfg = DeepSeekV2RopeConfig {
868            rope_scaling: cfg.rope_scaling.clone(),
869            max_position_embeddings: cfg.max_position_embeddings,
870            rope_theta: cfg.rope_theta,
871            qk_rope_head_dim: cfg.qk_rope_head_dim,
872        };
873        for i in 0..cfg.num_hidden_layers {
874            let device = mapper
875                .device_for(i, false)
876                .unwrap_or(&normal_loading_metadata.real_device);
877            ropes.insert(
878                device.location(),
879                Arc::new(DeepSeekV2RotaryEmbedding::new(
880                    &rope_cfg,
881                    vb.dtype(),
882                    device,
883                )?),
884            );
885        }
886
887        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
888        let vb_l = vb_m.pp("layers");
889        for layer_idx in NiceProgressBar::<_, 'b'>(
890            0..cfg.num_hidden_layers,
891            "Loading repeating layers",
892            &normal_loading_metadata.multi_progress,
893        ) {
894            let device = mapper
895                .device_for(layer_idx, false)
896                .unwrap_or(&normal_loading_metadata.real_device);
897            let rotary_emb = ropes
898                .get(&device.location())
899                .expect("No RoPE for device location!")
900                .clone();
901            let paged_attn = match &attention_mechanism {
902                AttentionImplementation::Eager => None,
903                AttentionImplementation::PagedAttention => Some(
904                    PagedAttention::new(cfg.v_head_dim, device, None)
905                        .expect("Failed to create PagedAttention"),
906                ),
907            };
908            let comm = mapper.get_comm_for(layer_idx)?;
909            let layer = DecoderLayer::new(
910                rotary_emb.clone(),
911                cfg,
912                vb_l.pp(layer_idx),
913                &*mapper,
914                layer_idx,
915                normal_loading_metadata.loading_isq,
916                paged_attn,
917                &comm,
918            )?;
919            layers.push(layer)
920        }
921
922        Ok(Self {
923            lm_head,
924            embed_tokens,
925            norm,
926            layers,
927            cache: EitherCache::Normal(NormalCache::new(
928                cfg.num_hidden_layers,
929                cfg.max_position_embeddings,
930            )),
931            device: normal_loading_metadata.real_device.clone(),
932            max_seq_len: cfg.max_position_embeddings,
933            cfg: ModelConfigMetadata {
934                max_seq_len: cfg.max_position_embeddings,
935                num_layers: cfg.num_hidden_layers,
936                hidden_size: cfg.hidden_size,
937                num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
938                    .max(1),
939                num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
940                    .max(1),
941                sliding_window: None,
942                k_head_dim: cfg.q_head_dim(),
943                v_head_dim: if matches!(
944                    attention_mechanism,
945                    AttentionImplementation::PagedAttention
946                ) {
947                    cfg.q_head_dim()
948                } else {
949                    cfg.v_head_dim
950                },
951            },
952            mapper,
953        })
954    }
955
956    #[allow(clippy::too_many_arguments)]
957    pub fn forward(
958        &self,
959        input_ids: &Tensor,
960        seqlen_offsets: &[usize],
961        context_lens: Vec<(usize, usize)>,
962        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
963        flash_params: &FlashParams,
964    ) -> Result<Tensor> {
965        let mut xs = self.embed_tokens.forward(input_ids)?;
966        let cache = &mut self.cache.normal().0;
967        let attention_mask = CausalMasker.make_causal_mask_matrix(
968            input_ids,
969            metadata
970                .as_ref()
971                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
972                .unwrap_or(cache as &dyn PastKvLenCache),
973            xs.dtype(),
974            self.cfg.num_attn_heads,
975        )?;
976        // PagedAttention prompt chunking
977        let attention_mask = attention_mask.filter(|_| {
978            metadata
979                .as_ref()
980                .map(|(_, meta)| meta.is_first_prompt_chunk)
981                .unwrap_or(true)
982        });
983        for (i, layer) in self.layers.iter().enumerate() {
984            xs = self.mapper.map(xs, i)?;
985            xs = layer.forward(
986                &xs,
987                attention_mask
988                    .as_ref()
989                    .map(|m| m.to_device(xs.device()).unwrap())
990                    .as_ref(),
991                seqlen_offsets,
992                &mut cache[i],
993                metadata
994                    .as_ref()
995                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
996                flash_params,
997            )?;
998        }
999        let xs = xs.to_device(&self.device)?;
1000        let xs = xs.apply(&self.norm)?;
1001        extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
1002    }
1003}
1004
1005impl IsqModel for DeepSeekV3 {
1006    fn get_layers(
1007        &mut self,
1008    ) -> (
1009        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1010        &dyn DeviceMapper,
1011    ) {
1012        let mut tensors = Vec::new();
1013        tensors.push((&mut self.lm_head, None));
1014        for (i, layer) in self.layers.iter_mut().enumerate() {
1015            match &mut layer.attn.q {
1016                QProj::Plain(q) => {
1017                    tensors.push((q, Some(i)));
1018                }
1019                QProj::Lora { a, norm: _, b } => {
1020                    tensors.push((a, Some(i)));
1021                    tensors.push((b, Some(i)));
1022                }
1023            }
1024            tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
1025            tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
1026            tensors.push((&mut layer.attn.o_proj, Some(i)));
1027            match &mut layer.moe_or_mlp {
1028                MoeOrMlp::Mlp(mlp) => {
1029                    tensors.push((&mut mlp.gate, Some(i)));
1030                    tensors.push((&mut mlp.up, Some(i)));
1031                    tensors.push((&mut mlp.down, Some(i)));
1032                }
1033                MoeOrMlp::Moe(moe) => {
1034                    for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1035                        tensors.push((&mut mlp.gate, Some(i)));
1036                        tensors.push((&mut mlp.up, Some(i)));
1037                        tensors.push((&mut mlp.down, Some(i)));
1038                    }
1039                    if let Some(mlp) = &mut moe.shared_experts {
1040                        tensors.push((&mut mlp.gate, Some(i)));
1041                        tensors.push((&mut mlp.up, Some(i)));
1042                        tensors.push((&mut mlp.down, Some(i)));
1043                    }
1044                }
1045            }
1046        }
1047        (tensors, &*self.mapper)
1048    }
1049
1050    fn get_layers_moe_experts_only(
1051        &mut self,
1052    ) -> (
1053        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1054        &dyn DeviceMapper,
1055    ) {
1056        let mut tensors = Vec::new();
1057        tensors.push((&mut self.lm_head, None));
1058        for (i, layer) in self.layers.iter_mut().enumerate() {
1059            match &mut layer.moe_or_mlp {
1060                MoeOrMlp::Mlp(mlp) => {
1061                    tensors.push((&mut mlp.gate, Some(i)));
1062                    tensors.push((&mut mlp.up, Some(i)));
1063                    tensors.push((&mut mlp.down, Some(i)));
1064                }
1065                MoeOrMlp::Moe(moe) => {
1066                    for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1067                        tensors.push((&mut mlp.gate, Some(i)));
1068                        tensors.push((&mut mlp.up, Some(i)));
1069                        tensors.push((&mut mlp.down, Some(i)));
1070                    }
1071                    if let Some(mlp) = &mut moe.shared_experts {
1072                        tensors.push((&mut mlp.gate, Some(i)));
1073                        tensors.push((&mut mlp.up, Some(i)));
1074                        tensors.push((&mut mlp.down, Some(i)));
1075                    }
1076                }
1077            }
1078        }
1079        (tensors, &*self.mapper)
1080    }
1081
1082    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1083        let uvb = UnVarBuilder::new();
1084
1085        let uvb_m = uvb.pp("model");
1086        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1087        uvb_m.pp("norm").add(&self.norm);
1088
1089        for (layer_idx, layer) in self.layers.iter().enumerate() {
1090            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1091            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1092            uvb_l
1093                .pp("post_attention_layernorm")
1094                .add(&layer.post_attention_layernorm);
1095
1096            uvb_l
1097                .pp("self_attn")
1098                .pp("kv_a_layernorm")
1099                .add(&layer.attn.kv_a_layernorm);
1100
1101            match &layer.moe_or_mlp {
1102                MoeOrMlp::Moe(moe) => {
1103                    uvb_l
1104                        .pp("mlp")
1105                        .pp("gate")
1106                        .add_tensor("weight", moe.gate.weight.clone());
1107                }
1108                MoeOrMlp::Mlp(_) => (),
1109            }
1110
1111            match &layer.attn.q {
1112                QProj::Plain(_) => (),
1113                QProj::Lora { a: _, norm, b: _ } => {
1114                    uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1115                }
1116            }
1117        }
1118
1119        uvb.to_safetensors()
1120    }
1121
1122    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1123        let uvb = UnVarBuilder::new();
1124
1125        let uvb_m = uvb.pp("model");
1126        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1127        uvb_m.pp("norm").add(&self.norm);
1128
1129        for (layer_idx, layer) in self.layers.iter().enumerate() {
1130            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1131            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1132            uvb_l
1133                .pp("post_attention_layernorm")
1134                .add(&layer.post_attention_layernorm);
1135
1136            uvb_l
1137                .pp("self_attn")
1138                .pp("kv_a_layernorm")
1139                .add(&layer.attn.kv_a_layernorm);
1140
1141            match &layer.moe_or_mlp {
1142                MoeOrMlp::Moe(moe) => {
1143                    uvb_l
1144                        .pp("mlp")
1145                        .pp("gate")
1146                        .add_tensor("weight", moe.gate.weight.clone());
1147                }
1148                MoeOrMlp::Mlp(_) => (),
1149            }
1150
1151            match &layer.attn.q {
1152                QProj::Plain(q) => {
1153                    uvb_l.pp("self_attn").pp("q_proj").add(q);
1154                }
1155                QProj::Lora { a, norm, b } => {
1156                    uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1157                    uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1158                    uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1159                }
1160            }
1161            uvb_l
1162                .pp("self_attn")
1163                .pp("kv_a_proj_with_mqa")
1164                .add(&layer.attn.kv_a_proj_with_mqa);
1165            uvb_l
1166                .pp("self_attn")
1167                .pp("kv_b_proj")
1168                .add(&layer.attn.kv_b_proj);
1169            uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1170        }
1171
1172        Some(uvb.to_safetensors())
1173    }
1174}
1175
1176impl NormalModel for DeepSeekV3 {
1177    fn forward(
1178        &self,
1179        input_ids: &Tensor,
1180        seqlen_offsets: &[usize],
1181        context_lens: Vec<(usize, usize)>,
1182        _position_ids: Vec<usize>,
1183        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1184        flash_params: &FlashParams,
1185    ) -> Result<Tensor> {
1186        self.forward(
1187            input_ids,
1188            seqlen_offsets,
1189            context_lens,
1190            metadata,
1191            flash_params,
1192        )
1193    }
1194    fn xlora_forward(
1195        &self,
1196        _input_ids: &Tensor,
1197        _input_ids_full: &Tensor,
1198        _seqlen_offsets: &[usize],
1199        _seqlen_offsets_full: &[usize],
1200        _no_kv_cache: bool,
1201        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1202        _context_lens: Vec<(usize, usize)>,
1203        _position_ids: Vec<usize>,
1204        _flash_params: &FlashParams,
1205        _flash_params_full: &FlashParams,
1206    ) -> Result<Tensor> {
1207        unimplemented!()
1208    }
1209    fn cache(&self) -> &EitherCache {
1210        &self.cache
1211    }
1212    fn cache_mut(&mut self) -> &mut EitherCache {
1213        &mut self.cache
1214    }
1215    fn device(&self) -> &Device {
1216        &self.device
1217    }
1218    fn is_xlora(&self) -> bool {
1219        false
1220    }
1221    fn max_seq_len(&self) -> usize {
1222        self.max_seq_len
1223    }
1224    fn config(&self) -> &ModelConfigMetadata {
1225        &self.cfg
1226    }
1227}
1228
1229impl AnyMoeBaseModelMixin for DeepSeekV3 {}