mistralrs_core/models/
gemma.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Module, Result, Tensor};
6use candle_nn::Linear;
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
9    ShardedVarBuilder, UnquantLinear,
10};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18    layers_masker::PastKvLenCache,
19    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20    pipeline::{
21        extract_logits,
22        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
24    },
25    serde_default_fn,
26    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
27};
28
29fn default_max_position_embeddings() -> usize {
30    4096
31}
32
33serde_default_fn!(bool, word_emb_default, false);
34
35#[derive(serde::Deserialize, serde::Serialize, Debug, Clone, Default)]
36pub struct Config {
37    pub attention_bias: bool,
38    pub head_dim: usize,
39    // The code gemma configs include both hidden_act and hidden_activation.
40    pub hidden_act: Option<Activation>,
41    pub hidden_activation: Option<Activation>,
42    pub hidden_size: usize,
43    pub intermediate_size: usize,
44    pub num_attention_heads: usize,
45    pub num_hidden_layers: usize,
46    pub num_key_value_heads: usize,
47    pub rms_norm_eps: f64,
48    pub rope_theta: f64,
49    pub vocab_size: usize,
50
51    #[serde(default = "default_max_position_embeddings")]
52    pub max_position_embeddings: usize,
53    pub use_flash_attn: bool,
54    pub quantization_config: Option<QuantizedConfig>,
55    #[serde(default = "word_emb_default")]
56    #[allow(dead_code)]
57    pub tie_word_embeddings: bool,
58}
59
60impl Config {
61    pub fn hidden_act(&self) -> Result<Activation> {
62        match (self.hidden_act, self.hidden_activation) {
63            (None, Some(act)) | (Some(act), None) => Ok(act),
64            (Some(_), Some(_)) => {
65                candle_core::bail!("both hidden_act and hidden_activation are set")
66            }
67            (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"),
68        }
69    }
70}
71
72struct Attention {
73    q_proj: Arc<dyn QuantMethod>,
74    k_proj: Arc<dyn QuantMethod>,
75    v_proj: Arc<dyn QuantMethod>,
76    o_proj: Arc<dyn QuantMethod>,
77    num_heads: usize,
78    num_kv_heads: usize,
79    head_dim: usize,
80    rotary_emb: Arc<RotaryEmbedding>,
81    paged_attn: Option<PagedAttention>,
82    sdpa_params: SdpaParams,
83}
84
85impl Attention {
86    fn new(
87        rotary_emb: Arc<RotaryEmbedding>,
88        cfg: &Config,
89        vb: ShardedVarBuilder,
90        paged_attn: Option<PagedAttention>,
91        comm: &Arc<mistralrs_quant::Comm>,
92    ) -> Result<Self> {
93        let hidden_sz = cfg.hidden_size;
94        let num_heads = cfg.num_attention_heads;
95        let num_kv_heads = cfg.num_key_value_heads;
96        let head_dim = cfg.head_dim;
97        let bias = cfg.attention_bias;
98        let q_proj = ColumnParallelLayer::new(
99            hidden_sz,
100            num_heads * head_dim,
101            &cfg.quantization_config,
102            bias,
103            comm,
104            vb.pp("q_proj"),
105        )?;
106        let kv_shard = mistralrs_quant::compute_kv_shard(
107            cfg.num_key_value_heads,
108            cfg.hidden_size / cfg.num_attention_heads,
109            comm,
110        );
111        let k_proj = ColumnParallelLayer::new_with_shard(
112            hidden_sz,
113            num_kv_heads * head_dim,
114            &cfg.quantization_config,
115            bias,
116            comm,
117            kv_shard,
118            vb.pp("k_proj"),
119        )?;
120        let v_proj = ColumnParallelLayer::new_with_shard(
121            hidden_sz,
122            num_kv_heads * head_dim,
123            &cfg.quantization_config,
124            bias,
125            comm,
126            kv_shard,
127            vb.pp("v_proj"),
128        )?;
129        let o_proj = RowParallelLayer::new(
130            num_heads * head_dim,
131            hidden_sz,
132            &cfg.quantization_config,
133            bias,
134            comm,
135            vb.pp("o_proj"),
136        )?;
137        Ok(Self {
138            q_proj,
139            k_proj,
140            v_proj,
141            o_proj,
142            num_heads: num_heads / comm.world_size(),
143            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
144            head_dim,
145            rotary_emb,
146            paged_attn,
147            sdpa_params: SdpaParams {
148                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
149                    cfg.num_key_value_heads,
150                    cfg.num_attention_heads,
151                    comm,
152                ),
153                use_flash_attn: cfg.use_flash_attn,
154                softcap: None,
155                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
156                sliding_window: None,
157            },
158        })
159    }
160
161    #[allow(clippy::too_many_arguments)]
162    fn forward(
163        &self,
164        xs: &Tensor,
165        attention_mask: Option<&Tensor>,
166        seqlen_offsets: &[usize],
167        kv_cache: &mut KvCache,
168        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
169        flash_params: &FlashParams,
170    ) -> Result<Tensor> {
171        let (b_sz, q_len, _) = xs.dims3()?;
172
173        let original_dtype = xs.dtype();
174        let mut xs = xs.clone();
175        if let Some(t) = self.q_proj.quantized_act_type() {
176            xs = xs.to_dtype(t)?;
177        }
178        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
179        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
180        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
181        if self.q_proj.quantized_act_type().is_some() {
182            q = q.to_dtype(original_dtype)?;
183            k = k.to_dtype(original_dtype)?;
184            v = v.to_dtype(original_dtype)?;
185        }
186
187        let (q, k, v) = if q_len != 1 {
188            let q = q
189                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
190                .transpose(1, 2)?;
191            let k = k
192                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
193                .transpose(1, 2)?;
194            let v = v
195                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196                .transpose(1, 2)?;
197            (q, k, v)
198        } else {
199            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
200            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
201            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
202            (q, k, v)
203        };
204
205        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
206
207        let mut attn_output = match &self.paged_attn {
208            Some(paged_attn) => match metadata {
209                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
210                    &q,
211                    &k,
212                    &v,
213                    attention_mask,
214                    Some(key_cache),
215                    Some(value_cache),
216                    input_metadata,
217                    &self.sdpa_params,
218                    Some(flash_params),
219                )?,
220                None => {
221                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
222                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
223                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
224                    // Sanity check.
225                    assert!(attention_mask.is_some());
226                    paged_attn.forward(
227                        &q,
228                        &k,
229                        &v,
230                        attention_mask,
231                        None,
232                        None,
233                        &input_metadata,
234                        &self.sdpa_params,
235                        Some(flash_params),
236                    )?
237                }
238            },
239            None => {
240                let (k, v) = kv_cache.append(&k, &v)?;
241
242                Sdpa.run_attention(
243                    &q,
244                    &k,
245                    &v,
246                    attention_mask,
247                    Some(flash_params),
248                    &self.sdpa_params,
249                )?
250            }
251        };
252
253        if let Some(t) = self.q_proj.quantized_act_type() {
254            attn_output = attn_output.to_dtype(t)?;
255        }
256        attn_output = if attention_mask.is_some() {
257            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
258        } else {
259            attn_output.reshape((b_sz, q_len, ()))?
260        };
261        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
262        if self.q_proj.quantized_act_type().is_some() {
263            res = res.to_dtype(original_dtype)?;
264        }
265        Ok(res)
266    }
267}
268
269struct DecoderLayer {
270    self_attn: Attention,
271    mlp: Box<dyn MlpLayer>,
272    input_layernorm: RmsNorm,
273    post_attention_layernorm: RmsNorm,
274}
275
276impl DecoderLayer {
277    #[allow(clippy::too_many_arguments)]
278    fn new(
279        rotary_emb: Arc<RotaryEmbedding>,
280        cfg: &Config,
281        vb: ShardedVarBuilder,
282        mapper: &dyn DeviceMapper,
283        layer_idx: usize,
284        loading_isq: bool,
285        paged_attn: Option<PagedAttention>,
286        comm: &Arc<mistralrs_quant::Comm>,
287    ) -> Result<Self> {
288        let self_attn = Attention::new(
289            rotary_emb,
290            cfg,
291            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
292            paged_attn,
293            comm,
294        )?;
295        let mlp = Mlp::new(
296            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
297            cfg.hidden_size,
298            cfg.intermediate_size,
299            &cfg.quantization_config,
300            cfg.hidden_act()?,
301            comm,
302        )?;
303        let input_layernorm = RmsNorm::new_gemma(
304            cfg.hidden_size,
305            cfg.rms_norm_eps,
306            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
307        )?;
308        let post_attention_layernorm = RmsNorm::new_gemma(
309            cfg.hidden_size,
310            cfg.rms_norm_eps,
311            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
312        )?;
313        Ok(Self {
314            self_attn,
315            mlp: Box::new(mlp),
316            input_layernorm,
317            post_attention_layernorm,
318        })
319    }
320
321    #[allow(clippy::too_many_arguments)]
322    fn forward(
323        &self,
324        xs: &Tensor,
325        attention_mask: Option<&Tensor>,
326        seqlen_offsets: &[usize],
327        kv_cache: &mut KvCache,
328        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
329        flash_params: &FlashParams,
330    ) -> Result<Tensor> {
331        let residual = xs;
332        let xs = self.input_layernorm.forward(xs)?;
333        let xs = self.self_attn.forward(
334            &xs,
335            attention_mask,
336            seqlen_offsets,
337            kv_cache,
338            metadata,
339            flash_params,
340        )?;
341        let xs = (xs + residual)?;
342        let residual = &xs;
343        let xs = self
344            .mlp
345            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
346        residual + xs
347    }
348}
349
350pub struct Model {
351    embed_tokens: candle_nn::Embedding,
352    layers: Vec<DecoderLayer>,
353    norm: RmsNorm,
354    lm_head: Arc<dyn QuantMethod>,
355    hidden_size: usize,
356    device: Device,
357    cache: EitherCache,
358    max_seq_len: usize,
359    mapper: Box<dyn DeviceMapper + Send + Sync>,
360    cfg: ModelConfigMetadata,
361}
362
363impl Model {
364    pub fn new(
365        cfg: &Config,
366        vb: ShardedVarBuilder,
367        is_gptx: bool,
368        normal_loading_metadata: NormalLoadingMetadata,
369        attention_mechanism: AttentionImplementation,
370    ) -> Result<Self> {
371        if let Some(ref quant_cfg) = &cfg.quantization_config {
372            tracing::info!(
373                "Using {} quantization: {}.",
374                quant_cfg.name(),
375                quant_cfg.get_bits_name(&vb)
376            );
377        }
378        let mapper = normal_loading_metadata.mapper;
379
380        let vb_m = vb.pp("model");
381        let embed_tokens = embedding(
382            cfg.vocab_size,
383            cfg.hidden_size,
384            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
385            &cfg.quantization_config,
386        )?;
387        let mut ropes = HashMap::new();
388        for layer_idx in 0..cfg.num_hidden_layers {
389            let device = mapper
390                .device_for(layer_idx, false)
391                .unwrap_or(&normal_loading_metadata.real_device);
392            ropes.insert(
393                device.location(),
394                Arc::new(RotaryEmbedding::new(
395                    cfg.rope_theta as f32,
396                    cfg.head_dim,
397                    cfg.max_position_embeddings,
398                    device,
399                    is_gptx,
400                    vb_m.dtype(),
401                )?),
402            );
403        }
404        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
405        let vb_l = vb_m.pp("layers");
406        for layer_idx in NiceProgressBar::<_, 'b'>(
407            0..cfg.num_hidden_layers,
408            "Loading repeating layers",
409            &normal_loading_metadata.multi_progress,
410        ) {
411            let device = mapper
412                .device_for(layer_idx, false)
413                .unwrap_or(&normal_loading_metadata.real_device);
414            let rotary_emb = ropes
415                .get(&device.location())
416                .expect("No RoPE for device location!")
417                .clone();
418            let paged_attn = match &attention_mechanism {
419                AttentionImplementation::Eager => None,
420                AttentionImplementation::PagedAttention => {
421                    Some(PagedAttention::new(cfg.head_dim, device, None)?)
422                }
423            };
424            let comm = mapper.get_comm_for(layer_idx)?;
425            let layer = DecoderLayer::new(
426                rotary_emb.clone(),
427                cfg,
428                vb_l.pp(layer_idx),
429                &*mapper,
430                layer_idx,
431                normal_loading_metadata.loading_isq,
432                paged_attn,
433                &comm,
434            )?;
435            layers.push(layer)
436        }
437        let norm = RmsNorm::new_gemma(
438            cfg.hidden_size,
439            cfg.rms_norm_eps,
440            mapper.set_nm_device(vb_m.pp("norm"), false),
441        )?;
442        let lm_head = mapper.cast_nm_device(
443            embed_tokens.embeddings(),
444            normal_loading_metadata.loading_isq,
445        )?;
446        Ok(Self {
447            embed_tokens,
448            layers,
449            norm,
450            lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
451                Linear::new(lm_head, None),
452            ))?),
453            device: normal_loading_metadata.real_device,
454            hidden_size: cfg.hidden_size,
455            cache: EitherCache::Normal(NormalCache::new(
456                cfg.num_hidden_layers,
457                cfg.max_position_embeddings,
458            )),
459            max_seq_len: default_max_position_embeddings(),
460            cfg: ModelConfigMetadata {
461                max_seq_len: cfg.max_position_embeddings,
462                num_layers: cfg.num_hidden_layers,
463                hidden_size: cfg.hidden_size,
464                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
465                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
466                    .max(1),
467                sliding_window: None,
468                k_head_dim: cfg.head_dim,
469                v_head_dim: cfg.head_dim,
470            },
471            mapper,
472        })
473    }
474
475    pub fn forward(
476        &self,
477        input_ids: &Tensor,
478        seqlen_offsets: &[usize],
479        context_lens: Vec<(usize, usize)>,
480        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
481        flash_params: &FlashParams,
482    ) -> Result<Tensor> {
483        let xs = self.embed_tokens.forward(input_ids)?;
484        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
485        let cache = &mut self.cache.normal().0;
486        let attention_mask = CausalMasker.make_causal_mask_matrix(
487            input_ids,
488            metadata
489                .as_ref()
490                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
491                .unwrap_or(cache as &dyn PastKvLenCache),
492            xs.dtype(),
493            self.cfg.num_attn_heads,
494        )?;
495        // PagedAttention prompt chunking
496        let attention_mask = attention_mask.filter(|_| {
497            metadata
498                .as_ref()
499                .map(|(_, meta)| meta.is_first_prompt_chunk)
500                .unwrap_or(true)
501        });
502        for (i, layer) in self.layers.iter().enumerate() {
503            xs = self.mapper.map(xs, i)?;
504            xs = layer.forward(
505                &xs,
506                attention_mask
507                    .as_ref()
508                    .map(|m| m.to_device(xs.device()).unwrap())
509                    .as_ref(),
510                seqlen_offsets,
511                &mut cache[i],
512                metadata
513                    .as_ref()
514                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
515                flash_params,
516            )?;
517        }
518        let xs = xs.to_device(&self.device)?;
519        let mut xs = xs.apply(&self.norm)?;
520        if let Some(t) = self.lm_head.quantized_act_type() {
521            xs = xs.to_dtype(t)?;
522        }
523        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
524    }
525}
526
527impl IsqModel for Model {
528    fn get_layers(
529        &mut self,
530    ) -> (
531        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
532        &dyn DeviceMapper,
533    ) {
534        let mut tensors = Vec::new();
535        tensors.push((&mut self.lm_head, None));
536        for (i, layer) in self.layers.iter_mut().enumerate() {
537            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
538            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
539            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
540            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
541            tensors.extend(
542                layer
543                    .mlp
544                    .get_isq_layers()
545                    .into_iter()
546                    .map(|m| (m, Some(i)))
547                    .collect::<Vec<_>>(),
548            );
549        }
550        (tensors, &*self.mapper)
551    }
552
553    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
554        let uvb = UnVarBuilder::new();
555
556        let uvb_m = uvb.pp("model");
557        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
558        uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
559
560        for (layer_idx, layer) in self.layers.iter().enumerate() {
561            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
562            uvb_l
563                .pp("input_layernorm")
564                .add(&layer.input_layernorm.undo_gemma().unwrap());
565            uvb_l
566                .pp("post_attention_layernorm")
567                .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
568        }
569
570        uvb.to_safetensors()
571    }
572
573    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
574        // NOTE: dependant on the exact implementation in get_layers!
575        let mut names = Vec::new();
576        // lm_head
577        names.push(None);
578        for i in 0..self.layers.len() {
579            names.push(Some(format!("blk.{i}.attn_q.weight")));
580            names.push(Some(format!("blk.{i}.attn_k.weight")));
581            names.push(Some(format!("blk.{i}.attn_v.weight")));
582            names.push(Some(format!("blk.{i}.attn_output.weight")));
583            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
584            names.push(Some(format!("blk.{i}.ffn_up.weight")));
585            names.push(Some(format!("blk.{i}.ffn_down.weight")));
586        }
587        Ok(names)
588    }
589}
590
591impl NormalModel for Model {
592    fn forward(
593        &self,
594        input_ids: &Tensor,
595        seqlen_offsets: &[usize],
596        context_lens: Vec<(usize, usize)>,
597        _position_ids: Vec<usize>,
598        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
599        flash_params: &FlashParams,
600    ) -> Result<Tensor> {
601        self.forward(
602            input_ids,
603            seqlen_offsets,
604            context_lens,
605            metadata,
606            flash_params,
607        )
608    }
609    fn xlora_forward(
610        &self,
611        _input_ids: &Tensor,
612        _input_ids_full: &Tensor,
613        _seqlen_offsets: &[usize],
614        _seqlen_offsets_full: &[usize],
615        _no_kv_cache: bool,
616        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
617        _context_lens: Vec<(usize, usize)>,
618        _position_ids: Vec<usize>,
619        _flash_params: &FlashParams,
620        _flash_params_full: &FlashParams,
621    ) -> Result<Tensor> {
622        unimplemented!()
623    }
624    fn cache(&self) -> &EitherCache {
625        &self.cache
626    }
627    fn cache_mut(&mut self) -> &mut EitherCache {
628        &mut self.cache
629    }
630    fn device(&self) -> &Device {
631        &self.device
632    }
633    fn is_xlora(&self) -> bool {
634        false
635    }
636    fn max_seq_len(&self) -> usize {
637        self.max_seq_len
638    }
639    fn config(&self) -> &ModelConfigMetadata {
640        &self.cfg
641    }
642}
643
644impl AnyMoeBaseModelMixin for Model {
645    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
646        let mut mlps = Vec::new();
647        for layer in &self.layers {
648            mlps.push(&*layer.mlp);
649        }
650        mlps
651    }
652    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
653        let mut mlps = Vec::new();
654        for layer in &mut self.layers {
655            mlps.push(&mut layer.mlp);
656        }
657        mlps
658    }
659    fn create_anymoe_layers(
660        &mut self,
661        additional_vbs: Vec<ShardedVarBuilder>,
662        config: AnyMoeConfig,
663        (prefix, mlp): (String, String),
664        mut layers: Vec<usize>,
665        expert_type: AnyMoeExpertType,
666        gate_vb: Option<ShardedVarBuilder>,
667    ) -> Result<()> {
668        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
669        if layers.is_empty() {
670            layers = (0..self.layers.len()).collect::<Vec<_>>();
671        }
672        for _ in 0..layers.len() {
673            experts.push(Vec::new());
674        }
675        for vb in additional_vbs {
676            let vb = vb.pp(&prefix);
677            for (layer, row) in experts.iter_mut().enumerate() {
678                if !layers.contains(&layer) {
679                    continue;
680                }
681
682                let intermediate_size = self.layers[layer].mlp.get_params()[1];
683                let hidden_size = self.layers[layer].mlp.get_params()[0];
684                match expert_type {
685                    AnyMoeExpertType::FineTuned => {
686                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
687                        row.push(Box::new(Mlp::replicate(
688                            self.layers[layer].mlp.get_params(),
689                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
690                            self.layers[layer].mlp.hidden_act(),
691                            &self.mapper.get_comm_for(layer)?,
692                        )?));
693                    }
694                    AnyMoeExpertType::LoraAdapter {
695                        rank,
696                        alpha,
697                        ref target_modules,
698                    } => {
699                        let vb_mlp = vb.pp(layer).pp(&mlp);
700
701                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
702                            Some(get_delta_from_lora_ab!(
703                                vb_mlp,
704                                rank,
705                                alpha,
706                                (hidden_size, intermediate_size),
707                                "gate_proj"
708                            ))
709                        } else {
710                            None
711                        };
712                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
713                            Some(get_delta_from_lora_ab!(
714                                vb_mlp,
715                                rank,
716                                alpha,
717                                (hidden_size, intermediate_size),
718                                "up_proj"
719                            ))
720                        } else {
721                            None
722                        };
723                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
724                            Some(get_delta_from_lora_ab!(
725                                vb_mlp,
726                                rank,
727                                alpha,
728                                (intermediate_size, hidden_size),
729                                "down_proj"
730                            ))
731                        } else {
732                            None
733                        };
734
735                        row.push(self.layers[layer].mlp.new_added_delta(vec![
736                            gate_proj_delta,
737                            up_proj_delta,
738                            down_proj_delta,
739                        ])?);
740                    }
741                }
742            }
743        }
744        for (layer, expert) in layers.into_iter().zip(experts) {
745            let mut experts_all = vec![self.layers[layer].mlp.clone()];
746            experts_all.extend(expert);
747            let (dtype, device) = self.layers[layer].mlp.dtype_device();
748            self.layers[layer].mlp = Box::new(MoeMlp::new(
749                experts_all,
750                config.clone(),
751                dtype,
752                &device,
753                layer,
754                gate_vb.as_ref(),
755            )?);
756        }
757        Ok(())
758    }
759    fn amoe_supported(&self) -> bool {
760        true
761    }
762}