mistralrs_core/models/
deepseek2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Context, DType, Device, IndexOp, Result, Tensor, D};
6use candle_nn::{Embedding, Module};
7use mistralrs_quant::{
8    distributed::DistributedOperation, ColumnParallelLayer, QuantMethod, QuantizedConfig,
9    ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, SumAllReduce,
10};
11use serde::Deserialize;
12
13use crate::{
14    amoe::AnyMoeBaseModelMixin,
15    attention::SdpaParams,
16    device_map::DeviceMapper,
17    layers::{
18        embedding, Activation, CausalMasker, DeepSeekV2RopeConfig, DeepSeekV2RopeScaling,
19        DeepSeekV2RotaryEmbedding, Mlp, RmsNorm, Sdpa,
20    },
21    layers_masker::{masked_fill, PastKvLenCache},
22    ops::{BincountOp, NonZeroOp, SplitOp, TopKLastDimOp, TopKOutput},
23    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24    pipeline::{
25        extract_logits,
26        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28    },
29    serde_default_fn,
30    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32
33serde_default_fn!(f64, routed_scaling_factor, 1.0);
34serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy);
35serde_default_fn!(usize, moe_layer_freq, 1);
36serde_default_fn!(usize, first_k_dense_replace, 0);
37serde_default_fn!(bool, norm_topk_prob, false);
38serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax);
39serde_default_fn!(Activation, hidden_act, Activation::Silu);
40serde_default_fn!(bool, tie_word_embeddings, false);
41serde_default_fn!(bool, use_flash_attn_default, false);
42
43#[derive(Deserialize, Clone, Debug)]
44enum TopkMethod {
45    #[serde(rename = "greedy")]
46    Greedy,
47    #[serde(rename = "group_limited_greedy")]
48    GroupLimitedGreedy,
49}
50
51#[derive(Deserialize, Clone, Debug)]
52enum ScoringFunc {
53    #[serde(rename = "softmax")]
54    Softmax,
55}
56
57#[derive(Deserialize, Clone, Debug)]
58pub struct DeepSeekV2Config {
59    pub(crate) vocab_size: usize,
60    pub(crate) hidden_size: usize,
61    pub(crate) intermediate_size: usize,
62    pub(crate) moe_intermediate_size: usize,
63    pub(crate) num_hidden_layers: usize,
64    pub(crate) num_attention_heads: usize,
65    pub(crate) n_shared_experts: Option<usize>,
66    pub(crate) n_routed_experts: Option<usize>,
67    #[serde(default = "routed_scaling_factor")]
68    pub(crate) routed_scaling_factor: f64,
69    #[serde(default = "topk_method")]
70    topk_method: TopkMethod,
71    pub(crate) num_experts_per_tok: Option<usize>,
72    #[serde(default = "moe_layer_freq")]
73    pub(crate) moe_layer_freq: usize,
74    #[serde(default = "first_k_dense_replace")]
75    pub(crate) first_k_dense_replace: usize,
76    // k dense layers
77    #[serde(default = "norm_topk_prob")]
78    pub(crate) norm_topk_prob: bool,
79    #[serde(default = "scoring_func")]
80    scoring_func: ScoringFunc,
81    #[serde(default = "hidden_act")]
82    pub(crate) hidden_act: Activation,
83    pub(crate) max_position_embeddings: usize,
84    pub(crate) rms_norm_eps: f64,
85    #[serde(default = "tie_word_embeddings")]
86    pub(crate) tie_word_embeddings: bool,
87    pub(crate) rope_theta: f32,
88    pub(crate) rope_scaling: Option<DeepSeekV2RopeScaling>,
89    pub(crate) attention_bias: bool,
90    pub(crate) q_lora_rank: Option<usize>,
91    pub(crate) qk_rope_head_dim: usize,
92    pub(crate) kv_lora_rank: usize,
93    pub(crate) v_head_dim: usize,
94    pub(crate) qk_nope_head_dim: usize,
95    #[serde(default = "use_flash_attn_default")]
96    pub(crate) use_flash_attn: bool,
97    pub(crate) quantization_config: Option<QuantizedConfig>,
98    pub(crate) n_group: usize,
99    pub(crate) topk_group: usize,
100}
101
102impl DeepSeekV2Config {
103    pub(crate) fn q_head_dim(&self) -> usize {
104        self.qk_rope_head_dim + self.qk_nope_head_dim
105    }
106
107    fn softmax_scale(&self) -> f32 {
108        let mut softmax_scale = 1.0 / (self.q_head_dim() as f32).sqrt();
109        if let Some(DeepSeekV2RopeScaling::Yarn {
110            mscale_all_dim,
111            factor,
112            ..
113        }) = self.rope_scaling
114        {
115            let mscale = DeepSeekV2RotaryEmbedding::yarn_get_mscale(factor, mscale_all_dim);
116            softmax_scale = softmax_scale * mscale * mscale;
117        }
118        softmax_scale
119    }
120}
121
122enum QProj {
123    Plain(Arc<dyn QuantMethod>),
124    Lora {
125        a: Arc<dyn QuantMethod>,
126        norm: RmsNorm,
127        b: Arc<dyn QuantMethod>,
128    },
129}
130
131impl QProj {
132    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
133        match self {
134            Self::Lora { a, norm, b } => {
135                b.forward_autocast(&norm.forward(&a.forward_autocast(xs)?)?)
136            }
137            Self::Plain(lin) => lin.forward_autocast(xs),
138        }
139    }
140}
141
142struct Attention {
143    q: QProj,
144    kv_a_proj_with_mqa: Arc<dyn QuantMethod>,
145    kv_a_layernorm: RmsNorm,
146    kv_b_proj: Arc<dyn QuantMethod>,
147    o_proj: Arc<dyn QuantMethod>,
148    rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
149    cfg: DeepSeekV2Config,
150    q_head_dim: usize,
151    paged_attn: Option<PagedAttention>,
152    sdpa_params: SdpaParams,
153    num_attention_heads: usize,
154}
155
156impl Attention {
157    #[allow(clippy::too_many_arguments)]
158    fn new(
159        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
160        cfg: &DeepSeekV2Config,
161        vb: ShardedVarBuilder,
162        mapper: &dyn DeviceMapper,
163        layer_idx: usize,
164        loading_isq: bool,
165        paged_attn: Option<PagedAttention>,
166        comm: &Arc<mistralrs_quant::Comm>,
167    ) -> Result<Self> {
168        let q_head_dim = cfg.q_head_dim();
169        let q = match cfg.q_lora_rank {
170            Some(lora_rank) => {
171                let a = ReplicatedLayer::new(
172                    cfg.hidden_size,
173                    lora_rank,
174                    &cfg.quantization_config,
175                    cfg.attention_bias,
176                    mapper.set_device(layer_idx, vb.pp("q_a_proj"), loading_isq),
177                )?;
178                let norm = RmsNorm::new(
179                    lora_rank,
180                    cfg.rms_norm_eps,
181                    mapper.set_device(layer_idx, vb.pp("q_a_layernorm"), false),
182                )?;
183                let b = ColumnParallelLayer::new(
184                    lora_rank,
185                    cfg.num_attention_heads * q_head_dim,
186                    &cfg.quantization_config,
187                    false,
188                    comm,
189                    mapper.set_device(layer_idx, vb.pp("q_b_proj"), loading_isq),
190                )?;
191                QProj::Lora { a, norm, b }
192            }
193            None => QProj::Plain(ColumnParallelLayer::new(
194                cfg.hidden_size,
195                cfg.num_attention_heads * q_head_dim,
196                &cfg.quantization_config,
197                false,
198                comm,
199                mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
200            )?),
201        };
202
203        let kv_a_proj_with_mqa = ReplicatedLayer::new(
204            cfg.hidden_size,
205            cfg.kv_lora_rank + cfg.qk_rope_head_dim,
206            &cfg.quantization_config,
207            cfg.attention_bias,
208            mapper.set_device(layer_idx, vb.pp("kv_a_proj_with_mqa"), loading_isq),
209        )?;
210        let kv_a_layernorm = RmsNorm::new(
211            cfg.kv_lora_rank,
212            cfg.rms_norm_eps,
213            mapper.set_device(layer_idx, vb.pp("kv_a_layernorm"), false),
214        )?;
215        let kv_b_proj = ColumnParallelLayer::new(
216            cfg.kv_lora_rank,
217            cfg.num_attention_heads * (q_head_dim - cfg.qk_rope_head_dim + cfg.v_head_dim),
218            &cfg.quantization_config,
219            false,
220            comm,
221            mapper.set_device(layer_idx, vb.pp("kv_b_proj"), loading_isq),
222        )?;
223
224        let o_proj = RowParallelLayer::new(
225            cfg.num_attention_heads * cfg.v_head_dim,
226            cfg.hidden_size,
227            &cfg.quantization_config,
228            cfg.attention_bias,
229            comm,
230            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
231        )?;
232
233        Ok(Self {
234            q,
235            kv_a_proj_with_mqa,
236            kv_a_layernorm,
237            kv_b_proj,
238            o_proj,
239            rotary_emb,
240            cfg: cfg.clone(),
241            q_head_dim,
242            paged_attn,
243            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
244            sdpa_params: SdpaParams {
245                n_kv_groups: 1,
246                use_flash_attn: cfg.use_flash_attn,
247                softcap: None,
248                softmax_scale: cfg.softmax_scale(),
249                sliding_window: None,
250            },
251        })
252    }
253
254    fn forward(
255        &self,
256        xs: &Tensor,
257        attention_mask: Option<&Tensor>,
258        seqlen_offsets: &[usize],
259        kv_cache: &mut KvCache,
260        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
261        flash_params: &FlashParams,
262    ) -> Result<Tensor> {
263        let (bs, seq_len, _) = xs.dims3()?;
264
265        let mut q = self.q.forward(xs)?;
266        q = q
267            .reshape((bs, seq_len, self.num_attention_heads, self.q_head_dim))?
268            .transpose(1, 2)?;
269        let q_split = q.split(
270            &[self.cfg.qk_nope_head_dim, self.cfg.qk_rope_head_dim],
271            D::Minus1,
272        )?;
273        let q_nope = q_split[0].clone();
274        let mut q_pe = q_split[1].clone();
275
276        let mut compressed_kv = self.kv_a_proj_with_mqa.forward_autocast(xs)?;
277        let ckv_split = compressed_kv.split(
278            &[self.cfg.kv_lora_rank, self.cfg.qk_rope_head_dim],
279            D::Minus1,
280        )?;
281        compressed_kv = ckv_split[0].clone();
282        let mut k_pe = ckv_split[1].clone();
283        k_pe = k_pe
284            .reshape((bs, seq_len, 1, self.cfg.qk_rope_head_dim))?
285            .transpose(1, 2)?;
286        let mut kv = self
287            .kv_b_proj
288            .forward_autocast(&self.kv_a_layernorm.forward(&compressed_kv)?)?;
289        kv = kv
290            .reshape((
291                bs,
292                seq_len,
293                self.num_attention_heads,
294                self.cfg.qk_nope_head_dim + self.cfg.v_head_dim,
295            ))?
296            .transpose(1, 2)?;
297
298        let kv_split = kv.split(&[self.cfg.qk_nope_head_dim, self.cfg.v_head_dim], D::Minus1)?;
299        let k_nope = kv_split[0].clone();
300        let mut v = kv_split[1].clone();
301
302        (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?;
303
304        let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?;
305        let mut k = Tensor::cat(
306            &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?],
307            D::Minus1,
308        )?
309        .contiguous()?;
310
311        let mut attn_out = match &self.paged_attn {
312            Some(paged_attn) => match metadata {
313                Some(((key_cache, value_cache), input_metadata)) => {
314                    let v = v
315                        .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
316                        .contiguous()?;
317                    paged_attn
318                        .forward(
319                            &q,
320                            &k,
321                            &v,
322                            attention_mask,
323                            Some(key_cache),
324                            Some(value_cache),
325                            input_metadata,
326                            &self.sdpa_params,
327                            Some(flash_params),
328                        )?
329                        .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
330                }
331                None => {
332                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
333                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
334                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
335                    // Sanity check.
336                    assert!(attention_mask.is_some());
337                    let v = v
338                        .pad_with_zeros(D::Minus1, 0, self.q_head_dim - self.cfg.v_head_dim)?
339                        .contiguous()?;
340                    paged_attn
341                        .forward(
342                            &q,
343                            &k,
344                            &v,
345                            attention_mask,
346                            None,
347                            None,
348                            &input_metadata,
349                            &self.sdpa_params,
350                            Some(flash_params),
351                        )?
352                        .narrow(D::Minus1, 0, self.cfg.v_head_dim)?
353                }
354            },
355            None => {
356                (k, v) = kv_cache.append(&k, &v)?;
357
358                Sdpa.run_attention(
359                    &q,
360                    &k,
361                    &v,
362                    attention_mask,
363                    Some(flash_params),
364                    &self.sdpa_params,
365                )?
366            }
367        };
368
369        attn_out = if attention_mask.is_some() {
370            attn_out.transpose(1, 2)?.reshape((bs, seq_len, ()))?
371        } else {
372            attn_out.reshape((bs, seq_len, ()))?
373        };
374
375        self.o_proj.forward_autocast(&attn_out)
376    }
377}
378
379struct Expert {
380    gate: Arc<dyn QuantMethod>,
381    up: Arc<dyn QuantMethod>,
382    down: Arc<dyn QuantMethod>,
383    act: Activation,
384}
385
386impl Expert {
387    fn new(
388        cfg: &DeepSeekV2Config,
389        vb: ShardedVarBuilder,
390        hidden_size: Option<usize>,
391        intermediate_size: Option<usize>,
392    ) -> Result<Self> {
393        let hidden_size = hidden_size.unwrap_or(cfg.hidden_size);
394        let intermediate_size = intermediate_size.unwrap_or(cfg.intermediate_size);
395
396        Ok(Self {
397            gate: ReplicatedLayer::new(
398                hidden_size,
399                intermediate_size,
400                &cfg.quantization_config,
401                false,
402                vb.pp("gate_proj"),
403            )?,
404            up: ReplicatedLayer::new(
405                hidden_size,
406                intermediate_size,
407                &cfg.quantization_config,
408                false,
409                vb.pp("up_proj"),
410            )?,
411            down: ReplicatedLayer::new(
412                intermediate_size,
413                hidden_size,
414                &cfg.quantization_config,
415                false,
416                vb.pp("down_proj"),
417            )?,
418            act: cfg.hidden_act,
419        })
420    }
421
422    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423        let original_dtype = xs.dtype();
424        let mut xs = xs.clone();
425        if let Some(t) = self.gate.quantized_act_type() {
426            xs = xs.to_dtype(t)?;
427        }
428        let lhs = self.gate.forward(&xs)?;
429        let rhs = self.up.forward(&xs)?;
430        let mut res = self.down.forward(&candle_nn::ops::mul_and_act(
431            &lhs,
432            &rhs,
433            self.act.try_into()?,
434        )?)?;
435        if self.gate.quantized_act_type().is_some() {
436            res = res.to_dtype(original_dtype)?;
437        }
438        Ok(res)
439    }
440}
441
442struct MoeGate {
443    weight: Tensor,
444    cfg: DeepSeekV2Config,
445    top_k: usize,
446    n_routed_experts: usize,
447}
448
449impl MoeGate {
450    fn new(cfg: &DeepSeekV2Config, vb: ShardedVarBuilder, n_routed_experts: usize) -> Result<Self> {
451        let weight = vb.get((n_routed_experts, cfg.hidden_size), "weight")?;
452        Ok(Self {
453            weight,
454            cfg: cfg.clone(),
455            top_k: cfg.num_experts_per_tok.unwrap(),
456            n_routed_experts,
457        })
458    }
459
460    /// (topk_idx, topk_weight)
461    fn forward(&self, xs: &Tensor) -> Result<(Tensor, Tensor)> {
462        let (bs, seq_len, h) = xs.dims3()?;
463        // Compute gating score
464        let xs = xs.reshape(((), h))?;
465        let logits = xs
466            .to_dtype(DType::F32)?
467            .broadcast_matmul(&self.weight.t()?.to_dtype(DType::F32)?)?;
468        let scores = match self.cfg.scoring_func {
469            ScoringFunc::Softmax => candle_nn::ops::softmax_last_dim(&logits)?,
470        };
471
472        // Select top-k experts
473        let (mut topk_weight, topk_idx) = match self.cfg.topk_method {
474            TopkMethod::Greedy => {
475                let TopKOutput { values, indices } = scores.topk_unsorted(self.top_k)?;
476                (values, indices)
477            }
478            TopkMethod::GroupLimitedGreedy => {
479                // (n, n_group)
480                let group_scores = scores
481                    .reshape((bs * seq_len, self.cfg.n_group, ()))?
482                    .max(D::Minus1)?;
483                // (n, topk_group)
484                let group_idx = scores.topk_unsorted(self.cfg.topk_group)?.indices;
485                // (n, n_group)
486                let mut group_mask = group_scores.zeros_like()?;
487                // (n, n_group)
488                group_mask = group_mask.scatter_add(
489                    &group_idx,
490                    &group_idx.ones_like()?.to_dtype(group_mask.dtype())?,
491                    1,
492                )?;
493                // (n, e)
494                let score_mask = group_mask
495                    .unsqueeze(D::Minus1)?
496                    .expand((
497                        bs * seq_len,
498                        self.cfg.n_group,
499                        self.n_routed_experts / self.cfg.n_group,
500                    ))?
501                    .reshape((bs, seq_len, ()))?;
502                // (n, e)
503                // Invert the mask
504                let tmp_scores = masked_fill(&score_mask, &(1. - &score_mask.ne(0.)?)?, 0.)?;
505                let TopKOutput { values, indices } = tmp_scores.topk_unsorted(self.top_k)?;
506                (values, indices)
507            }
508        };
509
510        if self.top_k > 1 && self.cfg.norm_topk_prob {
511            let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?;
512            topk_weight = (topk_weight / denmoninator)?;
513        } else {
514            topk_weight = (topk_weight * self.cfg.routed_scaling_factor)?;
515        }
516        Ok((topk_idx, topk_weight))
517    }
518}
519
520struct Moe {
521    experts: Vec<Option<Expert>>,
522    shared_experts: Option<Mlp>,
523    gate: MoeGate,
524    all_reduce: SumAllReduce,
525    experts_start_idx: usize,
526    experts_end_idx: usize,
527    world_size: usize,
528}
529
530impl Moe {
531    #[allow(clippy::too_many_arguments)]
532    fn new(
533        cfg: &DeepSeekV2Config,
534        vb: ShardedVarBuilder,
535        mapper: &dyn DeviceMapper,
536        layer_idx: usize,
537        loading_isq: bool,
538        n_shared_experts: Option<usize>,
539        n_routed_experts: usize,
540        comm: &Arc<mistralrs_quant::Comm>,
541    ) -> Result<Self> {
542        let mut experts = Vec::with_capacity(n_routed_experts);
543        let n_local_experts = n_routed_experts / comm.world_size();
544        let experts_start_idx = comm.rank() * n_local_experts;
545        let experts_end_idx = experts_start_idx + n_local_experts;
546        for i in 0..n_routed_experts {
547            if i >= experts_start_idx && i < experts_end_idx {
548                let vb_e = vb.pp("experts").pp(i);
549                experts.push(Some(Expert::new(
550                    cfg,
551                    mapper.set_device(layer_idx, vb_e, loading_isq),
552                    None,
553                    Some(cfg.moe_intermediate_size),
554                )?));
555            } else {
556                experts.push(None);
557            }
558        }
559        let shared_experts = if let Some(n_shared_experts) = n_shared_experts {
560            let intermediate_size = cfg.moe_intermediate_size * n_shared_experts;
561            Some(Mlp::new(
562                mapper.set_device(layer_idx, vb.pp("shared_experts"), loading_isq),
563                cfg.hidden_size,
564                intermediate_size,
565                &cfg.quantization_config,
566                cfg.hidden_act,
567                comm,
568            )?)
569        } else {
570            None
571        };
572        let gate = MoeGate::new(
573            cfg,
574            mapper.set_device(layer_idx, vb.pp("gate"), false),
575            n_routed_experts,
576        )?;
577        Ok(Self {
578            experts,
579            shared_experts,
580            gate,
581            all_reduce: SumAllReduce::new(comm),
582            experts_end_idx,
583            experts_start_idx,
584            world_size: comm.world_size(),
585        })
586    }
587
588    fn moe_infer(&self, xs: &Tensor, topk_ids: &Tensor, topk_weight: &Tensor) -> Result<Tensor> {
589        let mut y = xs.zeros_like()?;
590        let counts = topk_ids
591            .flatten_all()?
592            .bincount(self.experts.len() as u32)?;
593        for (i, count) in counts
594            .iter()
595            .enumerate()
596            .take(self.experts_end_idx)
597            .skip(self.experts_start_idx)
598        {
599            if *count == 0 {
600                continue;
601            }
602            let idx_top = topk_ids.eq(i as f64)?.nonzero()?.t()?;
603            let idx = &idx_top.i(0)?.contiguous()?;
604            let top = &idx_top.i(1)?.contiguous()?;
605
606            let expert = self.experts[i]
607                .as_ref()
608                .context("Expert is not present for this rank.")?;
609
610            y = y.index_add(
611                idx,
612                &expert.forward(&xs.index_select(idx, 0)?)?.broadcast_mul(
613                    &topk_weight
614                        .index_select(idx, 0)?
615                        .gather(&top.unsqueeze(1)?, 1)?
616                        .squeeze(1)?
617                        .unsqueeze(D::Minus1)?
618                        .to_dtype(xs.dtype())?,
619                )?,
620                0,
621            )?;
622        }
623
624        if self.world_size > 1 {
625            y = self.all_reduce.sum_all_reduce(&y)?;
626        }
627
628        Ok(y)
629    }
630
631    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
632        let identity = xs.clone();
633        let orig_shape = xs.shape();
634        let (topk_idx, topk_weight) = self.gate.forward(xs)?;
635        let xs = xs.reshape(((), xs.dim(D::Minus1)?))?;
636
637        let mut y = self
638            .moe_infer(&xs, &topk_idx, &topk_weight)?
639            .reshape(orig_shape)?;
640        if let Some(ref shared_experts) = self.shared_experts {
641            y = (y + shared_experts.forward(&identity)?)?;
642        }
643        Ok(y)
644    }
645}
646
647enum MoeOrMlp {
648    Moe(Moe),
649    Mlp(Mlp),
650}
651
652impl MoeOrMlp {
653    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
654        match self {
655            Self::Mlp(mlp) => mlp.forward(xs),
656            Self::Moe(moe) => moe.forward(xs),
657        }
658    }
659}
660
661struct DecoderLayer {
662    input_layernorm: RmsNorm,
663    post_attention_layernorm: RmsNorm,
664    attn: Attention,
665    moe_or_mlp: MoeOrMlp,
666}
667
668impl DecoderLayer {
669    #[allow(clippy::too_many_arguments)]
670    fn new(
671        rotary_emb: Arc<DeepSeekV2RotaryEmbedding>,
672        cfg: &DeepSeekV2Config,
673        vb: ShardedVarBuilder,
674        mapper: &dyn DeviceMapper,
675        layer_idx: usize,
676        loading_isq: bool,
677        paged_attn: Option<PagedAttention>,
678        comm: &Arc<mistralrs_quant::Comm>,
679    ) -> Result<Self> {
680        let attn = Attention::new(
681            rotary_emb,
682            cfg,
683            vb.pp("self_attn"),
684            mapper,
685            layer_idx,
686            loading_isq,
687            paged_attn,
688            comm,
689        )?;
690        let input_layernorm = RmsNorm::new(
691            cfg.hidden_size,
692            cfg.rms_norm_eps,
693            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
694        )?;
695        let post_attention_layernorm = RmsNorm::new(
696            cfg.hidden_size,
697            cfg.rms_norm_eps,
698            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
699        )?;
700        let moe_or_mlp = if cfg.n_routed_experts.is_some()
701            && layer_idx >= cfg.first_k_dense_replace
702            && layer_idx % cfg.moe_layer_freq == 0
703        {
704            MoeOrMlp::Moe(Moe::new(
705                cfg,
706                vb.pp("mlp"),
707                mapper,
708                layer_idx,
709                loading_isq,
710                cfg.n_shared_experts,
711                cfg.n_routed_experts.unwrap(),
712                comm,
713            )?)
714        } else {
715            MoeOrMlp::Mlp(Mlp::new(
716                mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
717                cfg.hidden_size,
718                cfg.intermediate_size,
719                &cfg.quantization_config,
720                cfg.hidden_act,
721                comm,
722            )?)
723        };
724
725        Ok(Self {
726            input_layernorm,
727            post_attention_layernorm,
728            attn,
729            moe_or_mlp,
730        })
731    }
732
733    fn forward(
734        &self,
735        xs: &Tensor,
736        attention_mask: Option<&Tensor>,
737        seqlen_offsets: &[usize],
738        kv_cache: &mut KvCache,
739        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
740        flash_params: &FlashParams,
741    ) -> Result<Tensor> {
742        let residual = xs;
743        let xs = self.input_layernorm.forward(xs)?;
744        let xs = self.attn.forward(
745            &xs,
746            attention_mask,
747            seqlen_offsets,
748            kv_cache,
749            metadata,
750            flash_params,
751        )?;
752        let xs = (xs + residual)?;
753        let residual = &xs;
754        let xs = self
755            .moe_or_mlp
756            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
757        residual + xs
758    }
759}
760
761pub struct DeepSeekV2 {
762    lm_head: Arc<dyn QuantMethod>,
763    embed_tokens: Embedding,
764    norm: RmsNorm,
765    layers: Vec<DecoderLayer>,
766    cache: EitherCache,
767    device: Device,
768    max_seq_len: usize,
769    cfg: ModelConfigMetadata,
770    mapper: Box<dyn DeviceMapper + Send + Sync>,
771}
772
773impl DeepSeekV2 {
774    pub fn new(
775        cfg: &DeepSeekV2Config,
776        vb: ShardedVarBuilder,
777        _is_gptx: bool,
778        normal_loading_metadata: NormalLoadingMetadata,
779        attention_mechanism: AttentionImplementation,
780    ) -> Result<Self> {
781        let vb_m = vb.pp("model");
782
783        let mapper = normal_loading_metadata.mapper;
784
785        let embed_tokens = embedding(
786            cfg.vocab_size,
787            cfg.hidden_size,
788            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
789        )?;
790        let lm_head = if !cfg.tie_word_embeddings {
791            ReplicatedLayer::new(
792                cfg.hidden_size,
793                cfg.vocab_size,
794                &None,
795                false,
796                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
797            )?
798        } else {
799            ReplicatedLayer::from_linear(candle_nn::Linear::new(
800                mapper.cast_nm_device(
801                    embed_tokens.embeddings(),
802                    normal_loading_metadata.loading_isq,
803                )?,
804                None,
805            ))?
806        };
807        let norm = RmsNorm::new(
808            cfg.hidden_size,
809            cfg.rms_norm_eps,
810            mapper.set_nm_device(vb_m.pp("norm"), false),
811        )?;
812
813        let mut ropes = HashMap::new();
814        let rope_cfg = DeepSeekV2RopeConfig {
815            rope_scaling: cfg.rope_scaling.clone(),
816            max_position_embeddings: cfg.max_position_embeddings,
817            rope_theta: cfg.rope_theta,
818            qk_rope_head_dim: cfg.qk_rope_head_dim,
819        };
820        for i in 0..cfg.num_hidden_layers {
821            let device = mapper
822                .device_for(i, false)
823                .unwrap_or(&normal_loading_metadata.real_device);
824            ropes.insert(
825                device.location(),
826                Arc::new(DeepSeekV2RotaryEmbedding::new(
827                    &rope_cfg,
828                    vb.dtype(),
829                    device,
830                )?),
831            );
832        }
833
834        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
835        let vb_l = vb_m.pp("layers");
836        for layer_idx in NiceProgressBar::<_, 'b'>(
837            0..cfg.num_hidden_layers,
838            "Loading repeating layers",
839            &normal_loading_metadata.multi_progress,
840        ) {
841            let device = mapper
842                .device_for(layer_idx, false)
843                .unwrap_or(&normal_loading_metadata.real_device);
844            let rotary_emb = ropes
845                .get(&device.location())
846                .expect("No RoPE for device location!")
847                .clone();
848            let paged_attn = match &attention_mechanism {
849                AttentionImplementation::Eager => None,
850                AttentionImplementation::PagedAttention => Some(
851                    PagedAttention::new(cfg.v_head_dim, device, None)
852                        .expect("Failed to create PagedAttention"),
853                ),
854            };
855            let comm = mapper.get_comm_for(layer_idx)?;
856            let layer = DecoderLayer::new(
857                rotary_emb.clone(),
858                cfg,
859                vb_l.pp(layer_idx),
860                &*mapper,
861                layer_idx,
862                normal_loading_metadata.loading_isq,
863                paged_attn,
864                &comm,
865            )?;
866            layers.push(layer)
867        }
868
869        Ok(Self {
870            lm_head,
871            embed_tokens,
872            norm,
873            layers,
874            cache: EitherCache::Normal(NormalCache::new(
875                cfg.num_hidden_layers,
876                cfg.max_position_embeddings,
877            )),
878            device: normal_loading_metadata.real_device.clone(),
879            max_seq_len: cfg.max_position_embeddings,
880            cfg: ModelConfigMetadata {
881                max_seq_len: cfg.max_position_embeddings,
882                num_layers: cfg.num_hidden_layers,
883                hidden_size: cfg.hidden_size,
884                num_kv_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
885                    .max(1),
886                num_attn_heads: (cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size())
887                    .max(1),
888                sliding_window: None,
889                k_head_dim: cfg.q_head_dim(),
890                v_head_dim: if matches!(
891                    attention_mechanism,
892                    AttentionImplementation::PagedAttention
893                ) {
894                    cfg.q_head_dim()
895                } else {
896                    cfg.v_head_dim
897                },
898            },
899            mapper,
900        })
901    }
902
903    #[allow(clippy::too_many_arguments)]
904    pub fn forward(
905        &self,
906        input_ids: &Tensor,
907        seqlen_offsets: &[usize],
908        context_lens: Vec<(usize, usize)>,
909        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
910        flash_params: &FlashParams,
911    ) -> Result<Tensor> {
912        let mut xs = self.embed_tokens.forward(input_ids)?;
913        let cache = &mut self.cache.normal().0;
914        let attention_mask = CausalMasker.make_causal_mask_matrix(
915            input_ids,
916            metadata
917                .as_ref()
918                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
919                .unwrap_or(cache as &dyn PastKvLenCache),
920            xs.dtype(),
921            self.cfg.num_attn_heads,
922        )?;
923        // PagedAttention prompt chunking
924        let attention_mask = attention_mask.filter(|_| {
925            metadata
926                .as_ref()
927                .map(|(_, meta)| meta.is_first_prompt_chunk)
928                .unwrap_or(true)
929        });
930        for (i, layer) in self.layers.iter().enumerate() {
931            xs = self.mapper.map(xs, i)?;
932            xs = layer.forward(
933                &xs,
934                attention_mask
935                    .as_ref()
936                    .map(|m| m.to_device(xs.device()).unwrap())
937                    .as_ref(),
938                seqlen_offsets,
939                &mut cache[i],
940                metadata
941                    .as_ref()
942                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
943                flash_params,
944            )?;
945        }
946        let xs = xs.to_device(&self.device)?;
947        let xs = xs.apply(&self.norm)?;
948        extract_logits(&self.lm_head.forward_autocast(&xs)?, context_lens)
949    }
950}
951
952impl IsqModel for DeepSeekV2 {
953    fn get_layers(
954        &mut self,
955    ) -> (
956        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
957        &dyn DeviceMapper,
958    ) {
959        let mut tensors = Vec::new();
960        tensors.push((&mut self.lm_head, None));
961        for (i, layer) in self.layers.iter_mut().enumerate() {
962            match &mut layer.attn.q {
963                QProj::Plain(q) => {
964                    tensors.push((q, Some(i)));
965                }
966                QProj::Lora { a, norm: _, b } => {
967                    tensors.push((a, Some(i)));
968                    tensors.push((b, Some(i)));
969                }
970            }
971            tensors.push((&mut layer.attn.kv_a_proj_with_mqa, Some(i)));
972            tensors.push((&mut layer.attn.kv_b_proj, Some(i)));
973            tensors.push((&mut layer.attn.o_proj, Some(i)));
974            match &mut layer.moe_or_mlp {
975                MoeOrMlp::Mlp(mlp) => {
976                    tensors.push((&mut mlp.gate, Some(i)));
977                    tensors.push((&mut mlp.up, Some(i)));
978                    tensors.push((&mut mlp.down, Some(i)));
979                }
980                MoeOrMlp::Moe(moe) => {
981                    for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
982                        tensors.push((&mut mlp.gate, Some(i)));
983                        tensors.push((&mut mlp.up, Some(i)));
984                        tensors.push((&mut mlp.down, Some(i)));
985                    }
986                    if let Some(mlp) = &mut moe.shared_experts {
987                        tensors.push((&mut mlp.gate, Some(i)));
988                        tensors.push((&mut mlp.up, Some(i)));
989                        tensors.push((&mut mlp.down, Some(i)));
990                    }
991                }
992            }
993        }
994        (tensors, &*self.mapper)
995    }
996
997    fn get_layers_moe_experts_only(
998        &mut self,
999    ) -> (
1000        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
1001        &dyn DeviceMapper,
1002    ) {
1003        let mut tensors = Vec::new();
1004        tensors.push((&mut self.lm_head, None));
1005        for (i, layer) in self.layers.iter_mut().enumerate() {
1006            match &mut layer.moe_or_mlp {
1007                MoeOrMlp::Mlp(mlp) => {
1008                    tensors.push((&mut mlp.gate, Some(i)));
1009                    tensors.push((&mut mlp.up, Some(i)));
1010                    tensors.push((&mut mlp.down, Some(i)));
1011                }
1012                MoeOrMlp::Moe(moe) => {
1013                    for mlp in moe.experts.iter_mut().filter_map(|e| e.as_mut()) {
1014                        tensors.push((&mut mlp.gate, Some(i)));
1015                        tensors.push((&mut mlp.up, Some(i)));
1016                        tensors.push((&mut mlp.down, Some(i)));
1017                    }
1018                    if let Some(mlp) = &mut moe.shared_experts {
1019                        tensors.push((&mut mlp.gate, Some(i)));
1020                        tensors.push((&mut mlp.up, Some(i)));
1021                        tensors.push((&mut mlp.down, Some(i)));
1022                    }
1023                }
1024            }
1025        }
1026        (tensors, &*self.mapper)
1027    }
1028
1029    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
1030        let uvb = UnVarBuilder::new();
1031
1032        let uvb_m = uvb.pp("model");
1033        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1034        uvb_m.pp("norm").add(&self.norm);
1035
1036        for (layer_idx, layer) in self.layers.iter().enumerate() {
1037            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1038            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1039            uvb_l
1040                .pp("post_attention_layernorm")
1041                .add(&layer.post_attention_layernorm);
1042
1043            uvb_l
1044                .pp("self_attn")
1045                .pp("kv_a_layernorm")
1046                .add(&layer.attn.kv_a_layernorm);
1047
1048            match &layer.moe_or_mlp {
1049                MoeOrMlp::Moe(moe) => {
1050                    uvb_l
1051                        .pp("mlp")
1052                        .pp("gate")
1053                        .add_tensor("weight", moe.gate.weight.clone());
1054                }
1055                MoeOrMlp::Mlp(_) => (),
1056            }
1057
1058            match &layer.attn.q {
1059                QProj::Plain(_) => (),
1060                QProj::Lora { a: _, norm, b: _ } => {
1061                    uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1062                }
1063            }
1064        }
1065
1066        uvb.to_safetensors()
1067    }
1068
1069    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
1070        let uvb = UnVarBuilder::new();
1071
1072        let uvb_m = uvb.pp("model");
1073        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
1074        uvb_m.pp("norm").add(&self.norm);
1075
1076        for (layer_idx, layer) in self.layers.iter().enumerate() {
1077            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
1078            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
1079            uvb_l
1080                .pp("post_attention_layernorm")
1081                .add(&layer.post_attention_layernorm);
1082
1083            uvb_l
1084                .pp("self_attn")
1085                .pp("kv_a_layernorm")
1086                .add(&layer.attn.kv_a_layernorm);
1087
1088            match &layer.moe_or_mlp {
1089                MoeOrMlp::Moe(moe) => {
1090                    uvb_l
1091                        .pp("mlp")
1092                        .pp("gate")
1093                        .add_tensor("weight", moe.gate.weight.clone());
1094                }
1095                MoeOrMlp::Mlp(_) => (),
1096            }
1097
1098            match &layer.attn.q {
1099                QProj::Plain(q) => {
1100                    uvb_l.pp("self_attn").pp("q_proj").add(q);
1101                }
1102                QProj::Lora { a, norm, b } => {
1103                    uvb_l.pp("self_attn").pp("q_a_proj").add(a);
1104                    uvb_l.pp("self_attn").pp("q_a_layernorm").add(norm);
1105                    uvb_l.pp("self_attn").pp("q_b_proj").add(b);
1106                }
1107            }
1108            uvb_l
1109                .pp("self_attn")
1110                .pp("kv_a_proj_with_mqa")
1111                .add(&layer.attn.kv_a_proj_with_mqa);
1112            uvb_l
1113                .pp("self_attn")
1114                .pp("kv_b_proj")
1115                .add(&layer.attn.kv_b_proj);
1116            uvb_l.pp("self_attn").pp("o_proj").add(&layer.attn.o_proj);
1117        }
1118
1119        Some(uvb.to_safetensors())
1120    }
1121}
1122
1123impl NormalModel for DeepSeekV2 {
1124    fn forward(
1125        &self,
1126        input_ids: &Tensor,
1127        seqlen_offsets: &[usize],
1128        context_lens: Vec<(usize, usize)>,
1129        _position_ids: Vec<usize>,
1130        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
1131        flash_params: &FlashParams,
1132    ) -> Result<Tensor> {
1133        self.forward(
1134            input_ids,
1135            seqlen_offsets,
1136            context_lens,
1137            metadata,
1138            flash_params,
1139        )
1140    }
1141    fn xlora_forward(
1142        &self,
1143        _input_ids: &Tensor,
1144        _input_ids_full: &Tensor,
1145        _seqlen_offsets: &[usize],
1146        _seqlen_offsets_full: &[usize],
1147        _no_kv_cache: bool,
1148        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
1149        _context_lens: Vec<(usize, usize)>,
1150        _position_ids: Vec<usize>,
1151        _flash_params: &FlashParams,
1152        _flash_params_full: &FlashParams,
1153    ) -> Result<Tensor> {
1154        unimplemented!()
1155    }
1156    fn cache(&self) -> &EitherCache {
1157        &self.cache
1158    }
1159    fn cache_mut(&mut self) -> &mut EitherCache {
1160        &mut self.cache
1161    }
1162    fn device(&self) -> &Device {
1163        &self.device
1164    }
1165    fn is_xlora(&self) -> bool {
1166        false
1167    }
1168    fn max_seq_len(&self) -> usize {
1169        self.max_seq_len
1170    }
1171    fn config(&self) -> &ModelConfigMetadata {
1172        &self.cfg
1173    }
1174}
1175
1176impl AnyMoeBaseModelMixin for DeepSeekV2 {}