mistralrs_core/models/
qwen2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use mistralrs_quant::{
5    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
6    ShardedVarBuilder,
7};
8use std::{collections::HashMap, sync::Arc};
9
10use crate::{
11    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    get_delta_from_lora_ab,
15    layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
16    layers_masker::PastKvLenCache,
17    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
22    },
23    serde_default_fn,
24    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27serde_default_fn!(bool, word_emb_default, false);
28serde_default_fn!(bool, use_flash_attn, false);
29
30#[derive(Debug, Clone, serde::Deserialize, Default, serde::Serialize)]
31pub struct Config {
32    pub vocab_size: usize,
33    pub hidden_size: usize,
34    pub intermediate_size: usize,
35    pub num_hidden_layers: usize,
36    pub num_attention_heads: usize,
37    pub num_key_value_heads: usize,
38    pub max_position_embeddings: usize,
39    pub sliding_window: Option<usize>,
40    pub rope_theta: f64,
41    pub rms_norm_eps: f64,
42    pub hidden_act: Activation,
43    #[serde(default = "use_flash_attn")]
44    pub use_flash_attn: bool,
45    pub quantization_config: Option<QuantizedConfig>,
46    #[serde(default = "word_emb_default")]
47    pub tie_word_embeddings: bool,
48}
49
50struct Attention {
51    q_proj: Arc<dyn QuantMethod>,
52    k_proj: Arc<dyn QuantMethod>,
53    v_proj: Arc<dyn QuantMethod>,
54    o_proj: Arc<dyn QuantMethod>,
55    num_heads: usize,
56    num_kv_heads: usize,
57    head_dim: usize,
58    rotary_emb: Arc<RotaryEmbedding>,
59    paged_attn: Option<PagedAttention>,
60    sdpa_params: SdpaParams,
61}
62
63impl Attention {
64    fn new(
65        rotary_emb: Arc<RotaryEmbedding>,
66        cfg: &Config,
67        vb: ShardedVarBuilder,
68        paged_attn: Option<PagedAttention>,
69        comm: &Arc<mistralrs_quant::Comm>,
70    ) -> Result<Self> {
71        let hidden_sz = cfg.hidden_size;
72        let num_heads = cfg.num_attention_heads;
73        let num_kv_heads = cfg.num_key_value_heads;
74        let head_dim = hidden_sz / num_heads;
75        let q_proj = ColumnParallelLayer::new(
76            hidden_sz,
77            num_heads * head_dim,
78            &cfg.quantization_config,
79            true,
80            comm,
81            vb.pp("q_proj"),
82        )?;
83        let kv_shard = mistralrs_quant::compute_kv_shard(
84            cfg.num_key_value_heads,
85            cfg.hidden_size / cfg.num_attention_heads,
86            comm,
87        );
88        let k_proj = ColumnParallelLayer::new_with_shard(
89            hidden_sz,
90            num_kv_heads * head_dim,
91            &cfg.quantization_config,
92            true,
93            comm,
94            kv_shard,
95            vb.pp("k_proj"),
96        )?;
97        let v_proj = ColumnParallelLayer::new_with_shard(
98            hidden_sz,
99            num_kv_heads * head_dim,
100            &cfg.quantization_config,
101            true,
102            comm,
103            kv_shard,
104            vb.pp("v_proj"),
105        )?;
106        let o_proj = RowParallelLayer::new(
107            num_heads * head_dim,
108            hidden_sz,
109            &cfg.quantization_config,
110            false,
111            comm,
112            vb.pp("o_proj"),
113        )?;
114        Ok(Self {
115            q_proj,
116            k_proj,
117            v_proj,
118            o_proj,
119            num_heads: num_heads / comm.world_size(),
120            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
121            head_dim,
122            rotary_emb,
123            paged_attn,
124            sdpa_params: SdpaParams {
125                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
126                    cfg.num_key_value_heads,
127                    cfg.num_attention_heads,
128                    comm,
129                ),
130                use_flash_attn: cfg.use_flash_attn,
131                softcap: None,
132                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
133                sliding_window: cfg.sliding_window,
134            },
135        })
136    }
137
138    #[allow(clippy::too_many_arguments)]
139    fn forward(
140        &self,
141        xs: &Tensor,
142        attention_mask: Option<&Tensor>,
143        seqlen_offsets: &[usize],
144        kv_cache: &mut KvCache,
145        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
146        flash_params: &FlashParams,
147    ) -> Result<Tensor> {
148        let (b_sz, q_len, _) = xs.dims3()?;
149
150        let original_dtype = xs.dtype();
151        let mut xs = xs.clone();
152        if let Some(t) = self.q_proj.quantized_act_type() {
153            xs = xs.to_dtype(t)?;
154        }
155        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
156        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
157        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
158        if self.q_proj.quantized_act_type().is_some() {
159            q = q.to_dtype(original_dtype)?;
160            k = k.to_dtype(original_dtype)?;
161            v = v.to_dtype(original_dtype)?;
162        }
163
164        let (q, k, v) = if q_len != 1 {
165            let q = q
166                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
167                .transpose(1, 2)?;
168            let k = k
169                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
170                .transpose(1, 2)?;
171            let v = v
172                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173                .transpose(1, 2)?;
174            (q, k, v)
175        } else {
176            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
177            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
179            (q, k, v)
180        };
181
182        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
183
184        let mut attn_output = match &self.paged_attn {
185            Some(paged_attn) => match metadata {
186                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
187                    &q,
188                    &k,
189                    &v,
190                    attention_mask,
191                    Some(key_cache),
192                    Some(value_cache),
193                    input_metadata,
194                    &self.sdpa_params,
195                    Some(flash_params),
196                )?,
197                None => {
198                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
199                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
200                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
201                    // Sanity check.
202                    assert!(attention_mask.is_some());
203                    paged_attn.forward(
204                        &q,
205                        &k,
206                        &v,
207                        attention_mask,
208                        None,
209                        None,
210                        &input_metadata,
211                        &self.sdpa_params,
212                        Some(flash_params),
213                    )?
214                }
215            },
216            None => {
217                let (k, v) = kv_cache.append(&k, &v)?;
218
219                Sdpa.run_attention(
220                    &q,
221                    &k,
222                    &v,
223                    attention_mask,
224                    Some(flash_params),
225                    &self.sdpa_params,
226                )?
227            }
228        };
229
230        if let Some(t) = self.q_proj.quantized_act_type() {
231            attn_output = attn_output.to_dtype(t)?;
232        }
233        attn_output = if attention_mask.is_some() {
234            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
235        } else {
236            attn_output.reshape((b_sz, q_len, ()))?
237        };
238        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
239        if self.q_proj.quantized_act_type().is_some() {
240            res = res.to_dtype(original_dtype)?;
241        }
242        Ok(res)
243    }
244}
245
246struct DecoderLayer {
247    self_attn: Attention,
248    mlp: Box<dyn MlpLayer>,
249    input_layernorm: RmsNorm,
250    post_attention_layernorm: RmsNorm,
251}
252
253impl DecoderLayer {
254    #[allow(clippy::too_many_arguments)]
255    fn new(
256        rotary_emb: Arc<RotaryEmbedding>,
257        cfg: &Config,
258        vb: ShardedVarBuilder,
259        mapper: &dyn DeviceMapper,
260        layer_idx: usize,
261        loading_isq: bool,
262        paged_attn: Option<PagedAttention>,
263        comm: &Arc<mistralrs_quant::Comm>,
264    ) -> Result<Self> {
265        let self_attn = Attention::new(
266            rotary_emb,
267            cfg,
268            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
269            paged_attn,
270            comm,
271        )?;
272        let mlp = Mlp::new(
273            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
274            cfg.hidden_size,
275            cfg.intermediate_size,
276            &cfg.quantization_config,
277            cfg.hidden_act,
278            comm,
279        )?;
280        let input_layernorm = RmsNorm::new(
281            cfg.hidden_size,
282            cfg.rms_norm_eps,
283            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
284        )?;
285        let post_attention_layernorm = RmsNorm::new(
286            cfg.hidden_size,
287            cfg.rms_norm_eps,
288            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
289        )?;
290        Ok(Self {
291            self_attn,
292            mlp: Box::new(mlp),
293            input_layernorm,
294            post_attention_layernorm,
295        })
296    }
297
298    #[allow(clippy::too_many_arguments)]
299    fn forward(
300        &self,
301        xs: &Tensor,
302        attention_mask: Option<&Tensor>,
303        seqlen_offsets: &[usize],
304        kv_cache: &mut KvCache,
305        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
306        flash_params: &FlashParams,
307    ) -> Result<Tensor> {
308        let residual = xs;
309        let xs = self.input_layernorm.forward(xs)?;
310        let xs = self.self_attn.forward(
311            &xs,
312            attention_mask,
313            seqlen_offsets,
314            kv_cache,
315            metadata,
316            flash_params,
317        )?;
318        let xs = (xs + residual)?;
319        let residual = &xs;
320        let xs = self
321            .mlp
322            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
323        residual + xs
324    }
325}
326
327pub struct Model {
328    embed_tokens: candle_nn::Embedding,
329    layers: Vec<DecoderLayer>,
330    norm: RmsNorm,
331    lm_head: Arc<dyn QuantMethod>,
332    sliding_window: Option<usize>,
333    device: Device,
334    cache: EitherCache,
335    max_seq_len: usize,
336    mapper: Box<dyn DeviceMapper + Send + Sync>,
337    cfg: ModelConfigMetadata,
338}
339
340impl Model {
341    pub fn new(
342        cfg: &Config,
343        vb: ShardedVarBuilder,
344        is_gptx: bool,
345        normal_loading_metadata: NormalLoadingMetadata,
346        attention_mechanism: AttentionImplementation,
347    ) -> Result<Self> {
348        if let Some(ref quant_cfg) = &cfg.quantization_config {
349            tracing::info!(
350                "Using {} quantization: {}.",
351                quant_cfg.name(),
352                quant_cfg.get_bits_name(&vb)
353            );
354        }
355        let mapper = normal_loading_metadata.mapper;
356        let vb_m = vb.pp("model");
357
358        let embed_tokens = embedding(
359            cfg.vocab_size,
360            cfg.hidden_size,
361            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
362            &cfg.quantization_config,
363        )?;
364        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
365        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
366
367        let mut ropes = HashMap::new();
368        for layer_idx in 0..cfg.num_hidden_layers {
369            let device = mapper
370                .device_for(layer_idx, false)
371                .unwrap_or(&normal_loading_metadata.real_device);
372            ropes.insert(
373                device.location(),
374                Arc::new(RotaryEmbedding::new(
375                    cfg.rope_theta as f32,
376                    head_dim,
377                    cfg.max_position_embeddings,
378                    device,
379                    is_gptx,
380                    vb_m.dtype(),
381                )?),
382            );
383        }
384
385        let vb_l = vb_m.pp("layers");
386        for layer_idx in NiceProgressBar::<_, 'b'>(
387            0..cfg.num_hidden_layers,
388            "Loading repeating layers",
389            &normal_loading_metadata.multi_progress,
390        ) {
391            let device = mapper
392                .device_for(layer_idx, false)
393                .unwrap_or(&normal_loading_metadata.real_device);
394            let rotary_emb = ropes
395                .get(&device.location())
396                .expect("No RoPE for device location!")
397                .clone();
398            let paged_attn = match &attention_mechanism {
399                AttentionImplementation::Eager => None,
400                AttentionImplementation::PagedAttention => {
401                    Some(PagedAttention::new(head_dim, device, None)?)
402                }
403            };
404            let comm = mapper.get_comm_for(layer_idx)?;
405            let layer = DecoderLayer::new(
406                rotary_emb.clone(),
407                cfg,
408                vb_l.pp(layer_idx),
409                &*mapper,
410                layer_idx,
411                normal_loading_metadata.loading_isq,
412                paged_attn,
413                &comm,
414            )?;
415            layers.push(layer)
416        }
417        let norm = RmsNorm::new(
418            cfg.hidden_size,
419            cfg.rms_norm_eps,
420            mapper.set_nm_device(vb_m.pp("norm"), false),
421        )?;
422        let lm_head = if !cfg.tie_word_embeddings {
423            ReplicatedLayer::new(
424                cfg.hidden_size,
425                cfg.vocab_size,
426                &None,
427                false,
428                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
429            )?
430        } else {
431            ReplicatedLayer::from_linear(candle_nn::Linear::new(
432                mapper.cast_nm_device(
433                    embed_tokens.embeddings(),
434                    normal_loading_metadata.loading_isq,
435                )?,
436                None,
437            ))?
438        };
439        Ok(Self {
440            embed_tokens,
441            layers,
442            norm,
443            lm_head,
444            sliding_window: cfg.sliding_window,
445            device: normal_loading_metadata.real_device,
446            cache: EitherCache::Normal(NormalCache::new(
447                cfg.num_hidden_layers,
448                cfg.max_position_embeddings,
449            )),
450            max_seq_len: cfg.max_position_embeddings,
451            cfg: ModelConfigMetadata {
452                max_seq_len: cfg.max_position_embeddings,
453                num_layers: cfg.num_hidden_layers,
454                hidden_size: cfg.hidden_size,
455                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
456                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
457                    .max(1),
458                sliding_window: cfg.sliding_window,
459                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
460                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
461            },
462            mapper,
463        })
464    }
465
466    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
467        self.embed_tokens.forward(input_ids)
468    }
469
470    pub fn forward(
471        &self,
472        input_ids: &Tensor,
473        seqlen_offsets: &[usize],
474        context_lens: Vec<(usize, usize)>,
475        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
476        flash_params: &FlashParams,
477    ) -> Result<Tensor> {
478        let xs = self.embed_tokens.forward(input_ids)?;
479        self.forward_embed(
480            input_ids,
481            xs,
482            seqlen_offsets,
483            context_lens,
484            metadata,
485            flash_params,
486        )
487    }
488
489    #[allow(clippy::too_many_arguments)]
490    pub fn forward_embed(
491        &self,
492        input_ids: &Tensor,
493        mut xs: Tensor,
494        seqlen_offsets: &[usize],
495        context_lens: Vec<(usize, usize)>,
496        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
497        flash_params: &FlashParams,
498    ) -> Result<Tensor> {
499        let cache = &mut self.cache.normal().0;
500        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
501            input_ids,
502            metadata
503                .as_ref()
504                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
505                .unwrap_or(cache as &dyn PastKvLenCache),
506            self.sliding_window,
507            xs.dtype(),
508            self.cfg.num_attn_heads,
509        )?;
510        let attention_mask = attention_mask.filter(|_| {
511            metadata
512                .as_ref()
513                .map(|(_, meta)| meta.is_first_prompt_chunk)
514                .unwrap_or(true)
515        });
516        for (i, layer) in self.layers.iter().enumerate() {
517            xs = self.mapper.map(xs, i)?;
518            xs = layer.forward(
519                &xs,
520                attention_mask
521                    .as_ref()
522                    .map(|m| m.to_device(xs.device()).unwrap())
523                    .as_ref(),
524                seqlen_offsets,
525                &mut cache[i],
526                metadata
527                    .as_ref()
528                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
529                flash_params,
530            )?
531        }
532        let xs = xs.to_device(&self.device)?;
533        let mut xs = xs.apply(&self.norm)?;
534        if let Some(t) = self.lm_head.quantized_act_type() {
535            xs = xs.to_dtype(t)?;
536        }
537        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
538    }
539
540    pub fn embed_dtype(&self) -> DType {
541        self.embed_tokens.embeddings().dtype()
542    }
543}
544
545impl IsqModel for Model {
546    fn get_layers(
547        &mut self,
548    ) -> (
549        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
550        &dyn DeviceMapper,
551    ) {
552        let mut tensors = Vec::new();
553        tensors.push((&mut self.lm_head, None));
554        for (i, layer) in self.layers.iter_mut().enumerate() {
555            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
556            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
557            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
558            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
559            tensors.extend(
560                layer
561                    .mlp
562                    .get_isq_layers()
563                    .into_iter()
564                    .map(|m| (m, Some(i)))
565                    .collect::<Vec<_>>(),
566            );
567        }
568        (tensors, &*self.mapper)
569    }
570
571    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
572        let uvb = UnVarBuilder::new();
573
574        let uvb_m = uvb.pp("model");
575        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
576        uvb_m.pp("norm").add(&self.norm);
577
578        for (layer_idx, layer) in self.layers.iter().enumerate() {
579            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
580            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
581            uvb_l
582                .pp("post_attention_layernorm")
583                .add(&layer.post_attention_layernorm);
584        }
585
586        uvb.to_safetensors()
587    }
588
589    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
590        // NOTE: dependant on the exact implementation in get_layers!
591        let mut names = Vec::new();
592        // lm_head
593        names.push(None);
594        for i in 0..self.layers.len() {
595            names.push(Some(format!("blk.{i}.attn_q.weight")));
596            names.push(Some(format!("blk.{i}.attn_k.weight")));
597            names.push(Some(format!("blk.{i}.attn_v.weight")));
598            names.push(Some(format!("blk.{i}.attn_output.weight")));
599            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
600            names.push(Some(format!("blk.{i}.ffn_up.weight")));
601            names.push(Some(format!("blk.{i}.ffn_down.weight")));
602        }
603        Ok(names)
604    }
605}
606
607impl NormalModel for Model {
608    fn forward(
609        &self,
610        input_ids: &Tensor,
611        seqlen_offsets: &[usize],
612        context_lens: Vec<(usize, usize)>,
613        _position_ids: Vec<usize>,
614        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
615        flash_params: &FlashParams,
616    ) -> Result<Tensor> {
617        self.forward(
618            input_ids,
619            seqlen_offsets,
620            context_lens,
621            metadata,
622            flash_params,
623        )
624    }
625    fn xlora_forward(
626        &self,
627        _input_ids: &Tensor,
628        _input_ids_full: &Tensor,
629        _seqlen_offsets: &[usize],
630        _seqlen_offsets_full: &[usize],
631        _no_kv_cache: bool,
632        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
633        _context_lens: Vec<(usize, usize)>,
634        _position_ids: Vec<usize>,
635        _flash_params: &FlashParams,
636        _flash_params_full: &FlashParams,
637    ) -> Result<Tensor> {
638        unimplemented!()
639    }
640    fn cache(&self) -> &EitherCache {
641        &self.cache
642    }
643    fn cache_mut(&mut self) -> &mut EitherCache {
644        &mut self.cache
645    }
646    fn device(&self) -> &Device {
647        &self.device
648    }
649    fn is_xlora(&self) -> bool {
650        false
651    }
652    fn max_seq_len(&self) -> usize {
653        self.max_seq_len
654    }
655    fn config(&self) -> &ModelConfigMetadata {
656        &self.cfg
657    }
658}
659
660impl AnyMoeBaseModelMixin for Model {
661    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
662        let mut mlps = Vec::new();
663        for layer in &self.layers {
664            mlps.push(&*layer.mlp);
665        }
666        mlps
667    }
668    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
669        let mut mlps = Vec::new();
670        for layer in &mut self.layers {
671            mlps.push(&mut layer.mlp);
672        }
673        mlps
674    }
675    fn create_anymoe_layers(
676        &mut self,
677        additional_vbs: Vec<ShardedVarBuilder>,
678        config: AnyMoeConfig,
679        (prefix, mlp): (String, String),
680        mut layers: Vec<usize>,
681        expert_type: AnyMoeExpertType,
682        gate_vb: Option<ShardedVarBuilder>,
683    ) -> Result<()> {
684        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
685        if layers.is_empty() {
686            layers = (0..self.layers.len()).collect::<Vec<_>>();
687        }
688        for _ in 0..layers.len() {
689            experts.push(Vec::new());
690        }
691        for vb in additional_vbs {
692            let vb = vb.pp(&prefix);
693            for (layer, row) in experts.iter_mut().enumerate() {
694                if !layers.contains(&layer) {
695                    continue;
696                }
697
698                let intermediate_size = self.layers[layer].mlp.get_params()[1];
699                let hidden_size = self.layers[layer].mlp.get_params()[0];
700                match expert_type {
701                    AnyMoeExpertType::FineTuned => {
702                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
703                        row.push(Box::new(Mlp::replicate(
704                            self.layers[layer].mlp.get_params(),
705                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
706                            self.layers[layer].mlp.hidden_act(),
707                            &self.mapper.get_comm_for(layer)?,
708                        )?));
709                    }
710                    AnyMoeExpertType::LoraAdapter {
711                        rank,
712                        alpha,
713                        ref target_modules,
714                    } => {
715                        let vb_mlp = vb.pp(layer).pp(&mlp);
716
717                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
718                            Some(get_delta_from_lora_ab!(
719                                vb_mlp,
720                                rank,
721                                alpha,
722                                (hidden_size, intermediate_size),
723                                "gate_proj"
724                            ))
725                        } else {
726                            None
727                        };
728                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
729                            Some(get_delta_from_lora_ab!(
730                                vb_mlp,
731                                rank,
732                                alpha,
733                                (hidden_size, intermediate_size),
734                                "up_proj"
735                            ))
736                        } else {
737                            None
738                        };
739                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
740                            Some(get_delta_from_lora_ab!(
741                                vb_mlp,
742                                rank,
743                                alpha,
744                                (intermediate_size, hidden_size),
745                                "down_proj"
746                            ))
747                        } else {
748                            None
749                        };
750
751                        row.push(self.layers[layer].mlp.new_added_delta(vec![
752                            gate_proj_delta,
753                            up_proj_delta,
754                            down_proj_delta,
755                        ])?);
756                    }
757                }
758            }
759        }
760        for (layer, expert) in layers.into_iter().zip(experts) {
761            let mut experts_all = vec![self.layers[layer].mlp.clone()];
762            experts_all.extend(expert);
763            let (dtype, device) = self.layers[layer].mlp.dtype_device();
764            self.layers[layer].mlp = Box::new(MoeMlp::new(
765                experts_all,
766                config.clone(),
767                dtype,
768                &device,
769                layer,
770                gate_vb.as_ref(),
771            )?);
772        }
773        Ok(())
774    }
775    fn amoe_supported(&self) -> bool {
776        true
777    }
778}