mistralrs_core/vision_models/llava/llava_llm/
llama.rs

1//LLaMA without fused RoPE
2#![allow(
3    clippy::cast_possible_truncation,
4    clippy::cast_precision_loss,
5    clippy::too_many_arguments
6)]
7
8use std::sync::Arc;
9
10use candle_core::{DType, Device, Result, Tensor};
11use candle_nn::{Embedding, Module};
12use mistralrs_quant::{
13    ColumnParallelLayer, QuantMethod, QuantMethodConfig, RowParallelLayer, ShardedVarBuilder,
14    UnquantLinear,
15};
16
17use crate::{
18    amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
19    attention::SdpaParams,
20    device_map::DeviceMapper,
21    get_delta_from_lora_ab,
22    layers::{
23        embedding, linear_no_bias as linear, Activation, CausalMasker, MatMul, RmsNorm, Sdpa,
24    },
25    layers_masker::PastKvLenCache,
26    models::llama::Config,
27    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
28    pipeline::{
29        extract_logits,
30        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
31        IsqModel, NormalLoadingMetadata, NormalModel,
32    },
33    utils::progress::NiceProgressBar,
34    AnyMoeConfig, AnyMoeExpertType,
35};
36
37use super::{LLaVALLM, OrdinaryRoPE};
38
39struct CausalSelfAttention {
40    q_proj: Arc<dyn QuantMethod>,
41    k_proj: Arc<dyn QuantMethod>,
42    v_proj: Arc<dyn QuantMethod>,
43    o_proj: Arc<dyn QuantMethod>,
44    num_attention_heads: usize,
45    num_key_value_heads: usize,
46    head_dim: usize,
47    max_seq_len: usize,
48    paged_attn: Option<PagedAttention>,
49    sdpa_params: SdpaParams,
50}
51
52impl CausalSelfAttention {
53    fn forward(
54        &self,
55        x: &Tensor,
56        attention_mask: &Option<Tensor>,
57        seqlen_offsets: &[usize],
58        block_idx: usize,
59        kv_cache: &mut crate::pipeline::LayerCaches,
60        rope_parameter: (&Tensor, &Tensor),
61        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
62        flash_params: &FlashParams,
63    ) -> Result<Tensor> {
64        let (b_sz, seq_len, _) = x.dims3()?;
65
66        let original_dtype = x.dtype();
67        let mut x = x.clone();
68        if let Some(t) = self.q_proj.quantized_act_type() {
69            x = x.to_dtype(t)?;
70        }
71        let mut q = MatMul.qmethod_matmul(&x, &*self.q_proj)?;
72        let mut k = MatMul.qmethod_matmul(&x, &*self.k_proj)?;
73        let mut v = MatMul.qmethod_matmul(&x, &*self.v_proj)?;
74        if self.q_proj.quantized_act_type().is_some() {
75            q = q.to_dtype(original_dtype)?;
76            k = k.to_dtype(original_dtype)?;
77            v = v.to_dtype(original_dtype)?;
78        }
79
80        let mut q = q
81            .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
82            .transpose(1, 2)?
83            .contiguous()?;
84        let mut k = k
85            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
86            .transpose(1, 2)?
87            .contiguous()?;
88        q = OrdinaryRoPE::forward(&q, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
89        k = OrdinaryRoPE::forward(&k, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
90        let v = v
91            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
92            .transpose(1, 2)?;
93
94        let mut y = match &self.paged_attn {
95            Some(paged_attn) => match metadata {
96                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
97                    &q,
98                    &k,
99                    &v,
100                    attention_mask.clone().as_ref(),
101                    Some(key_cache),
102                    Some(value_cache),
103                    input_metadata,
104                    &self.sdpa_params,
105                    Some(flash_params),
106                )?,
107                None => {
108                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
109                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
110                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
111                    // Sanity check.
112                    assert!(attention_mask.is_some());
113                    paged_attn.forward(
114                        &q,
115                        &k,
116                        &v,
117                        attention_mask.clone().as_ref(),
118                        None,
119                        None,
120                        &input_metadata,
121                        &self.sdpa_params,
122                        Some(flash_params),
123                    )?
124                }
125            },
126            None => {
127                let (k, v) =
128                    crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?;
129
130                Sdpa.run_attention(
131                    &q,
132                    &k,
133                    &v,
134                    attention_mask.clone().as_ref(),
135                    Some(flash_params),
136                    &self.sdpa_params,
137                )?
138            }
139        };
140
141        if let Some(t) = self.q_proj.quantized_act_type() {
142            y = y.to_dtype(t)?;
143        }
144        y = if attention_mask.is_some() {
145            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
146        } else {
147            y.reshape((b_sz, seq_len, ()))?
148        };
149        let mut res = MatMul.qmethod_matmul(&y, &*self.o_proj)?;
150        if self.q_proj.quantized_act_type().is_some() {
151            res = res.to_dtype(original_dtype)?;
152        }
153        Ok(res)
154    }
155
156    fn load(
157        vb: ShardedVarBuilder,
158        cfg: &Config,
159        paged_attn: Option<PagedAttention>,
160        comm: &Arc<mistralrs_quant::Comm>,
161    ) -> Result<Self> {
162        let size_in = cfg.hidden_size;
163        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
164        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
165        let q_proj = ColumnParallelLayer::new(
166            size_in,
167            size_q,
168            &cfg.quantization_config,
169            false,
170            comm,
171            vb.pp("q_proj"),
172        )?;
173        let kv_shard = mistralrs_quant::compute_kv_shard(
174            cfg.num_key_value_heads,
175            cfg.hidden_size / cfg.num_attention_heads,
176            comm,
177        );
178        let k_proj = ColumnParallelLayer::new_with_shard(
179            size_in,
180            size_kv,
181            &cfg.quantization_config,
182            false,
183            comm,
184            kv_shard,
185            vb.pp("k_proj"),
186        )?;
187        let v_proj = ColumnParallelLayer::new_with_shard(
188            size_in,
189            size_kv,
190            &cfg.quantization_config,
191            false,
192            comm,
193            kv_shard,
194            vb.pp("v_proj"),
195        )?;
196        let o_proj = RowParallelLayer::new(
197            size_q,
198            size_in,
199            &cfg.quantization_config,
200            false,
201            comm,
202            vb.pp("o_proj"),
203        )?;
204        Ok(Self {
205            q_proj,
206            k_proj,
207            v_proj,
208            o_proj,
209            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
210            num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
211            head_dim: cfg.hidden_size / cfg.num_attention_heads,
212            max_seq_len: cfg.max_position_embeddings,
213            paged_attn,
214            sdpa_params: SdpaParams {
215                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
216                    cfg.num_key_value_heads,
217                    cfg.num_attention_heads,
218                    comm,
219                ),
220                use_flash_attn: cfg.use_flash_attn,
221                softcap: None,
222                softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
223                sliding_window: None,
224            },
225        })
226    }
227}
228#[derive(Clone)]
229struct Mlp {
230    c_fc1: Arc<dyn QuantMethod>,
231    c_fc2: Arc<dyn QuantMethod>,
232    c_proj: Arc<dyn QuantMethod>,
233    params: Vec<usize>,
234}
235
236impl Mlp {
237    fn load(
238        vb: ShardedVarBuilder,
239        cfg: &Config,
240        comm: &Arc<mistralrs_quant::Comm>,
241    ) -> Result<Self> {
242        let h_size = cfg.hidden_size;
243        let i_size = cfg.intermediate_size;
244        let c_fc1 = ColumnParallelLayer::new(
245            h_size,
246            i_size,
247            &cfg.quantization_config,
248            false,
249            comm,
250            vb.pp("gate_proj"),
251        )?;
252        let c_fc2 = ColumnParallelLayer::new(
253            h_size,
254            i_size,
255            &cfg.quantization_config,
256            false,
257            comm,
258            vb.pp("up_proj"),
259        )?;
260        let c_proj = RowParallelLayer::new(
261            i_size,
262            h_size,
263            &cfg.quantization_config,
264            false,
265            comm,
266            vb.pp("down_proj"),
267        )?;
268        Ok(Self {
269            c_fc1,
270            c_fc2,
271            c_proj,
272            params: vec![h_size, i_size],
273        })
274    }
275}
276
277impl AnyMoeTrainableLayer for Mlp {}
278
279impl MlpLayer for Mlp {
280    fn forward(&self, x: &Tensor) -> Result<Tensor> {
281        let original_dtype = x.dtype();
282        let mut x = x.clone();
283        if let Some(t) = self.c_fc1.quantized_act_type() {
284            x = x.to_dtype(t)?;
285        }
286        let x = (candle_nn::ops::silu(&MatMul.qmethod_matmul(&x, &*self.c_fc1)?)?
287            * MatMul.qmethod_matmul(&x, &*self.c_fc2)?)?;
288        let mut res = MatMul.qmethod_matmul(&x, &*self.c_proj)?;
289        if self.c_fc1.quantized_act_type().is_some() {
290            res = res.to_dtype(original_dtype)?;
291        }
292        Ok(res)
293    }
294    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
295        vec![&mut self.c_fc1, &mut self.c_fc2, &mut self.c_proj]
296    }
297    fn clone(&self) -> Box<dyn MlpLayer> {
298        Box::new(Clone::clone(self))
299    }
300    fn get_params(&self) -> &[usize] {
301        &self.params
302    }
303    fn hidden_act(&self) -> Activation {
304        Activation::Silu
305    }
306    // c_fc1, c_fc2, c_proj
307    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
308        let new_c_fc1 = if let Some(ref delta) = deltas[0] {
309            self.c_fc1.add_delta_w(delta)?
310        } else {
311            self.c_fc1.clone()
312        };
313        let new_c_fc2 = if let Some(ref delta) = deltas[1] {
314            self.c_fc2.add_delta_w(delta)?
315        } else {
316            self.c_fc2.clone()
317        };
318        let new_c_proj = if let Some(ref delta) = deltas[2] {
319            self.c_proj.add_delta_w(delta)?
320        } else {
321            self.c_proj.clone()
322        };
323
324        Ok(Box::new(Self {
325            c_fc1: new_c_fc1,
326            c_fc2: new_c_fc2,
327            c_proj: new_c_proj,
328            params: self.params.clone(),
329        }))
330    }
331
332    fn dtype_device(&self) -> (DType, Device) {
333        self.c_fc1.dtype_and_device()
334    }
335}
336
337struct Block {
338    rms_1: RmsNorm,
339    attn: CausalSelfAttention,
340    rms_2: RmsNorm,
341    mlp: Box<dyn MlpLayer>,
342    rope_parameter: (Tensor, Tensor),
343}
344
345impl Block {
346    fn forward(
347        &self,
348        x: &Tensor,
349        attention_mask: &Option<Tensor>,
350        seqlen_offsets: &[usize],
351        block_idx: usize,
352        kv_cache: &mut crate::pipeline::LayerCaches,
353        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
354        flash_params: &FlashParams,
355    ) -> Result<Tensor> {
356        let residual = x;
357        let mut x = self.rms_1.forward(x)?;
358        x = (self.attn.forward(
359            &x,
360            attention_mask,
361            seqlen_offsets,
362            block_idx,
363            kv_cache,
364            (&self.rope_parameter.0, &self.rope_parameter.1),
365            metadata,
366            flash_params,
367        )? + residual)?;
368        let residual = &x;
369        x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
370        Ok(x)
371    }
372
373    fn load(
374        vb: ShardedVarBuilder,
375        cfg: &Config,
376        mapper: &dyn DeviceMapper,
377        layer_idx: usize,
378        loading_isq: bool,
379        paged_attn: Option<PagedAttention>,
380        rope_parameter: (Tensor, Tensor),
381        comm: &Arc<mistralrs_quant::Comm>,
382    ) -> Result<Self> {
383        let attn = CausalSelfAttention::load(
384            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
385            cfg,
386            paged_attn,
387            comm,
388        )?;
389        let mlp = Mlp::load(
390            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
391            cfg,
392            comm,
393        )?;
394        let rms_1 = RmsNorm::new(
395            cfg.hidden_size,
396            cfg.rms_norm_eps,
397            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
398        )?;
399        let rms_2 = RmsNorm::new(
400            cfg.hidden_size,
401            cfg.rms_norm_eps,
402            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
403        )?;
404        Ok(Self {
405            rms_1,
406            attn,
407            rms_2,
408            mlp: Box::new(mlp),
409            rope_parameter,
410        })
411    }
412}
413
414pub struct Llama {
415    wte: Embedding,
416    blocks: Vec<Block>,
417    ln_f: RmsNorm,
418    lm_head: Arc<dyn QuantMethod>,
419    kv_cache: crate::pipeline::EitherCache,
420    device: Device,
421    mapper: Box<dyn DeviceMapper + Send + Sync>,
422    cfg: ModelConfigMetadata,
423}
424
425impl Llama {
426    pub fn forward_input(
427        &self,
428        input_ids: &Tensor,
429        seqlen_offsets: &[usize],
430        context_lens: Vec<(usize, usize)>,
431        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
432        flash_params: &FlashParams,
433    ) -> Result<Tensor> {
434        let x = self.wte.forward(input_ids)?;
435        self.forward_input_embed(
436            input_ids,
437            x,
438            seqlen_offsets,
439            context_lens,
440            metadata,
441            flash_params,
442        )
443    }
444
445    pub fn new(
446        cfg: &Config,
447        vb: ShardedVarBuilder,
448        _is_gptx: bool,
449        normal_loading_metadata: NormalLoadingMetadata,
450        attention_mechanism: AttentionImplementation,
451    ) -> Result<Self> {
452        if let Some(ref quant_cfg) = &cfg.quantization_config {
453            tracing::info!(
454                "Using {} quantization: {}.",
455                quant_cfg.name(),
456                quant_cfg.get_bits_name(&vb)
457            );
458        }
459        let mapper = normal_loading_metadata.mapper;
460        let wte = embedding(
461            cfg.vocab_size,
462            cfg.hidden_size,
463            mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
464            &cfg.quantization_config,
465        )?;
466        let lm_head = linear(
467            cfg.hidden_size,
468            cfg.vocab_size,
469            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
470        )?;
471        let ln_f = RmsNorm::new(
472            cfg.hidden_size,
473            cfg.rms_norm_eps,
474            mapper.set_nm_device(vb.pp("model.norm"), false),
475        )?;
476        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
477
478        let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
479            0..cfg.num_hidden_layers,
480            "Loading repeating layers",
481            &normal_loading_metadata.multi_progress,
482        )
483        .into_iter()
484        .map(|i| {
485            let vb_m = vb.pp(format!("model.layers.{i}"));
486            let device = mapper
487                .device_for(i, false)
488                .unwrap_or(&normal_loading_metadata.real_device);
489            let rope_parameters = OrdinaryRoPE::create_parameters(
490                head_dim,
491                cfg.max_position_embeddings,
492                cfg.rope_theta,
493                vb_m.dtype(),
494                device,
495            )
496            .unwrap();
497            let paged_attn = match &attention_mechanism {
498                AttentionImplementation::Eager => None,
499                AttentionImplementation::PagedAttention => Some(
500                    PagedAttention::new(head_dim, device, None)
501                        .expect("Failed to create PagedAttention"),
502                ),
503            };
504            let comm = mapper.get_comm_for(i).unwrap();
505            Block::load(
506                vb_m,
507                cfg,
508                &*mapper,
509                i,
510                normal_loading_metadata.loading_isq,
511                paged_attn,
512                rope_parameters,
513                &comm,
514            )
515            .expect("Failed to load block.")
516        })
517        .collect();
518
519        Ok(Self {
520            wte,
521            blocks,
522            ln_f,
523            lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lm_head))?),
524            kv_cache: crate::pipeline::EitherCache::Full(crate::pipeline::Cache::new(
525                cfg.num_hidden_layers,
526                false,
527            )),
528            device: normal_loading_metadata.real_device,
529            cfg: ModelConfigMetadata {
530                max_seq_len: cfg.max_position_embeddings,
531                num_layers: cfg.num_hidden_layers,
532                hidden_size: cfg.hidden_size,
533                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
534                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
535                    .max(1),
536                sliding_window: None,
537                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
538                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
539            },
540            mapper,
541        })
542    }
543}
544
545impl IsqModel for Llama {
546    fn get_layers(
547        &mut self,
548    ) -> (
549        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
550        &dyn DeviceMapper,
551    ) {
552        let mut tensors = Vec::new();
553        tensors.push((&mut self.lm_head, None));
554        for (i, layer) in self.blocks.iter_mut().enumerate() {
555            tensors.push((&mut layer.attn.q_proj, Some(i)));
556            tensors.push((&mut layer.attn.k_proj, Some(i)));
557            tensors.push((&mut layer.attn.v_proj, Some(i)));
558            tensors.push((&mut layer.attn.o_proj, Some(i)));
559            tensors.extend(
560                layer
561                    .mlp
562                    .get_isq_layers()
563                    .into_iter()
564                    .map(|m| (m, Some(i)))
565                    .collect::<Vec<_>>(),
566            );
567        }
568        (tensors, &*self.mapper)
569    }
570
571    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
572        Vec::new()
573    }
574}
575
576impl LLaVALLM for Llama {
577    fn embed(&self, input_ids: &Tensor) -> Result<Tensor> {
578        self.wte.forward(input_ids)
579    }
580    fn forward_input_embed(
581        &self,
582        input_ids: &Tensor,
583        input_embed: Tensor,
584        seqlen_offsets: &[usize],
585        context_lens: Vec<(usize, usize)>,
586        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
587        flash_params: &FlashParams,
588    ) -> Result<Tensor> {
589        let mut x = input_embed;
590        let mut cache = self.kv_cache.full().lock();
591        let mask = CausalMasker.make_causal_mask_matrix(
592            input_ids,
593            metadata
594                .as_ref()
595                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
596                .unwrap_or(&*cache as &dyn PastKvLenCache),
597            x.dtype(),
598            self.blocks[0].attn.num_attention_heads,
599        )?;
600        let mask = mask.filter(|_| {
601            metadata
602                .as_ref()
603                .map(|(_, meta)| meta.is_first_prompt_chunk)
604                .unwrap_or(true)
605        });
606        for (block_idx, block) in self.blocks.iter().enumerate() {
607            x = self.mapper.map(x, block_idx)?;
608            x = block.forward(
609                &x,
610                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
611                seqlen_offsets,
612                block_idx,
613                &mut cache,
614                metadata
615                    .as_ref()
616                    .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
617                flash_params,
618            )?;
619        }
620        x = x.to_device(&self.device)?;
621        x = self.ln_f.forward(&x)?;
622        if let Some(t) = self.lm_head.quantized_act_type() {
623            x = x.to_dtype(t)?;
624        }
625        let xs = MatMul.qmethod_matmul(&x, &*self.lm_head)?;
626        extract_logits(&xs, context_lens)
627    }
628}
629
630impl NormalModel for Llama {
631    fn forward(
632        &self,
633        input_ids: &Tensor,
634        seqlen_offsets: &[usize],
635        context_lens: Vec<(usize, usize)>,
636        _position_ids: Vec<usize>,
637        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
638        flash_params: &FlashParams,
639    ) -> Result<Tensor> {
640        self.forward_input(
641            input_ids,
642            seqlen_offsets,
643            context_lens,
644            metadata,
645            flash_params,
646        )
647    }
648    fn xlora_forward(
649        &self,
650        _input_ids: &Tensor,
651        _input_ids_full: &Tensor,
652        _seqlen_offsets: &[usize],
653        _seqlen_offsets_full: &[usize],
654        _no_kv_cache: bool,
655        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
656        _context_lens: Vec<(usize, usize)>,
657        _position_ids: Vec<usize>,
658        _flash_params: &FlashParams,
659        _flash_params_full: &FlashParams,
660    ) -> Result<Tensor> {
661        unimplemented!()
662    }
663    fn cache(&self) -> &crate::pipeline::EitherCache {
664        &self.kv_cache
665    }
666    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
667        &mut self.kv_cache
668    }
669    fn device(&self) -> &Device {
670        &self.device
671    }
672    fn is_xlora(&self) -> bool {
673        false
674    }
675    fn max_seq_len(&self) -> usize {
676        self.blocks[0].attn.max_seq_len
677    }
678    fn config(&self) -> &ModelConfigMetadata {
679        &self.cfg
680    }
681}
682
683impl AnyMoeBaseModelMixin for Llama {
684    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
685        let mut mlps = Vec::new();
686        for layer in &self.blocks {
687            mlps.push(&*layer.mlp);
688        }
689        mlps
690    }
691    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
692        let mut mlps = Vec::new();
693        for layer in &mut self.blocks {
694            mlps.push(&mut layer.mlp);
695        }
696        mlps
697    }
698    fn create_anymoe_layers(
699        &mut self,
700        additional_vbs: Vec<ShardedVarBuilder>,
701        config: AnyMoeConfig,
702        (prefix, mlp): (String, String),
703        mut layers: Vec<usize>,
704        expert_type: AnyMoeExpertType,
705        gate_vb: Option<ShardedVarBuilder>,
706    ) -> Result<()> {
707        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
708        if layers.is_empty() {
709            layers = (0..self.blocks.len()).collect::<Vec<_>>();
710        }
711        for _ in 0..layers.len() {
712            experts.push(Vec::new());
713        }
714        for vb in additional_vbs {
715            let vb = vb.pp(&prefix);
716            for (layer, row) in experts.iter_mut().enumerate() {
717                if !layers.contains(&layer) {
718                    continue;
719                }
720
721                let intermediate_size = self.blocks[layer].mlp.get_params()[1];
722                let hidden_size = self.blocks[layer].mlp.get_params()[0];
723                match expert_type {
724                    AnyMoeExpertType::FineTuned => {
725                        let (dtype, device) = self.blocks[layer].mlp.dtype_device();
726                        row.push(Box::new(Mlp::load(
727                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
728                            &Config {
729                                intermediate_size: self.blocks[layer].mlp.get_params()[1],
730                                hidden_size: self.blocks[layer].mlp.get_params()[0],
731                                ..Default::default()
732                            },
733                            &self.mapper.get_comm_for(layer)?,
734                        )?));
735                    }
736                    AnyMoeExpertType::LoraAdapter {
737                        rank,
738                        alpha,
739                        ref target_modules,
740                    } => {
741                        let vb_mlp = vb.pp(layer).pp(&mlp);
742
743                        let c_fc1_delta = if target_modules.contains(&"c_fc1".to_string()) {
744                            Some(get_delta_from_lora_ab!(
745                                vb_mlp,
746                                rank,
747                                alpha,
748                                (hidden_size, intermediate_size),
749                                "c_fc1"
750                            ))
751                        } else {
752                            None
753                        };
754                        let c_fc2_delta = if target_modules.contains(&"c_fc2".to_string()) {
755                            Some(get_delta_from_lora_ab!(
756                                vb_mlp,
757                                rank,
758                                alpha,
759                                (hidden_size, intermediate_size),
760                                "c_fc2"
761                            ))
762                        } else {
763                            None
764                        };
765                        let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
766                            Some(get_delta_from_lora_ab!(
767                                vb_mlp,
768                                rank,
769                                alpha,
770                                (intermediate_size, hidden_size),
771                                "c_proj"
772                            ))
773                        } else {
774                            None
775                        };
776
777                        row.push(self.blocks[layer].mlp.new_added_delta(vec![
778                            c_fc1_delta,
779                            c_fc2_delta,
780                            c_proj_delta,
781                        ])?);
782                    }
783                }
784            }
785        }
786        for (layer, expert) in layers.into_iter().zip(experts) {
787            let mut experts_all = vec![self.blocks[layer].mlp.clone()];
788            experts_all.extend(expert);
789            let (dtype, device) = self.blocks[layer].mlp.dtype_device();
790            self.blocks[layer].mlp = Box::new(MoeMlp::new(
791                experts_all,
792                config.clone(),
793                dtype,
794                &device,
795                layer,
796                gate_vb.as_ref(),
797            )?);
798        }
799        Ok(())
800    }
801    fn amoe_supported(&self) -> bool {
802        true
803    }
804}