mistralrs_core/models/
llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7    ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{
18        embedding, Activation, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, Mlp,
19        RmsNorm, Sdpa,
20    },
21    layers_masker::PastKvLenCache,
22    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23    pipeline::{
24        extract_logits,
25        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27    },
28    serde_default_fn,
29    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34#[derive(Debug, Clone, Deserialize, Serialize, Default)]
35pub struct Config {
36    pub hidden_act: Activation,
37    pub hidden_size: usize,
38    pub intermediate_size: usize,
39    pub vocab_size: usize,
40    pub num_hidden_layers: usize,
41    pub num_attention_heads: usize,
42    pub num_key_value_heads: usize,
43    pub rms_norm_eps: f64,
44    pub rope_theta: f32,
45    pub max_position_embeddings: usize,
46    pub rope_scaling: Option<Llama3RopeConfig>,
47    pub quantization_config: Option<QuantizedConfig>,
48    #[serde(default = "word_emb_default")]
49    pub tie_word_embeddings: bool,
50}
51
52struct CausalSelfAttention {
53    q_proj: Arc<dyn QuantMethod>,
54    k_proj: Arc<dyn QuantMethod>,
55    v_proj: Arc<dyn QuantMethod>,
56    o_proj: Arc<dyn QuantMethod>,
57    num_attention_heads: usize,
58    num_key_value_heads: usize,
59    head_dim: usize,
60    rotary_emb: Arc<Llama3RotaryEmbedding>,
61    max_seq_len: usize,
62    paged_attn: Option<PagedAttention>,
63    sdpa_params: SdpaParams,
64}
65
66impl CausalSelfAttention {
67    #[allow(clippy::too_many_arguments)]
68    fn forward(
69        &self,
70        x: &Tensor,
71        attention_mask: &Option<Tensor>,
72        seqlen_offsets: &[usize],
73        kv_cache: &mut KvCache,
74        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
75        flash_params: &FlashParams,
76    ) -> Result<Tensor> {
77        let (b_sz, seq_len, _) = x.dims3()?;
78
79        let original_dtype = x.dtype();
80        let mut x = x.clone();
81        if let Some(t) = self.q_proj.quantized_act_type() {
82            x = x.to_dtype(t)?;
83        }
84        let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
85        let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
86        let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
87        if self.q_proj.quantized_act_type().is_some() {
88            q = q.to_dtype(original_dtype)?;
89            k = k.to_dtype(original_dtype)?;
90            v = v.to_dtype(original_dtype)?;
91        }
92
93        let (q, k, v) = if seq_len != 1 {
94            let q = q
95                .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
96                .transpose(1, 2)?;
97            let k = k
98                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
99                .transpose(1, 2)?;
100            let v = v
101                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
102                .transpose(1, 2)?;
103            (q, k, v)
104        } else {
105            let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
106            let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
107            let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
108            (q, k, v)
109        };
110
111        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
112
113        let mut y = match &self.paged_attn {
114            Some(paged_attn) => match metadata {
115                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
116                    &q,
117                    &k,
118                    &v,
119                    attention_mask.clone().as_ref(),
120                    Some(key_cache),
121                    Some(value_cache),
122                    input_metadata,
123                    &self.sdpa_params,
124                    Some(flash_params),
125                )?,
126                None => {
127                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
128                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
129                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
130                    // Sanity check.
131                    assert!(attention_mask.is_some());
132                    paged_attn.forward(
133                        &q,
134                        &k,
135                        &v,
136                        attention_mask.clone().as_ref(),
137                        None,
138                        None,
139                        &input_metadata,
140                        &self.sdpa_params,
141                        Some(flash_params),
142                    )?
143                }
144            },
145            None => {
146                let (k, v) = kv_cache.append(&k, &v)?;
147
148                Sdpa.run_attention(
149                    &q,
150                    &k,
151                    &v,
152                    attention_mask.clone().as_ref(),
153                    Some(flash_params),
154                    &self.sdpa_params,
155                )?
156            }
157        };
158
159        if let Some(t) = self.q_proj.quantized_act_type() {
160            y = y.to_dtype(t)?;
161        }
162        y = if attention_mask.is_some() {
163            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
164        } else {
165            y.reshape((b_sz, seq_len, ()))?
166        };
167        let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
168        if self.q_proj.quantized_act_type().is_some() {
169            res = res.to_dtype(original_dtype)?;
170        }
171        Ok(res)
172    }
173
174    fn load(
175        vb: ShardedVarBuilder,
176        cfg: &Config,
177        rope: Arc<Llama3RotaryEmbedding>,
178        paged_attn: Option<PagedAttention>,
179        comm: &Arc<mistralrs_quant::Comm>,
180    ) -> Result<Self> {
181        let size_in = cfg.hidden_size;
182        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
183        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
184        let q_proj = ColumnParallelLayer::new(
185            size_in,
186            size_q,
187            &cfg.quantization_config,
188            false,
189            comm,
190            vb.pp("q_proj"),
191        )?;
192        let kv_shard = mistralrs_quant::compute_kv_shard(
193            cfg.num_key_value_heads,
194            cfg.hidden_size / cfg.num_attention_heads,
195            comm,
196        );
197        let k_proj = ColumnParallelLayer::new_with_shard(
198            size_in,
199            size_kv,
200            &cfg.quantization_config,
201            false,
202            comm,
203            kv_shard,
204            vb.pp("k_proj"),
205        )?;
206        let v_proj = ColumnParallelLayer::new_with_shard(
207            size_in,
208            size_kv,
209            &cfg.quantization_config,
210            false,
211            comm,
212            kv_shard,
213            vb.pp("v_proj"),
214        )?;
215        let o_proj = RowParallelLayer::new(
216            size_q,
217            size_in,
218            &cfg.quantization_config,
219            false,
220            comm,
221            vb.pp("o_proj"),
222        )?;
223        Ok(Self {
224            q_proj,
225            k_proj,
226            v_proj,
227            o_proj,
228            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
229            num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
230            head_dim: cfg.hidden_size / cfg.num_attention_heads,
231            rotary_emb: rope,
232            max_seq_len: cfg.max_position_embeddings,
233            paged_attn,
234            sdpa_params: SdpaParams {
235                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
236                    cfg.num_key_value_heads,
237                    cfg.num_attention_heads,
238                    comm,
239                ),
240                softcap: None,
241                softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
242                sliding_window: None,
243            },
244        })
245    }
246}
247
248struct Block {
249    rms_1: RmsNorm,
250    attn: CausalSelfAttention,
251    rms_2: RmsNorm,
252    mlp: Box<dyn MlpLayer>,
253}
254
255impl Block {
256    #[allow(clippy::too_many_arguments)]
257    fn forward(
258        &self,
259        x: &Tensor,
260        attention_mask: &Option<Tensor>,
261        seqlen_offsets: &[usize],
262        kv_cache: &mut KvCache,
263        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
264        flash_params: &FlashParams,
265    ) -> Result<Tensor> {
266        let residual = x;
267        let x = self.rms_1.forward(x)?;
268        let x = (self.attn.forward(
269            &x,
270            attention_mask,
271            seqlen_offsets,
272            kv_cache,
273            metadata,
274            flash_params,
275        )? + residual)?;
276        let residual = &x;
277        let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
278        Ok(x)
279    }
280
281    #[allow(clippy::too_many_arguments)]
282    fn load(
283        vb: ShardedVarBuilder,
284        cfg: &Config,
285        mapper: &dyn DeviceMapper,
286        layer_idx: usize,
287        loading_isq: bool,
288        rope: Arc<Llama3RotaryEmbedding>,
289        paged_attn: Option<PagedAttention>,
290        comm: &Arc<mistralrs_quant::Comm>,
291    ) -> Result<Self> {
292        let attn = CausalSelfAttention::load(
293            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
294            cfg,
295            rope,
296            paged_attn,
297            comm,
298        )?;
299        let mlp = Mlp::new(
300            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
301            cfg.hidden_size,
302            cfg.intermediate_size,
303            &cfg.quantization_config,
304            cfg.hidden_act,
305            comm,
306        )?;
307        let rms_1 = RmsNorm::new(
308            cfg.hidden_size,
309            cfg.rms_norm_eps,
310            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
311        )?;
312        let rms_2 = RmsNorm::new(
313            cfg.hidden_size,
314            cfg.rms_norm_eps,
315            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
316        )?;
317        Ok(Self {
318            rms_1,
319            attn,
320            rms_2,
321            mlp: Box::new(mlp),
322        })
323    }
324}
325
326pub struct Llama {
327    wte: Embedding,
328    blocks: Vec<Block>,
329    ln_f: RmsNorm,
330    lm_head: Arc<dyn QuantMethod>,
331    kv_cache: crate::pipeline::EitherCache,
332    device: Device,
333    mapper: Box<dyn DeviceMapper + Send + Sync>,
334    cfg: ModelConfigMetadata,
335}
336
337impl Llama {
338    pub fn new(
339        cfg: &Config,
340        vb: ShardedVarBuilder,
341        is_gptx: bool,
342        normal_loading_metadata: NormalLoadingMetadata,
343        attention_mechanism: AttentionImplementation,
344    ) -> Result<Self> {
345        let vb_m = vb.pp("model");
346        let vb_lm_head = vb.pp("lm_head");
347        Self::new_inner(
348            cfg,
349            vb_m,
350            vb_lm_head,
351            is_gptx,
352            normal_loading_metadata,
353            attention_mechanism,
354        )
355    }
356
357    pub fn new_inner(
358        cfg: &Config,
359        vb_m: ShardedVarBuilder,
360        vb_lm_head: ShardedVarBuilder,
361        is_gptx: bool,
362        normal_loading_metadata: NormalLoadingMetadata,
363        attention_mechanism: AttentionImplementation,
364    ) -> Result<Self> {
365        if let Some(ref quant_cfg) = &cfg.quantization_config {
366            tracing::info!(
367                "Using {} quantization: {}.",
368                quant_cfg.name(),
369                quant_cfg.get_bits_name(&vb_m)
370            );
371        }
372        let mapper = normal_loading_metadata.mapper;
373
374        let wte = embedding(
375            cfg.vocab_size,
376            cfg.hidden_size,
377            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
378            &cfg.quantization_config,
379        )?;
380        let lm_head = if !cfg.tie_word_embeddings {
381            ReplicatedLayer::new(
382                cfg.hidden_size,
383                cfg.vocab_size,
384                &cfg.quantization_config,
385                false,
386                mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
387            )?
388        } else {
389            ReplicatedLayer::from_linear(candle_nn::Linear::new(
390                mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
391                None,
392            ))?
393        };
394        let ln_f = RmsNorm::new(
395            cfg.hidden_size,
396            cfg.rms_norm_eps,
397            mapper.set_nm_device(vb_m.pp("norm"), false),
398        )?;
399        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
400        let mut ropes = HashMap::new();
401        for i in 0..cfg.num_hidden_layers {
402            let device = mapper
403                .device_for(i, false)
404                .unwrap_or(&normal_loading_metadata.real_device);
405            ropes.insert(
406                device.location(),
407                Arc::new(Llama3RotaryEmbedding::new_llama3(
408                    vb_m.dtype(),
409                    cfg,
410                    device,
411                    is_gptx,
412                )?),
413            );
414        }
415        let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
416            0..cfg.num_hidden_layers,
417            "Loading repeating layers",
418            &normal_loading_metadata.multi_progress,
419        )
420        .par_iter_if_isq(|i| {
421            let device = mapper
422                .device_for(i, false)
423                .unwrap_or(&normal_loading_metadata.real_device);
424            let rotary_emb = ropes
425                .get(&device.location())
426                .expect("No RoPE for device location!")
427                .clone();
428            let paged_attn = match &attention_mechanism {
429                AttentionImplementation::Eager => None,
430                AttentionImplementation::PagedAttention => {
431                    Some(PagedAttention::new(head_dim, device, None)?)
432                }
433            };
434            let comm = mapper.get_comm_for(i)?;
435            Block::load(
436                vb_m.pp(format!("layers.{i}")),
437                cfg,
438                &*mapper,
439                i,
440                normal_loading_metadata.loading_isq,
441                rotary_emb,
442                paged_attn,
443                &comm,
444            )
445        })?;
446
447        Ok(Self {
448            wte,
449            blocks,
450            ln_f,
451            lm_head,
452            kv_cache: EitherCache::Normal(NormalCache::new(
453                cfg.num_hidden_layers,
454                cfg.max_position_embeddings,
455            )),
456            device: normal_loading_metadata.real_device,
457            cfg: ModelConfigMetadata {
458                max_seq_len: cfg.max_position_embeddings,
459                num_layers: cfg.num_hidden_layers,
460                hidden_size: cfg.hidden_size,
461                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
462                    .max(1),
463                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
464                sliding_window: None,
465                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
466                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
467            },
468            mapper,
469        })
470    }
471
472    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
473        self.wte.forward(input_ids)
474    }
475
476    pub fn forward(
477        &self,
478        input_ids: &Tensor,
479        seqlen_offsets: &[usize],
480        context_lens: Vec<(usize, usize)>,
481        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
482        flash_params: &FlashParams,
483    ) -> Result<Tensor> {
484        self.forward_embeds(
485            input_ids,
486            self.wte.forward(input_ids)?,
487            seqlen_offsets,
488            context_lens,
489            metadata,
490            flash_params,
491        )
492    }
493
494    #[allow(clippy::too_many_arguments)]
495    pub fn forward_embeds(
496        &self,
497        input_ids: &Tensor,
498        input_embeds: Tensor,
499        seqlen_offsets: &[usize],
500        context_lens: Vec<(usize, usize)>,
501        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
502        flash_params: &FlashParams,
503    ) -> Result<Tensor> {
504        let mut x = input_embeds;
505        let cache = &mut self.kv_cache.normal().0;
506        let mask = CausalMasker.make_causal_mask_matrix(
507            input_ids,
508            metadata
509                .as_ref()
510                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
511                .unwrap_or(cache as &dyn PastKvLenCache),
512            x.dtype(),
513            self.blocks[0].attn.num_attention_heads,
514        )?;
515        // PagedAttention prompt chunking
516        let mask = mask.filter(|_| {
517            metadata
518                .as_ref()
519                .map(|(_, meta)| meta.is_first_prompt_chunk)
520                .unwrap_or(true)
521        });
522        for (block_idx, block) in self.blocks.iter().enumerate() {
523            x = self.mapper.map(x, block_idx)?;
524            x = block.forward(
525                &x,
526                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
527                seqlen_offsets,
528                &mut cache[block_idx],
529                metadata
530                    .as_ref()
531                    .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
532                flash_params,
533            )?;
534        }
535        let x = x.to_device(&self.device)?;
536        let mut x = self.ln_f.forward(&x)?;
537        if let Some(t) = self.lm_head.quantized_act_type() {
538            x = x.to_dtype(t)?;
539        }
540        let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
541        extract_logits(&xs, context_lens)
542    }
543
544    pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
545        uvb_m.pp("embed_tokens").add(&self.wte);
546        uvb_m.pp("norm").add(&self.ln_f);
547
548        for (layer_idx, layer) in self.blocks.iter().enumerate() {
549            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
550            uvb_l.pp("input_layernorm").add(&layer.rms_1);
551            uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
552        }
553
554        uvb_m.to_safetensors()
555    }
556}
557
558impl IsqModel for Llama {
559    fn get_layers(
560        &mut self,
561    ) -> (
562        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
563        &dyn DeviceMapper,
564    ) {
565        let mut tensors = Vec::new();
566        tensors.push((&mut self.lm_head, None));
567        for (i, layer) in self.blocks.iter_mut().enumerate() {
568            tensors.push((&mut layer.attn.q_proj, Some(i)));
569            tensors.push((&mut layer.attn.k_proj, Some(i)));
570            tensors.push((&mut layer.attn.v_proj, Some(i)));
571            tensors.push((&mut layer.attn.o_proj, Some(i)));
572            tensors.extend(
573                layer
574                    .mlp
575                    .get_isq_layers()
576                    .into_iter()
577                    .map(|m| (m, Some(i)))
578                    .collect::<Vec<_>>(),
579            );
580        }
581        (tensors, &*self.mapper)
582    }
583
584    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
585        let uvb = UnVarBuilder::new();
586        self.residual_tensors_m(uvb.pp("model"))
587    }
588
589    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
590        // NOTE: dependant on the exact implementation in get_layers!
591        let mut names = Vec::new();
592        // lm_head
593        names.push(None);
594        for i in 0..self.blocks.len() {
595            names.push(Some(format!("blk.{i}.attn_q.weight")));
596            names.push(Some(format!("blk.{i}.attn_k.weight")));
597            names.push(Some(format!("blk.{i}.attn_v.weight")));
598            names.push(Some(format!("blk.{i}.attn_output.weight")));
599            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
600            names.push(Some(format!("blk.{i}.ffn_up.weight")));
601            names.push(Some(format!("blk.{i}.ffn_down.weight")));
602        }
603        Ok(names)
604    }
605}
606
607impl NormalModel for Llama {
608    fn forward(
609        &self,
610        input_ids: &Tensor,
611        seqlen_offsets: &[usize],
612        context_lens: Vec<(usize, usize)>,
613        _position_ids: Vec<usize>,
614        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
615        flash_params: &FlashParams,
616    ) -> Result<Tensor> {
617        self.forward(
618            input_ids,
619            seqlen_offsets,
620            context_lens,
621            metadata,
622            flash_params,
623        )
624    }
625    fn xlora_forward(
626        &self,
627        _input_ids: &Tensor,
628        _input_ids_full: &Tensor,
629        _seqlen_offsets: &[usize],
630        _seqlen_offsets_full: &[usize],
631        _no_kv_cache: bool,
632        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
633        _context_lens: Vec<(usize, usize)>,
634        _position_ids: Vec<usize>,
635        _flash_params: &FlashParams,
636        _flash_params_full: &FlashParams,
637    ) -> Result<Tensor> {
638        unimplemented!()
639    }
640    fn cache(&self) -> &crate::pipeline::EitherCache {
641        &self.kv_cache
642    }
643    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
644        &mut self.kv_cache
645    }
646    fn device(&self) -> &Device {
647        &self.device
648    }
649    fn is_xlora(&self) -> bool {
650        false
651    }
652    fn max_seq_len(&self) -> usize {
653        self.blocks[0].attn.max_seq_len
654    }
655    fn config(&self) -> &ModelConfigMetadata {
656        &self.cfg
657    }
658}
659
660impl AnyMoeBaseModelMixin for Llama {
661    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
662        let mut mlps = Vec::new();
663        for layer in &self.blocks {
664            mlps.push(&*layer.mlp);
665        }
666        mlps
667    }
668    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
669        let mut mlps = Vec::new();
670        for layer in &mut self.blocks {
671            mlps.push(&mut layer.mlp);
672        }
673        mlps
674    }
675    fn create_anymoe_layers(
676        &mut self,
677        additional_vbs: Vec<ShardedVarBuilder>,
678        config: AnyMoeConfig,
679        (prefix, mlp): (String, String),
680        mut layers: Vec<usize>,
681        expert_type: AnyMoeExpertType,
682        gate_vb: Option<ShardedVarBuilder>,
683    ) -> Result<()> {
684        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
685        if layers.is_empty() {
686            layers = (0..self.blocks.len()).collect::<Vec<_>>();
687        }
688        for _ in 0..layers.len() {
689            experts.push(Vec::new());
690        }
691        for vb in additional_vbs {
692            let vb = vb.pp(&prefix);
693            for (layer, row) in experts.iter_mut().enumerate() {
694                if !layers.contains(&layer) {
695                    continue;
696                }
697
698                let intermediate_size = self.blocks[layer].mlp.get_params()[1];
699                let hidden_size = self.blocks[layer].mlp.get_params()[0];
700                match expert_type {
701                    AnyMoeExpertType::FineTuned => {
702                        let (dtype, device) = self.blocks[layer].mlp.dtype_device();
703                        row.push(Box::new(Mlp::replicate(
704                            self.blocks[layer].mlp.get_params(),
705                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
706                            self.blocks[layer].mlp.hidden_act(),
707                            &self.mapper.get_comm_for(layer)?,
708                        )?));
709                    }
710                    AnyMoeExpertType::LoraAdapter {
711                        rank,
712                        alpha,
713                        ref target_modules,
714                    } => {
715                        let vb_mlp = vb.pp(layer).pp(&mlp);
716
717                        let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
718                            Some(get_delta_from_lora_ab!(
719                                vb_mlp,
720                                rank,
721                                alpha,
722                                (hidden_size, intermediate_size),
723                                "c_fc1"
724                            ))
725                        } else {
726                            None
727                        };
728                        let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
729                            Some(get_delta_from_lora_ab!(
730                                vb_mlp,
731                                rank,
732                                alpha,
733                                (hidden_size, intermediate_size),
734                                "c_fc2"
735                            ))
736                        } else {
737                            None
738                        };
739                        let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
740                            Some(get_delta_from_lora_ab!(
741                                vb_mlp,
742                                rank,
743                                alpha,
744                                (intermediate_size, hidden_size),
745                                "c_proj"
746                            ))
747                        } else {
748                            None
749                        };
750
751                        row.push(self.blocks[layer].mlp.new_added_delta(vec![
752                            c_fc1_delta,
753                            c_fc2_delta,
754                            c_proj_delta,
755                        ])?);
756                    }
757                }
758            }
759        }
760        for (layer, expert) in layers.into_iter().zip(experts) {
761            let mut experts_all = vec![self.blocks[layer].mlp.clone()];
762            experts_all.extend(expert);
763            let (dtype, device) = self.blocks[layer].mlp.dtype_device();
764            self.blocks[layer].mlp = Box::new(MoeMlp::new(
765                experts_all,
766                config.clone(),
767                dtype,
768                &device,
769                layer,
770                gate_vb.as_ref(),
771            )?);
772        }
773        Ok(())
774    }
775    fn amoe_supported(&self) -> bool {
776        true
777    }
778}