mistralrs_core/models/
llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, Result, Tensor};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7    ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{
18        embedding, Activation, CausalMasker, Llama3RopeConfig, Llama3RotaryEmbedding, MatMul, Mlp,
19        RmsNorm, Sdpa,
20    },
21    layers_masker::PastKvLenCache,
22    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23    pipeline::{
24        extract_logits,
25        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27    },
28    serde_default_fn,
29    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33serde_default_fn!(bool, use_flash_attn_default, false);
34
35#[derive(Debug, Clone, Deserialize, Serialize, Default)]
36pub struct Config {
37    pub hidden_act: Activation,
38    pub hidden_size: usize,
39    pub intermediate_size: usize,
40    pub vocab_size: usize,
41    pub num_hidden_layers: usize,
42    pub num_attention_heads: usize,
43    pub num_key_value_heads: usize,
44    #[serde(default = "use_flash_attn_default")]
45    pub use_flash_attn: bool,
46    pub rms_norm_eps: f64,
47    pub rope_theta: f32,
48    pub max_position_embeddings: usize,
49    pub rope_scaling: Option<Llama3RopeConfig>,
50    pub quantization_config: Option<QuantizedConfig>,
51    #[serde(default = "word_emb_default")]
52    pub tie_word_embeddings: bool,
53}
54
55struct CausalSelfAttention {
56    q_proj: Arc<dyn QuantMethod>,
57    k_proj: Arc<dyn QuantMethod>,
58    v_proj: Arc<dyn QuantMethod>,
59    o_proj: Arc<dyn QuantMethod>,
60    num_attention_heads: usize,
61    num_key_value_heads: usize,
62    head_dim: usize,
63    rotary_emb: Arc<Llama3RotaryEmbedding>,
64    max_seq_len: usize,
65    paged_attn: Option<PagedAttention>,
66    sdpa_params: SdpaParams,
67}
68
69impl CausalSelfAttention {
70    #[allow(clippy::too_many_arguments)]
71    fn forward(
72        &self,
73        x: &Tensor,
74        attention_mask: &Option<Tensor>,
75        seqlen_offsets: &[usize],
76        kv_cache: &mut KvCache,
77        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
78        flash_params: &FlashParams,
79    ) -> Result<Tensor> {
80        let (b_sz, seq_len, _) = x.dims3()?;
81
82        let original_dtype = x.dtype();
83        let mut x = x.clone();
84        if let Some(t) = self.q_proj.quantized_act_type() {
85            x = x.to_dtype(t)?;
86        }
87        let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
88        let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
89        let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
90        if self.q_proj.quantized_act_type().is_some() {
91            q = q.to_dtype(original_dtype)?;
92            k = k.to_dtype(original_dtype)?;
93            v = v.to_dtype(original_dtype)?;
94        }
95
96        let (q, k, v) = if seq_len != 1 {
97            let q = q
98                .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
99                .transpose(1, 2)?;
100            let k = k
101                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
102                .transpose(1, 2)?;
103            let v = v
104                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
105                .transpose(1, 2)?;
106            (q, k, v)
107        } else {
108            let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
109            let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
110            let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
111            (q, k, v)
112        };
113
114        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
115
116        let mut y = match &self.paged_attn {
117            Some(paged_attn) => match metadata {
118                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
119                    &q,
120                    &k,
121                    &v,
122                    attention_mask.clone().as_ref(),
123                    Some(key_cache),
124                    Some(value_cache),
125                    input_metadata,
126                    &self.sdpa_params,
127                    Some(flash_params),
128                )?,
129                None => {
130                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
131                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
132                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
133                    // Sanity check.
134                    assert!(attention_mask.is_some());
135                    paged_attn.forward(
136                        &q,
137                        &k,
138                        &v,
139                        attention_mask.clone().as_ref(),
140                        None,
141                        None,
142                        &input_metadata,
143                        &self.sdpa_params,
144                        Some(flash_params),
145                    )?
146                }
147            },
148            None => {
149                let (k, v) = kv_cache.append(&k, &v)?;
150
151                Sdpa.run_attention(
152                    &q,
153                    &k,
154                    &v,
155                    attention_mask.clone().as_ref(),
156                    Some(flash_params),
157                    &self.sdpa_params,
158                )?
159            }
160        };
161
162        if let Some(t) = self.q_proj.quantized_act_type() {
163            y = y.to_dtype(t)?;
164        }
165        y = if attention_mask.is_some() {
166            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
167        } else {
168            y.reshape((b_sz, seq_len, ()))?
169        };
170        let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
171        if self.q_proj.quantized_act_type().is_some() {
172            res = res.to_dtype(original_dtype)?;
173        }
174        Ok(res)
175    }
176
177    fn load(
178        vb: ShardedVarBuilder,
179        cfg: &Config,
180        rope: Arc<Llama3RotaryEmbedding>,
181        paged_attn: Option<PagedAttention>,
182        comm: &Arc<mistralrs_quant::Comm>,
183    ) -> Result<Self> {
184        let size_in = cfg.hidden_size;
185        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
186        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
187        let q_proj = ColumnParallelLayer::new(
188            size_in,
189            size_q,
190            &cfg.quantization_config,
191            false,
192            comm,
193            vb.pp("q_proj"),
194        )?;
195        let kv_shard = mistralrs_quant::compute_kv_shard(
196            cfg.num_key_value_heads,
197            cfg.hidden_size / cfg.num_attention_heads,
198            comm,
199        );
200        let k_proj = ColumnParallelLayer::new_with_shard(
201            size_in,
202            size_kv,
203            &cfg.quantization_config,
204            false,
205            comm,
206            kv_shard,
207            vb.pp("k_proj"),
208        )?;
209        let v_proj = ColumnParallelLayer::new_with_shard(
210            size_in,
211            size_kv,
212            &cfg.quantization_config,
213            false,
214            comm,
215            kv_shard,
216            vb.pp("v_proj"),
217        )?;
218        let o_proj = RowParallelLayer::new(
219            size_q,
220            size_in,
221            &cfg.quantization_config,
222            false,
223            comm,
224            vb.pp("o_proj"),
225        )?;
226        Ok(Self {
227            q_proj,
228            k_proj,
229            v_proj,
230            o_proj,
231            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
232            num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
233            head_dim: cfg.hidden_size / cfg.num_attention_heads,
234            rotary_emb: rope,
235            max_seq_len: cfg.max_position_embeddings,
236            paged_attn,
237            sdpa_params: SdpaParams {
238                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
239                    cfg.num_key_value_heads,
240                    cfg.num_attention_heads,
241                    comm,
242                ),
243                use_flash_attn: cfg.use_flash_attn,
244                softcap: None,
245                softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
246                sliding_window: None,
247            },
248        })
249    }
250}
251
252struct Block {
253    rms_1: RmsNorm,
254    attn: CausalSelfAttention,
255    rms_2: RmsNorm,
256    mlp: Box<dyn MlpLayer>,
257}
258
259impl Block {
260    #[allow(clippy::too_many_arguments)]
261    fn forward(
262        &self,
263        x: &Tensor,
264        attention_mask: &Option<Tensor>,
265        seqlen_offsets: &[usize],
266        kv_cache: &mut KvCache,
267        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
268        flash_params: &FlashParams,
269    ) -> Result<Tensor> {
270        let residual = x;
271        let x = self.rms_1.forward(x)?;
272        let x = (self.attn.forward(
273            &x,
274            attention_mask,
275            seqlen_offsets,
276            kv_cache,
277            metadata,
278            flash_params,
279        )? + residual)?;
280        let residual = &x;
281        let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
282        Ok(x)
283    }
284
285    #[allow(clippy::too_many_arguments)]
286    fn load(
287        vb: ShardedVarBuilder,
288        cfg: &Config,
289        mapper: &dyn DeviceMapper,
290        layer_idx: usize,
291        loading_isq: bool,
292        rope: Arc<Llama3RotaryEmbedding>,
293        paged_attn: Option<PagedAttention>,
294        comm: &Arc<mistralrs_quant::Comm>,
295    ) -> Result<Self> {
296        let attn = CausalSelfAttention::load(
297            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
298            cfg,
299            rope,
300            paged_attn,
301            comm,
302        )?;
303        let mlp = Mlp::new(
304            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
305            cfg.hidden_size,
306            cfg.intermediate_size,
307            &cfg.quantization_config,
308            cfg.hidden_act,
309            comm,
310        )?;
311        let rms_1 = RmsNorm::new(
312            cfg.hidden_size,
313            cfg.rms_norm_eps,
314            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
315        )?;
316        let rms_2 = RmsNorm::new(
317            cfg.hidden_size,
318            cfg.rms_norm_eps,
319            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
320        )?;
321        Ok(Self {
322            rms_1,
323            attn,
324            rms_2,
325            mlp: Box::new(mlp),
326        })
327    }
328}
329
330pub struct Llama {
331    wte: Embedding,
332    blocks: Vec<Block>,
333    ln_f: RmsNorm,
334    lm_head: Arc<dyn QuantMethod>,
335    kv_cache: crate::pipeline::EitherCache,
336    device: Device,
337    mapper: Box<dyn DeviceMapper + Send + Sync>,
338    cfg: ModelConfigMetadata,
339}
340
341impl Llama {
342    pub fn new(
343        cfg: &Config,
344        vb: ShardedVarBuilder,
345        is_gptx: bool,
346        normal_loading_metadata: NormalLoadingMetadata,
347        attention_mechanism: AttentionImplementation,
348    ) -> Result<Self> {
349        let vb_m = vb.pp("model");
350        let vb_lm_head = vb.pp("lm_head");
351        Self::new_inner(
352            cfg,
353            vb_m,
354            vb_lm_head,
355            is_gptx,
356            normal_loading_metadata,
357            attention_mechanism,
358        )
359    }
360
361    pub fn new_inner(
362        cfg: &Config,
363        vb_m: ShardedVarBuilder,
364        vb_lm_head: ShardedVarBuilder,
365        is_gptx: bool,
366        normal_loading_metadata: NormalLoadingMetadata,
367        attention_mechanism: AttentionImplementation,
368    ) -> Result<Self> {
369        if let Some(ref quant_cfg) = &cfg.quantization_config {
370            tracing::info!(
371                "Using {} quantization: {}.",
372                quant_cfg.name(),
373                quant_cfg.get_bits_name(&vb_m)
374            );
375        }
376        let mapper = normal_loading_metadata.mapper;
377
378        let wte = embedding(
379            cfg.vocab_size,
380            cfg.hidden_size,
381            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
382            &cfg.quantization_config,
383        )?;
384        let lm_head = if !cfg.tie_word_embeddings {
385            ReplicatedLayer::new(
386                cfg.hidden_size,
387                cfg.vocab_size,
388                &None,
389                false,
390                mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
391            )?
392        } else {
393            ReplicatedLayer::from_linear(candle_nn::Linear::new(
394                mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
395                None,
396            ))?
397        };
398        let ln_f = RmsNorm::new(
399            cfg.hidden_size,
400            cfg.rms_norm_eps,
401            mapper.set_nm_device(vb_m.pp("norm"), false),
402        )?;
403        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
404        let mut ropes = HashMap::new();
405        for i in 0..cfg.num_hidden_layers {
406            let device = mapper
407                .device_for(i, false)
408                .unwrap_or(&normal_loading_metadata.real_device);
409            ropes.insert(
410                device.location(),
411                Arc::new(Llama3RotaryEmbedding::new_llama3(
412                    vb_m.dtype(),
413                    cfg,
414                    device,
415                    is_gptx,
416                )?),
417            );
418        }
419        let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
420            0..cfg.num_hidden_layers,
421            "Loading repeating layers",
422            &normal_loading_metadata.multi_progress,
423        )
424        .into_iter()
425        .map(|i| {
426            let device = mapper
427                .device_for(i, false)
428                .unwrap_or(&normal_loading_metadata.real_device);
429            let rotary_emb = ropes
430                .get(&device.location())
431                .expect("No RoPE for device location!")
432                .clone();
433            let paged_attn = match &attention_mechanism {
434                AttentionImplementation::Eager => None,
435                AttentionImplementation::PagedAttention => Some(
436                    PagedAttention::new(head_dim, device, None)
437                        .expect("Failed to create PagedAttention"),
438                ),
439            };
440            let comm = mapper.get_comm_for(i).unwrap();
441            Block::load(
442                vb_m.pp(format!("layers.{i}")),
443                cfg,
444                &*mapper,
445                i,
446                normal_loading_metadata.loading_isq,
447                rotary_emb,
448                paged_attn,
449                &comm,
450            )
451            .expect("Failed to load block.")
452        })
453        .collect();
454
455        Ok(Self {
456            wte,
457            blocks,
458            ln_f,
459            lm_head,
460            kv_cache: EitherCache::Normal(NormalCache::new(
461                cfg.num_hidden_layers,
462                cfg.max_position_embeddings,
463            )),
464            device: normal_loading_metadata.real_device,
465            cfg: ModelConfigMetadata {
466                max_seq_len: cfg.max_position_embeddings,
467                num_layers: cfg.num_hidden_layers,
468                hidden_size: cfg.hidden_size,
469                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
470                    .max(1),
471                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
472                sliding_window: None,
473                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
474                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
475            },
476            mapper,
477        })
478    }
479
480    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
481        self.wte.forward(input_ids)
482    }
483
484    pub fn forward(
485        &self,
486        input_ids: &Tensor,
487        seqlen_offsets: &[usize],
488        context_lens: Vec<(usize, usize)>,
489        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
490        flash_params: &FlashParams,
491    ) -> Result<Tensor> {
492        self.forward_embeds(
493            input_ids,
494            self.wte.forward(input_ids)?,
495            seqlen_offsets,
496            context_lens,
497            metadata,
498            flash_params,
499        )
500    }
501
502    #[allow(clippy::too_many_arguments)]
503    pub fn forward_embeds(
504        &self,
505        input_ids: &Tensor,
506        input_embeds: Tensor,
507        seqlen_offsets: &[usize],
508        context_lens: Vec<(usize, usize)>,
509        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
510        flash_params: &FlashParams,
511    ) -> Result<Tensor> {
512        let mut x = input_embeds;
513        let cache = &mut self.kv_cache.normal().0;
514        let mask = CausalMasker.make_causal_mask_matrix(
515            input_ids,
516            metadata
517                .as_ref()
518                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
519                .unwrap_or(cache as &dyn PastKvLenCache),
520            x.dtype(),
521            self.blocks[0].attn.num_attention_heads,
522        )?;
523        // PagedAttention prompt chunking
524        let mask = mask.filter(|_| {
525            metadata
526                .as_ref()
527                .map(|(_, meta)| meta.is_first_prompt_chunk)
528                .unwrap_or(true)
529        });
530        for (block_idx, block) in self.blocks.iter().enumerate() {
531            x = self.mapper.map(x, block_idx)?;
532            x = block.forward(
533                &x,
534                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
535                seqlen_offsets,
536                &mut cache[block_idx],
537                metadata
538                    .as_ref()
539                    .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
540                flash_params,
541            )?;
542        }
543        let x = x.to_device(&self.device)?;
544        let mut x = self.ln_f.forward(&x)?;
545        if let Some(t) = self.lm_head.quantized_act_type() {
546            x = x.to_dtype(t)?;
547        }
548        let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
549        extract_logits(&xs, context_lens)
550    }
551
552    pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
553        uvb_m.pp("embed_tokens").add(&self.wte);
554        uvb_m.pp("norm").add(&self.ln_f);
555
556        for (layer_idx, layer) in self.blocks.iter().enumerate() {
557            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
558            uvb_l.pp("input_layernorm").add(&layer.rms_1);
559            uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
560        }
561
562        uvb_m.to_safetensors()
563    }
564}
565
566impl IsqModel for Llama {
567    fn get_layers(
568        &mut self,
569    ) -> (
570        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
571        &dyn DeviceMapper,
572    ) {
573        let mut tensors = Vec::new();
574        tensors.push((&mut self.lm_head, None));
575        for (i, layer) in self.blocks.iter_mut().enumerate() {
576            tensors.push((&mut layer.attn.q_proj, Some(i)));
577            tensors.push((&mut layer.attn.k_proj, Some(i)));
578            tensors.push((&mut layer.attn.v_proj, Some(i)));
579            tensors.push((&mut layer.attn.o_proj, Some(i)));
580            tensors.extend(
581                layer
582                    .mlp
583                    .get_isq_layers()
584                    .into_iter()
585                    .map(|m| (m, Some(i)))
586                    .collect::<Vec<_>>(),
587            );
588        }
589        (tensors, &*self.mapper)
590    }
591
592    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
593        let uvb = UnVarBuilder::new();
594        self.residual_tensors_m(uvb.pp("model"))
595    }
596
597    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
598        // NOTE: dependant on the exact implementation in get_layers!
599        let mut names = Vec::new();
600        // lm_head
601        names.push(None);
602        for i in 0..self.blocks.len() {
603            names.push(Some(format!("blk.{i}.attn_q.weight")));
604            names.push(Some(format!("blk.{i}.attn_k.weight")));
605            names.push(Some(format!("blk.{i}.attn_v.weight")));
606            names.push(Some(format!("blk.{i}.attn_output.weight")));
607            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
608            names.push(Some(format!("blk.{i}.ffn_up.weight")));
609            names.push(Some(format!("blk.{i}.ffn_down.weight")));
610        }
611        Ok(names)
612    }
613}
614
615impl NormalModel for Llama {
616    fn forward(
617        &self,
618        input_ids: &Tensor,
619        seqlen_offsets: &[usize],
620        context_lens: Vec<(usize, usize)>,
621        _position_ids: Vec<usize>,
622        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
623        flash_params: &FlashParams,
624    ) -> Result<Tensor> {
625        self.forward(
626            input_ids,
627            seqlen_offsets,
628            context_lens,
629            metadata,
630            flash_params,
631        )
632    }
633    fn xlora_forward(
634        &self,
635        _input_ids: &Tensor,
636        _input_ids_full: &Tensor,
637        _seqlen_offsets: &[usize],
638        _seqlen_offsets_full: &[usize],
639        _no_kv_cache: bool,
640        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
641        _context_lens: Vec<(usize, usize)>,
642        _position_ids: Vec<usize>,
643        _flash_params: &FlashParams,
644        _flash_params_full: &FlashParams,
645    ) -> Result<Tensor> {
646        unimplemented!()
647    }
648    fn cache(&self) -> &crate::pipeline::EitherCache {
649        &self.kv_cache
650    }
651    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
652        &mut self.kv_cache
653    }
654    fn device(&self) -> &Device {
655        &self.device
656    }
657    fn is_xlora(&self) -> bool {
658        false
659    }
660    fn max_seq_len(&self) -> usize {
661        self.blocks[0].attn.max_seq_len
662    }
663    fn config(&self) -> &ModelConfigMetadata {
664        &self.cfg
665    }
666}
667
668impl AnyMoeBaseModelMixin for Llama {
669    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
670        let mut mlps = Vec::new();
671        for layer in &self.blocks {
672            mlps.push(&*layer.mlp);
673        }
674        mlps
675    }
676    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
677        let mut mlps = Vec::new();
678        for layer in &mut self.blocks {
679            mlps.push(&mut layer.mlp);
680        }
681        mlps
682    }
683    fn create_anymoe_layers(
684        &mut self,
685        additional_vbs: Vec<ShardedVarBuilder>,
686        config: AnyMoeConfig,
687        (prefix, mlp): (String, String),
688        mut layers: Vec<usize>,
689        expert_type: AnyMoeExpertType,
690        gate_vb: Option<ShardedVarBuilder>,
691    ) -> Result<()> {
692        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
693        if layers.is_empty() {
694            layers = (0..self.blocks.len()).collect::<Vec<_>>();
695        }
696        for _ in 0..layers.len() {
697            experts.push(Vec::new());
698        }
699        for vb in additional_vbs {
700            let vb = vb.pp(&prefix);
701            for (layer, row) in experts.iter_mut().enumerate() {
702                if !layers.contains(&layer) {
703                    continue;
704                }
705
706                let intermediate_size = self.blocks[layer].mlp.get_params()[1];
707                let hidden_size = self.blocks[layer].mlp.get_params()[0];
708                match expert_type {
709                    AnyMoeExpertType::FineTuned => {
710                        let (dtype, device) = self.blocks[layer].mlp.dtype_device();
711                        row.push(Box::new(Mlp::replicate(
712                            self.blocks[layer].mlp.get_params(),
713                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
714                            self.blocks[layer].mlp.hidden_act(),
715                            &self.mapper.get_comm_for(layer)?,
716                        )?));
717                    }
718                    AnyMoeExpertType::LoraAdapter {
719                        rank,
720                        alpha,
721                        ref target_modules,
722                    } => {
723                        let vb_mlp = vb.pp(layer).pp(&mlp);
724
725                        let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
726                            Some(get_delta_from_lora_ab!(
727                                vb_mlp,
728                                rank,
729                                alpha,
730                                (hidden_size, intermediate_size),
731                                "c_fc1"
732                            ))
733                        } else {
734                            None
735                        };
736                        let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
737                            Some(get_delta_from_lora_ab!(
738                                vb_mlp,
739                                rank,
740                                alpha,
741                                (hidden_size, intermediate_size),
742                                "c_fc2"
743                            ))
744                        } else {
745                            None
746                        };
747                        let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
748                            Some(get_delta_from_lora_ab!(
749                                vb_mlp,
750                                rank,
751                                alpha,
752                                (intermediate_size, hidden_size),
753                                "c_proj"
754                            ))
755                        } else {
756                            None
757                        };
758
759                        row.push(self.blocks[layer].mlp.new_added_delta(vec![
760                            c_fc1_delta,
761                            c_fc2_delta,
762                            c_proj_delta,
763                        ])?);
764                    }
765                }
766            }
767        }
768        for (layer, expert) in layers.into_iter().zip(experts) {
769            let mut experts_all = vec![self.blocks[layer].mlp.clone()];
770            experts_all.extend(expert);
771            let (dtype, device) = self.blocks[layer].mlp.dtype_device();
772            self.blocks[layer].mlp = Box::new(MoeMlp::new(
773                experts_all,
774                config.clone(),
775                dtype,
776                &device,
777                layer,
778                gate_vb.as_ref(),
779            )?);
780        }
781        Ok(())
782    }
783    fn amoe_supported(&self) -> bool {
784        true
785    }
786}