mistralrs_core/models/
qwen2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use mistralrs_quant::{
5    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
6    ShardedVarBuilder,
7};
8use std::{collections::HashMap, sync::Arc};
9
10use crate::{
11    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    get_delta_from_lora_ab,
15    layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
16    layers_masker::PastKvLenCache,
17    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
18    pipeline::{
19        extract_logits,
20        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
22    },
23    serde_default_fn,
24    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27serde_default_fn!(bool, word_emb_default, false);
28serde_default_fn!(bool, use_flash_attn, false);
29
30#[derive(Debug, Clone, serde::Deserialize, Default, serde::Serialize)]
31pub struct Config {
32    pub vocab_size: usize,
33    pub hidden_size: usize,
34    pub intermediate_size: usize,
35    pub num_hidden_layers: usize,
36    pub num_attention_heads: usize,
37    pub num_key_value_heads: usize,
38    pub max_position_embeddings: usize,
39    pub sliding_window: usize,
40    pub rope_theta: f64,
41    pub rms_norm_eps: f64,
42    pub hidden_act: Activation,
43    #[serde(default = "use_flash_attn")]
44    pub use_flash_attn: bool,
45    pub quantization_config: Option<QuantizedConfig>,
46    #[serde(default = "word_emb_default")]
47    pub tie_word_embeddings: bool,
48}
49
50struct Attention {
51    q_proj: Arc<dyn QuantMethod>,
52    k_proj: Arc<dyn QuantMethod>,
53    v_proj: Arc<dyn QuantMethod>,
54    o_proj: Arc<dyn QuantMethod>,
55    num_heads: usize,
56    num_kv_heads: usize,
57    head_dim: usize,
58    rotary_emb: Arc<RotaryEmbedding>,
59    paged_attn: Option<PagedAttention>,
60    sdpa_params: SdpaParams,
61}
62
63impl Attention {
64    fn new(
65        rotary_emb: Arc<RotaryEmbedding>,
66        cfg: &Config,
67        vb: ShardedVarBuilder,
68        paged_attn: Option<PagedAttention>,
69        comm: &Arc<mistralrs_quant::Comm>,
70    ) -> Result<Self> {
71        let hidden_sz = cfg.hidden_size;
72        let num_heads = cfg.num_attention_heads;
73        let num_kv_heads = cfg.num_key_value_heads;
74        let head_dim = hidden_sz / num_heads;
75        let q_proj = ColumnParallelLayer::new(
76            hidden_sz,
77            num_heads * head_dim,
78            &cfg.quantization_config,
79            true,
80            comm,
81            vb.pp("q_proj"),
82        )?;
83        let kv_shard = mistralrs_quant::compute_kv_shard(
84            cfg.num_key_value_heads,
85            cfg.hidden_size / cfg.num_attention_heads,
86            comm,
87        );
88        let k_proj = ColumnParallelLayer::new_with_shard(
89            hidden_sz,
90            num_kv_heads * head_dim,
91            &cfg.quantization_config,
92            true,
93            comm,
94            kv_shard,
95            vb.pp("k_proj"),
96        )?;
97        let v_proj = ColumnParallelLayer::new_with_shard(
98            hidden_sz,
99            num_kv_heads * head_dim,
100            &cfg.quantization_config,
101            true,
102            comm,
103            kv_shard,
104            vb.pp("v_proj"),
105        )?;
106        let o_proj = RowParallelLayer::new(
107            num_heads * head_dim,
108            hidden_sz,
109            &cfg.quantization_config,
110            false,
111            comm,
112            vb.pp("o_proj"),
113        )?;
114        Ok(Self {
115            q_proj,
116            k_proj,
117            v_proj,
118            o_proj,
119            num_heads: num_heads / comm.world_size(),
120            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
121            head_dim,
122            rotary_emb,
123            paged_attn,
124            sdpa_params: SdpaParams {
125                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
126                    cfg.num_key_value_heads,
127                    cfg.num_attention_heads,
128                    comm,
129                ),
130                use_flash_attn: cfg.use_flash_attn,
131                softcap: None,
132                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
133                sliding_window: None,
134            },
135        })
136    }
137
138    #[allow(clippy::too_many_arguments)]
139    fn forward(
140        &self,
141        xs: &Tensor,
142        attention_mask: Option<&Tensor>,
143        seqlen_offsets: &[usize],
144        kv_cache: &mut KvCache,
145        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
146        flash_params: &FlashParams,
147    ) -> Result<Tensor> {
148        let (b_sz, q_len, _) = xs.dims3()?;
149
150        let original_dtype = xs.dtype();
151        let mut xs = xs.clone();
152        if let Some(t) = self.q_proj.quantized_act_type() {
153            xs = xs.to_dtype(t)?;
154        }
155        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
156        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
157        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
158        if self.q_proj.quantized_act_type().is_some() {
159            q = q.to_dtype(original_dtype)?;
160            k = k.to_dtype(original_dtype)?;
161            v = v.to_dtype(original_dtype)?;
162        }
163
164        let (q, k, v) = if q_len != 1 {
165            let q = q
166                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
167                .transpose(1, 2)?;
168            let k = k
169                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
170                .transpose(1, 2)?;
171            let v = v
172                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173                .transpose(1, 2)?;
174            (q, k, v)
175        } else {
176            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
177            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
179            (q, k, v)
180        };
181
182        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
183
184        let mut attn_output = match &self.paged_attn {
185            Some(paged_attn) => match metadata {
186                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
187                    &q,
188                    &k,
189                    &v,
190                    attention_mask,
191                    Some(key_cache),
192                    Some(value_cache),
193                    input_metadata,
194                    &self.sdpa_params,
195                    Some(flash_params),
196                )?,
197                None => {
198                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
199                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
200                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
201                    // Sanity check.
202                    assert!(attention_mask.is_some());
203                    paged_attn.forward(
204                        &q,
205                        &k,
206                        &v,
207                        attention_mask,
208                        None,
209                        None,
210                        &input_metadata,
211                        &self.sdpa_params,
212                        Some(flash_params),
213                    )?
214                }
215            },
216            None => {
217                let (k, v) = kv_cache.append(&k, &v)?;
218
219                Sdpa.run_attention(
220                    &q,
221                    &k,
222                    &v,
223                    attention_mask,
224                    Some(flash_params),
225                    &self.sdpa_params,
226                )?
227            }
228        };
229
230        if let Some(t) = self.q_proj.quantized_act_type() {
231            attn_output = attn_output.to_dtype(t)?;
232        }
233        attn_output = if attention_mask.is_some() {
234            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
235        } else {
236            attn_output.reshape((b_sz, q_len, ()))?
237        };
238        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
239        if self.q_proj.quantized_act_type().is_some() {
240            res = res.to_dtype(original_dtype)?;
241        }
242        Ok(res)
243    }
244}
245
246struct DecoderLayer {
247    self_attn: Attention,
248    mlp: Box<dyn MlpLayer>,
249    input_layernorm: RmsNorm,
250    post_attention_layernorm: RmsNorm,
251}
252
253impl DecoderLayer {
254    #[allow(clippy::too_many_arguments)]
255    fn new(
256        rotary_emb: Arc<RotaryEmbedding>,
257        cfg: &Config,
258        vb: ShardedVarBuilder,
259        mapper: &dyn DeviceMapper,
260        layer_idx: usize,
261        loading_isq: bool,
262        paged_attn: Option<PagedAttention>,
263        comm: &Arc<mistralrs_quant::Comm>,
264    ) -> Result<Self> {
265        let self_attn = Attention::new(
266            rotary_emb,
267            cfg,
268            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
269            paged_attn,
270            comm,
271        )?;
272        let mlp = Mlp::new(
273            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
274            cfg.hidden_size,
275            cfg.intermediate_size,
276            &cfg.quantization_config,
277            cfg.hidden_act,
278            comm,
279        )?;
280        let input_layernorm = RmsNorm::new(
281            cfg.hidden_size,
282            cfg.rms_norm_eps,
283            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
284        )?;
285        let post_attention_layernorm = RmsNorm::new(
286            cfg.hidden_size,
287            cfg.rms_norm_eps,
288            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
289        )?;
290        Ok(Self {
291            self_attn,
292            mlp: Box::new(mlp),
293            input_layernorm,
294            post_attention_layernorm,
295        })
296    }
297
298    #[allow(clippy::too_many_arguments)]
299    fn forward(
300        &self,
301        xs: &Tensor,
302        attention_mask: Option<&Tensor>,
303        seqlen_offsets: &[usize],
304        kv_cache: &mut KvCache,
305        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
306        flash_params: &FlashParams,
307    ) -> Result<Tensor> {
308        let residual = xs;
309        let xs = self.input_layernorm.forward(xs)?;
310        let xs = self.self_attn.forward(
311            &xs,
312            attention_mask,
313            seqlen_offsets,
314            kv_cache,
315            metadata,
316            flash_params,
317        )?;
318        let xs = (xs + residual)?;
319        let residual = &xs;
320        let xs = self
321            .mlp
322            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
323        residual + xs
324    }
325}
326
327pub struct Model {
328    embed_tokens: candle_nn::Embedding,
329    layers: Vec<DecoderLayer>,
330    norm: RmsNorm,
331    lm_head: Arc<dyn QuantMethod>,
332    sliding_window: usize,
333    device: Device,
334    cache: EitherCache,
335    max_seq_len: usize,
336    mapper: Box<dyn DeviceMapper + Send + Sync>,
337    cfg: ModelConfigMetadata,
338}
339
340impl Model {
341    pub fn new(
342        cfg: &Config,
343        vb: ShardedVarBuilder,
344        is_gptx: bool,
345        normal_loading_metadata: NormalLoadingMetadata,
346        attention_mechanism: AttentionImplementation,
347    ) -> Result<Self> {
348        if let Some(ref quant_cfg) = &cfg.quantization_config {
349            tracing::info!(
350                "Using {} quantization: {}.",
351                quant_cfg.quant_method.to_string(),
352                quant_cfg.get_bits_name(&vb)
353            );
354        }
355        let mapper = normal_loading_metadata.mapper;
356        let vb_m = vb.pp("model");
357
358        let embed_tokens = embedding(
359            cfg.vocab_size,
360            cfg.hidden_size,
361            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
362        )?;
363        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
364        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
365
366        let mut ropes = HashMap::new();
367        for layer_idx in 0..cfg.num_hidden_layers {
368            let device = mapper
369                .device_for(layer_idx, false)
370                .unwrap_or(&normal_loading_metadata.real_device);
371            ropes.insert(
372                device.location(),
373                Arc::new(RotaryEmbedding::new(
374                    cfg.rope_theta as f32,
375                    head_dim,
376                    cfg.max_position_embeddings,
377                    device,
378                    is_gptx,
379                    vb_m.dtype(),
380                )?),
381            );
382        }
383
384        let vb_l = vb_m.pp("layers");
385        for layer_idx in NiceProgressBar::<_, 'b'>(
386            0..cfg.num_hidden_layers,
387            "Loading repeating layers",
388            &normal_loading_metadata.multi_progress,
389        ) {
390            let device = mapper
391                .device_for(layer_idx, false)
392                .unwrap_or(&normal_loading_metadata.real_device);
393            let rotary_emb = ropes
394                .get(&device.location())
395                .expect("No RoPE for device location!")
396                .clone();
397            let paged_attn = match &attention_mechanism {
398                AttentionImplementation::Eager => None,
399                AttentionImplementation::PagedAttention => {
400                    Some(PagedAttention::new(head_dim, device, None)?)
401                }
402            };
403            let comm = mapper.get_comm_for(layer_idx)?;
404            let layer = DecoderLayer::new(
405                rotary_emb.clone(),
406                cfg,
407                vb_l.pp(layer_idx),
408                &*mapper,
409                layer_idx,
410                normal_loading_metadata.loading_isq,
411                paged_attn,
412                &comm,
413            )?;
414            layers.push(layer)
415        }
416        let norm = RmsNorm::new(
417            cfg.hidden_size,
418            cfg.rms_norm_eps,
419            mapper.set_nm_device(vb_m.pp("norm"), false),
420        )?;
421        let lm_head = if !cfg.tie_word_embeddings {
422            ReplicatedLayer::new(
423                cfg.hidden_size,
424                cfg.vocab_size,
425                &None,
426                false,
427                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
428            )?
429        } else {
430            ReplicatedLayer::from_linear(candle_nn::Linear::new(
431                mapper.cast_nm_device(
432                    embed_tokens.embeddings(),
433                    normal_loading_metadata.loading_isq,
434                )?,
435                None,
436            ))?
437        };
438        Ok(Self {
439            embed_tokens,
440            layers,
441            norm,
442            lm_head,
443            sliding_window: cfg.sliding_window,
444            device: normal_loading_metadata.real_device,
445            cache: EitherCache::Normal(NormalCache::new(
446                cfg.num_hidden_layers,
447                cfg.max_position_embeddings,
448            )),
449            max_seq_len: cfg.max_position_embeddings,
450            cfg: ModelConfigMetadata {
451                max_seq_len: cfg.max_position_embeddings,
452                num_layers: cfg.num_hidden_layers,
453                hidden_size: cfg.hidden_size,
454                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
455                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
456                    .max(1),
457                sliding_window: Some(cfg.sliding_window),
458                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
459                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
460            },
461            mapper,
462        })
463    }
464
465    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
466        self.embed_tokens.forward(input_ids)
467    }
468
469    pub fn forward(
470        &self,
471        input_ids: &Tensor,
472        seqlen_offsets: &[usize],
473        context_lens: Vec<(usize, usize)>,
474        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
475        flash_params: &FlashParams,
476    ) -> Result<Tensor> {
477        let xs = self.embed_tokens.forward(input_ids)?;
478        self.forward_embed(
479            input_ids,
480            xs,
481            seqlen_offsets,
482            context_lens,
483            metadata,
484            flash_params,
485        )
486    }
487
488    #[allow(clippy::too_many_arguments)]
489    pub fn forward_embed(
490        &self,
491        input_ids: &Tensor,
492        mut xs: Tensor,
493        seqlen_offsets: &[usize],
494        context_lens: Vec<(usize, usize)>,
495        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
496        flash_params: &FlashParams,
497    ) -> Result<Tensor> {
498        let cache = &mut self.cache.normal().0;
499        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
500            input_ids,
501            metadata
502                .as_ref()
503                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
504                .unwrap_or(cache as &dyn PastKvLenCache),
505            Some(self.sliding_window),
506            xs.dtype(),
507            self.cfg.num_attn_heads,
508        )?;
509        let attention_mask = attention_mask.filter(|_| {
510            metadata
511                .as_ref()
512                .map(|(_, meta)| meta.is_first_prompt_chunk)
513                .unwrap_or(true)
514        });
515        for (i, layer) in self.layers.iter().enumerate() {
516            xs = self.mapper.map(xs, i)?;
517            xs = layer.forward(
518                &xs,
519                attention_mask
520                    .as_ref()
521                    .map(|m| m.to_device(xs.device()).unwrap())
522                    .as_ref(),
523                seqlen_offsets,
524                &mut cache[i],
525                metadata
526                    .as_ref()
527                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
528                flash_params,
529            )?
530        }
531        let xs = xs.to_device(&self.device)?;
532        let mut xs = xs.apply(&self.norm)?;
533        if let Some(t) = self.lm_head.quantized_act_type() {
534            xs = xs.to_dtype(t)?;
535        }
536        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
537    }
538
539    pub fn embed_dtype(&self) -> DType {
540        self.embed_tokens.embeddings().dtype()
541    }
542}
543
544impl IsqModel for Model {
545    fn get_layers(
546        &mut self,
547    ) -> (
548        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
549        &dyn DeviceMapper,
550    ) {
551        let mut tensors = Vec::new();
552        tensors.push((&mut self.lm_head, None));
553        for (i, layer) in self.layers.iter_mut().enumerate() {
554            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
555            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
556            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
557            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
558            tensors.extend(
559                layer
560                    .mlp
561                    .get_isq_layers()
562                    .into_iter()
563                    .map(|m| (m, Some(i)))
564                    .collect::<Vec<_>>(),
565            );
566        }
567        (tensors, &*self.mapper)
568    }
569
570    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
571        let uvb = UnVarBuilder::new();
572
573        let uvb_m = uvb.pp("model");
574        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
575        uvb_m.pp("norm").add(&self.norm);
576
577        for (layer_idx, layer) in self.layers.iter().enumerate() {
578            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
579            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
580            uvb_l
581                .pp("post_attention_layernorm")
582                .add(&layer.post_attention_layernorm);
583        }
584
585        uvb.to_safetensors()
586    }
587
588    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
589        // NOTE: dependant on the exact implementation in get_layers!
590        let mut names = Vec::new();
591        // lm_head
592        names.push(None);
593        for i in 0..self.layers.len() {
594            names.push(Some(format!("blk.{i}.attn_q.weight")));
595            names.push(Some(format!("blk.{i}.attn_k.weight")));
596            names.push(Some(format!("blk.{i}.attn_v.weight")));
597            names.push(Some(format!("blk.{i}.attn_output.weight")));
598            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
599            names.push(Some(format!("blk.{i}.ffn_up.weight")));
600            names.push(Some(format!("blk.{i}.ffn_down.weight")));
601        }
602        Ok(names)
603    }
604}
605
606impl NormalModel for Model {
607    fn forward(
608        &self,
609        input_ids: &Tensor,
610        seqlen_offsets: &[usize],
611        context_lens: Vec<(usize, usize)>,
612        _position_ids: Vec<usize>,
613        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
614        flash_params: &FlashParams,
615    ) -> Result<Tensor> {
616        self.forward(
617            input_ids,
618            seqlen_offsets,
619            context_lens,
620            metadata,
621            flash_params,
622        )
623    }
624    fn xlora_forward(
625        &self,
626        _input_ids: &Tensor,
627        _input_ids_full: &Tensor,
628        _seqlen_offsets: &[usize],
629        _seqlen_offsets_full: &[usize],
630        _no_kv_cache: bool,
631        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
632        _context_lens: Vec<(usize, usize)>,
633        _position_ids: Vec<usize>,
634        _flash_params: &FlashParams,
635        _flash_params_full: &FlashParams,
636    ) -> Result<Tensor> {
637        unimplemented!()
638    }
639    fn cache(&self) -> &EitherCache {
640        &self.cache
641    }
642    fn cache_mut(&mut self) -> &mut EitherCache {
643        &mut self.cache
644    }
645    fn device(&self) -> &Device {
646        &self.device
647    }
648    fn is_xlora(&self) -> bool {
649        false
650    }
651    fn max_seq_len(&self) -> usize {
652        self.max_seq_len
653    }
654    fn config(&self) -> &ModelConfigMetadata {
655        &self.cfg
656    }
657}
658
659impl AnyMoeBaseModelMixin for Model {
660    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
661        let mut mlps = Vec::new();
662        for layer in &self.layers {
663            mlps.push(&*layer.mlp);
664        }
665        mlps
666    }
667    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
668        let mut mlps = Vec::new();
669        for layer in &mut self.layers {
670            mlps.push(&mut layer.mlp);
671        }
672        mlps
673    }
674    fn create_anymoe_layers(
675        &mut self,
676        additional_vbs: Vec<ShardedVarBuilder>,
677        config: AnyMoeConfig,
678        (prefix, mlp): (String, String),
679        mut layers: Vec<usize>,
680        expert_type: AnyMoeExpertType,
681        gate_vb: Option<ShardedVarBuilder>,
682    ) -> Result<()> {
683        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
684        if layers.is_empty() {
685            layers = (0..self.layers.len()).collect::<Vec<_>>();
686        }
687        for _ in 0..layers.len() {
688            experts.push(Vec::new());
689        }
690        for vb in additional_vbs {
691            let vb = vb.pp(&prefix);
692            for (layer, row) in experts.iter_mut().enumerate() {
693                if !layers.contains(&layer) {
694                    continue;
695                }
696
697                let intermediate_size = self.layers[layer].mlp.get_params()[1];
698                let hidden_size = self.layers[layer].mlp.get_params()[0];
699                match expert_type {
700                    AnyMoeExpertType::FineTuned => {
701                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
702                        row.push(Box::new(Mlp::replicate(
703                            self.layers[layer].mlp.get_params(),
704                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
705                            self.layers[layer].mlp.hidden_act(),
706                            &self.mapper.get_comm_for(layer)?,
707                        )?));
708                    }
709                    AnyMoeExpertType::LoraAdapter {
710                        rank,
711                        alpha,
712                        ref target_modules,
713                    } => {
714                        let vb_mlp = vb.pp(layer).pp(&mlp);
715
716                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
717                            Some(get_delta_from_lora_ab!(
718                                vb_mlp,
719                                rank,
720                                alpha,
721                                (hidden_size, intermediate_size),
722                                "gate_proj"
723                            ))
724                        } else {
725                            None
726                        };
727                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
728                            Some(get_delta_from_lora_ab!(
729                                vb_mlp,
730                                rank,
731                                alpha,
732                                (hidden_size, intermediate_size),
733                                "up_proj"
734                            ))
735                        } else {
736                            None
737                        };
738                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
739                            Some(get_delta_from_lora_ab!(
740                                vb_mlp,
741                                rank,
742                                alpha,
743                                (intermediate_size, hidden_size),
744                                "down_proj"
745                            ))
746                        } else {
747                            None
748                        };
749
750                        row.push(self.layers[layer].mlp.new_added_delta(vec![
751                            gate_proj_delta,
752                            up_proj_delta,
753                            down_proj_delta,
754                        ])?);
755                    }
756                }
757            }
758        }
759        for (layer, expert) in layers.into_iter().zip(experts) {
760            let mut experts_all = vec![self.layers[layer].mlp.clone()];
761            experts_all.extend(expert);
762            let (dtype, device) = self.layers[layer].mlp.dtype_device();
763            self.layers[layer].mlp = Box::new(MoeMlp::new(
764                experts_all,
765                config.clone(),
766                dtype,
767                &device,
768                layer,
769                gate_vb.as_ref(),
770            )?);
771        }
772        Ok(())
773    }
774    fn amoe_supported(&self) -> bool {
775        true
776    }
777}