mistralrs_core/models/
mistral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3/// Mistral LLM, https://github.com/mistralai/mistral-src
4use candle_core::{Device, Module, Result, Tensor};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
7    ShardedVarBuilder,
8};
9use serde::{Deserialize, Serialize};
10use std::{collections::HashMap, sync::Arc};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18    layers_masker::PastKvLenCache,
19    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20    pipeline::{
21        extract_logits,
22        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
24    },
25    serde_default_fn,
26    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
27};
28
29serde_default_fn!(bool, use_flash_attn, false);
30serde_default_fn!(bool, tie_word_embeddings, false);
31
32#[derive(Debug, Clone, Default, Serialize, Deserialize)]
33pub struct Config {
34    pub(crate) vocab_size: usize,
35    pub(crate) hidden_size: usize,
36    pub(crate) intermediate_size: usize,
37    pub(crate) num_hidden_layers: usize,
38    pub(crate) num_attention_heads: usize,
39    pub(crate) num_key_value_heads: usize,
40    pub(crate) hidden_act: Activation,
41    pub(crate) max_position_embeddings: usize,
42    pub(crate) rms_norm_eps: f64,
43    pub(crate) rope_theta: f64,
44    pub(crate) sliding_window: Option<usize>,
45    #[serde(default = "use_flash_attn")]
46    pub(crate) use_flash_attn: bool,
47    pub(crate) head_dim: Option<usize>,
48    pub(crate) quantization_config: Option<QuantizedConfig>,
49    #[serde(default = "tie_word_embeddings")]
50    pub(crate) tie_word_embeddings: bool,
51}
52
53impl Config {
54    pub(crate) fn head_dim(&self) -> usize {
55        self.head_dim
56            .unwrap_or(self.hidden_size / self.num_attention_heads)
57    }
58}
59
60struct Attention {
61    q_proj: Arc<dyn QuantMethod>,
62    k_proj: Arc<dyn QuantMethod>,
63    v_proj: Arc<dyn QuantMethod>,
64    o_proj: Arc<dyn QuantMethod>,
65    num_heads: usize,
66    num_kv_heads: usize,
67    head_dim: usize,
68    rotary_emb: Arc<RotaryEmbedding>,
69    paged_attn: Option<PagedAttention>,
70    sdpa_params: SdpaParams,
71}
72
73impl Attention {
74    fn new(
75        rotary_emb: Arc<RotaryEmbedding>,
76        cfg: &Config,
77        vb: ShardedVarBuilder,
78        paged_attn: Option<PagedAttention>,
79        comm: &Arc<mistralrs_quant::Comm>,
80    ) -> Result<Self> {
81        let hidden_sz = cfg.hidden_size;
82        let num_heads = cfg.num_attention_heads;
83        let num_kv_heads = cfg.num_key_value_heads;
84        let head_dim = cfg.head_dim();
85        let q_proj = ColumnParallelLayer::new(
86            hidden_sz,
87            num_heads * head_dim,
88            &cfg.quantization_config,
89            false,
90            comm,
91            vb.pp("q_proj"),
92        )?;
93        let kv_shard = mistralrs_quant::compute_kv_shard(
94            cfg.num_key_value_heads,
95            cfg.hidden_size / cfg.num_attention_heads,
96            comm,
97        );
98        let k_proj = ColumnParallelLayer::new_with_shard(
99            hidden_sz,
100            num_kv_heads * head_dim,
101            &cfg.quantization_config,
102            false,
103            comm,
104            kv_shard,
105            vb.pp("k_proj"),
106        )?;
107        let v_proj = ColumnParallelLayer::new_with_shard(
108            hidden_sz,
109            num_kv_heads * head_dim,
110            &cfg.quantization_config,
111            false,
112            comm,
113            kv_shard,
114            vb.pp("v_proj"),
115        )?;
116        let o_proj = RowParallelLayer::new(
117            num_heads * head_dim,
118            hidden_sz,
119            &cfg.quantization_config,
120            false,
121            comm,
122            vb.pp("o_proj"),
123        )?;
124        Ok(Self {
125            q_proj,
126            k_proj,
127            v_proj,
128            o_proj,
129            num_heads: num_heads / comm.world_size(),
130            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
131            head_dim,
132            rotary_emb,
133            paged_attn,
134            sdpa_params: SdpaParams {
135                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
136                    cfg.num_key_value_heads,
137                    cfg.num_attention_heads,
138                    comm,
139                ),
140                use_flash_attn: cfg.use_flash_attn,
141                softcap: None,
142                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
143                sliding_window: cfg.sliding_window,
144            },
145        })
146    }
147
148    #[allow(clippy::too_many_arguments)]
149    fn forward(
150        &self,
151        xs: &Tensor,
152        attention_mask: Option<&Tensor>,
153        seqlen_offsets: &[usize],
154        kv_cache: &mut KvCache,
155        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
156        flash_params: &FlashParams,
157    ) -> Result<Tensor> {
158        let (b_sz, q_len, _) = xs.dims3()?;
159
160        let original_dtype = xs.dtype();
161        let mut xs = xs.clone();
162        if let Some(t) = self.q_proj.quantized_act_type() {
163            xs = xs.to_dtype(t)?;
164        }
165        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
166        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
167        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
168        if self.q_proj.quantized_act_type().is_some() {
169            q = q.to_dtype(original_dtype)?;
170            k = k.to_dtype(original_dtype)?;
171            v = v.to_dtype(original_dtype)?;
172        }
173
174        let (q, k, v) = if q_len != 1 {
175            let q = q
176                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
177                .transpose(1, 2)?;
178            let k = k
179                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
180                .transpose(1, 2)?;
181            let v = v
182                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
183                .transpose(1, 2)?;
184            (q, k, v)
185        } else {
186            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
187            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
188            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
189            (q, k, v)
190        };
191
192        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
193
194        let mut attn_output = match &self.paged_attn {
195            Some(paged_attn) => match metadata {
196                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
197                    &q,
198                    &k,
199                    &v,
200                    attention_mask,
201                    Some(key_cache),
202                    Some(value_cache),
203                    input_metadata,
204                    &self.sdpa_params,
205                    Some(flash_params),
206                )?,
207                None => {
208                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
209                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
210                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
211                    // Sanity check.
212                    assert!(attention_mask.is_some());
213                    paged_attn.forward(
214                        &q,
215                        &k,
216                        &v,
217                        attention_mask,
218                        None,
219                        None,
220                        &input_metadata,
221                        &self.sdpa_params,
222                        Some(flash_params),
223                    )?
224                }
225            },
226            None => {
227                let (k, v) = kv_cache.append(&k, &v)?;
228
229                Sdpa.run_attention(
230                    &q,
231                    &k,
232                    &v,
233                    attention_mask,
234                    Some(flash_params),
235                    &self.sdpa_params,
236                )?
237            }
238        };
239
240        if let Some(t) = self.q_proj.quantized_act_type() {
241            attn_output = attn_output.to_dtype(t)?;
242        }
243        attn_output = if attention_mask.is_some() {
244            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
245        } else {
246            attn_output.reshape((b_sz, q_len, ()))?
247        };
248        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
249        if self.q_proj.quantized_act_type().is_some() {
250            res = res.to_dtype(original_dtype)?;
251        }
252        Ok(res)
253    }
254}
255
256struct DecoderLayer {
257    self_attn: Attention,
258    mlp: Box<dyn MlpLayer>,
259    input_layernorm: RmsNorm,
260    post_attention_layernorm: RmsNorm,
261}
262
263impl DecoderLayer {
264    #[allow(clippy::too_many_arguments)]
265    fn new(
266        rotary_emb: Arc<RotaryEmbedding>,
267        cfg: &Config,
268        vb: ShardedVarBuilder,
269        mapper: &dyn DeviceMapper,
270        layer_idx: usize,
271        loading_isq: bool,
272        paged_attn: Option<PagedAttention>,
273        comm: &Arc<mistralrs_quant::Comm>,
274    ) -> Result<Self> {
275        let self_attn = Attention::new(
276            rotary_emb,
277            cfg,
278            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
279            paged_attn,
280            comm,
281        )?;
282        let mlp = Mlp::new(
283            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
284            cfg.hidden_size,
285            cfg.intermediate_size,
286            &cfg.quantization_config,
287            cfg.hidden_act,
288            comm,
289        )?;
290        let input_layernorm = RmsNorm::new(
291            cfg.hidden_size,
292            cfg.rms_norm_eps,
293            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
294        )?;
295        let post_attention_layernorm = RmsNorm::new(
296            cfg.hidden_size,
297            cfg.rms_norm_eps,
298            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
299        )?;
300        Ok(Self {
301            self_attn,
302            mlp: Box::new(mlp),
303            input_layernorm,
304            post_attention_layernorm,
305        })
306    }
307
308    #[allow(clippy::too_many_arguments)]
309    fn forward(
310        &self,
311        xs: &Tensor,
312        attention_mask: Option<&Tensor>,
313        seqlen_offsets: &[usize],
314        kv_cache: &mut KvCache,
315        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
316        flash_params: &FlashParams,
317    ) -> Result<Tensor> {
318        let residual = xs;
319        let xs = self.input_layernorm.forward(xs)?;
320        let xs = self.self_attn.forward(
321            &xs,
322            attention_mask,
323            seqlen_offsets,
324            kv_cache,
325            metadata,
326            flash_params,
327        )?;
328        let xs = (xs + residual)?;
329        let residual = &xs;
330        let xs = self
331            .mlp
332            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
333        residual + xs
334    }
335}
336
337pub struct Model {
338    embed_tokens: candle_nn::Embedding,
339    layers: Vec<DecoderLayer>,
340    norm: RmsNorm,
341    lm_head: Arc<dyn QuantMethod>,
342    sliding_window: Option<usize>,
343    device: Device,
344    pub(crate) cache: EitherCache,
345    max_seq_len: usize,
346    mapper: Box<dyn DeviceMapper + Send + Sync>,
347    cfg: ModelConfigMetadata,
348}
349
350impl Model {
351    pub fn new(
352        cfg: &Config,
353        vb: ShardedVarBuilder,
354        is_gptx: bool,
355        normal_loading_metadata: NormalLoadingMetadata,
356        attention_mechanism: AttentionImplementation,
357    ) -> Result<Self> {
358        let vb_m = vb.pp("model");
359        let vb_lm_head = vb.pp("lm_head");
360        Self::new_inner(
361            cfg,
362            vb_m,
363            vb_lm_head,
364            is_gptx,
365            normal_loading_metadata,
366            attention_mechanism,
367        )
368    }
369
370    pub fn new_inner(
371        cfg: &Config,
372        vb_m: ShardedVarBuilder,
373        vb_lm_head: ShardedVarBuilder,
374        is_gptx: bool,
375        normal_loading_metadata: NormalLoadingMetadata,
376        attention_mechanism: AttentionImplementation,
377    ) -> Result<Self> {
378        if let Some(ref quant_cfg) = &cfg.quantization_config {
379            tracing::info!(
380                "Using {} quantization: {}.",
381                quant_cfg.name(),
382                quant_cfg.get_bits_name(&vb_m)
383            );
384        }
385        let mapper = normal_loading_metadata.mapper;
386
387        let embed_tokens = embedding(
388            cfg.vocab_size,
389            cfg.hidden_size,
390            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
391            &cfg.quantization_config,
392        )?;
393
394        let head_dim = cfg.head_dim();
395        let mut ropes = HashMap::new();
396        for layer_idx in 0..cfg.num_hidden_layers {
397            let device = mapper
398                .device_for(layer_idx, false)
399                .unwrap_or(&normal_loading_metadata.real_device);
400            ropes.insert(
401                device.location(),
402                Arc::new(RotaryEmbedding::new(
403                    cfg.rope_theta as f32,
404                    head_dim,
405                    cfg.max_position_embeddings,
406                    device,
407                    is_gptx,
408                    vb_m.dtype(),
409                )?),
410            );
411        }
412
413        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
414        let vb_l = vb_m.pp("layers");
415        for layer_idx in NiceProgressBar::<_, 'b'>(
416            0..cfg.num_hidden_layers,
417            "Loading repeating layers",
418            &normal_loading_metadata.multi_progress,
419        ) {
420            let device = mapper
421                .device_for(layer_idx, false)
422                .unwrap_or(&normal_loading_metadata.real_device);
423            let rotary_emb = ropes
424                .get(&device.location())
425                .expect("No RoPE for device location!")
426                .clone();
427            let paged_attn = match &attention_mechanism {
428                AttentionImplementation::Eager => None,
429                AttentionImplementation::PagedAttention => {
430                    Some(PagedAttention::new(head_dim, device, None)?)
431                }
432            };
433            let comm = mapper.get_comm_for(layer_idx)?;
434            let layer = DecoderLayer::new(
435                rotary_emb.clone(),
436                cfg,
437                vb_l.pp(layer_idx),
438                &*mapper,
439                layer_idx,
440                normal_loading_metadata.loading_isq,
441                paged_attn,
442                &comm,
443            )?;
444            layers.push(layer)
445        }
446        let norm = RmsNorm::new(
447            cfg.hidden_size,
448            cfg.rms_norm_eps,
449            mapper.set_nm_device(vb_m.pp("norm"), false),
450        )?;
451        let lm_head = if !cfg.tie_word_embeddings {
452            ReplicatedLayer::new(
453                cfg.hidden_size,
454                cfg.vocab_size,
455                &None,
456                false,
457                mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
458            )?
459        } else {
460            ReplicatedLayer::from_linear(candle_nn::Linear::new(
461                mapper.cast_nm_device(
462                    embed_tokens.embeddings(),
463                    normal_loading_metadata.loading_isq,
464                )?,
465                None,
466            ))?
467        };
468        Ok(Self {
469            embed_tokens,
470            layers,
471            norm,
472            lm_head,
473            sliding_window: cfg.sliding_window,
474            device: normal_loading_metadata.real_device,
475            cache: EitherCache::Normal(NormalCache::new_sliding(
476                cfg.num_hidden_layers,
477                cfg.max_position_embeddings,
478                cfg.sliding_window,
479            )),
480            max_seq_len: cfg.max_position_embeddings,
481            cfg: ModelConfigMetadata {
482                max_seq_len: cfg.max_position_embeddings,
483                num_layers: cfg.num_hidden_layers,
484                hidden_size: cfg.hidden_size,
485                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
486                    .max(1),
487                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
488                sliding_window: cfg.sliding_window,
489                k_head_dim: cfg.head_dim(),
490                v_head_dim: cfg.head_dim(),
491            },
492            mapper,
493        })
494    }
495
496    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
497        self.embed_tokens.forward(input_ids)
498    }
499
500    pub fn forward(
501        &self,
502        input_ids: &Tensor,
503        seqlen_offsets: &[usize],
504        context_lens: Vec<(usize, usize)>,
505        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
506        flash_params: &FlashParams,
507    ) -> Result<Tensor> {
508        self.forward_embeds(
509            input_ids,
510            self.embed_tokens.forward(input_ids)?,
511            seqlen_offsets,
512            context_lens,
513            metadata,
514            flash_params,
515        )
516    }
517
518    #[allow(clippy::too_many_arguments)]
519    pub fn forward_embeds(
520        &self,
521        input_ids: &Tensor,
522        input_embeds: Tensor,
523        seqlen_offsets: &[usize],
524        context_lens: Vec<(usize, usize)>,
525        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
526        flash_params: &FlashParams,
527    ) -> Result<Tensor> {
528        let mut xs = input_embeds;
529        let cache = &mut self.cache.normal().0;
530        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
531            input_ids,
532            metadata
533                .as_ref()
534                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
535                .unwrap_or(cache as &dyn PastKvLenCache),
536            self.sliding_window,
537            xs.dtype(),
538            self.cfg.num_attn_heads,
539        )?;
540        // PagedAttention prompt chunking
541        let attention_mask = attention_mask.filter(|_| {
542            metadata
543                .as_ref()
544                .map(|(_, meta)| meta.is_first_prompt_chunk)
545                .unwrap_or(true)
546        });
547        for (i, layer) in self.layers.iter().enumerate() {
548            xs = self.mapper.map(xs, i)?;
549            xs = layer.forward(
550                &xs,
551                attention_mask
552                    .as_ref()
553                    .map(|m| m.to_device(xs.device()).unwrap())
554                    .as_ref(),
555                seqlen_offsets,
556                &mut cache[i],
557                metadata
558                    .as_ref()
559                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
560                flash_params,
561            )?;
562        }
563        let xs = xs.to_device(&self.device)?;
564        let mut xs = xs.apply(&self.norm)?;
565        if let Some(t) = self.lm_head.quantized_act_type() {
566            xs = xs.to_dtype(t)?;
567        }
568        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
569    }
570}
571
572impl IsqModel for Model {
573    fn get_layers(
574        &mut self,
575    ) -> (
576        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
577        &dyn DeviceMapper,
578    ) {
579        let mut tensors = Vec::new();
580        tensors.push((&mut self.lm_head, None));
581        for (i, layer) in self.layers.iter_mut().enumerate() {
582            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
583            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
584            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
585            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
586            tensors.extend(
587                layer
588                    .mlp
589                    .get_isq_layers()
590                    .into_iter()
591                    .map(|m| (m, Some(i)))
592                    .collect::<Vec<_>>(),
593            );
594        }
595        (tensors, &*self.mapper)
596    }
597
598    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
599        let uvb = UnVarBuilder::new();
600
601        let uvb_m = uvb.pp("model");
602        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
603        uvb_m.pp("norm").add(&self.norm);
604
605        for (layer_idx, layer) in self.layers.iter().enumerate() {
606            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
607            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
608            uvb_l
609                .pp("post_attention_layernorm")
610                .add(&layer.post_attention_layernorm);
611        }
612
613        uvb.to_safetensors()
614    }
615
616    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
617        // NOTE: dependant on the exact implementation in get_layers!
618        let mut names = Vec::new();
619        // lm_head
620        names.push(None);
621        for i in 0..self.layers.len() {
622            names.push(Some(format!("blk.{i}.attn_q.weight")));
623            names.push(Some(format!("blk.{i}.attn_k.weight")));
624            names.push(Some(format!("blk.{i}.attn_v.weight")));
625            names.push(Some(format!("blk.{i}.attn_output.weight")));
626            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
627            names.push(Some(format!("blk.{i}.ffn_up.weight")));
628            names.push(Some(format!("blk.{i}.ffn_down.weight")));
629        }
630        Ok(names)
631    }
632}
633
634impl NormalModel for Model {
635    fn forward(
636        &self,
637        input_ids: &Tensor,
638        seqlen_offsets: &[usize],
639        context_lens: Vec<(usize, usize)>,
640        _position_ids: Vec<usize>,
641        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
642        flash_params: &FlashParams,
643    ) -> Result<Tensor> {
644        self.forward(
645            input_ids,
646            seqlen_offsets,
647            context_lens,
648            metadata,
649            flash_params,
650        )
651    }
652    fn xlora_forward(
653        &self,
654        _input_ids: &Tensor,
655        _input_ids_full: &Tensor,
656        _seqlen_offsets: &[usize],
657        _seqlen_offsets_full: &[usize],
658        _no_kv_cache: bool,
659        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
660        _context_lens: Vec<(usize, usize)>,
661        _position_ids: Vec<usize>,
662        _flash_params: &FlashParams,
663        _flash_params_full: &FlashParams,
664    ) -> Result<Tensor> {
665        unimplemented!()
666    }
667    fn cache(&self) -> &EitherCache {
668        &self.cache
669    }
670    fn cache_mut(&mut self) -> &mut EitherCache {
671        &mut self.cache
672    }
673    fn device(&self) -> &Device {
674        &self.device
675    }
676    fn is_xlora(&self) -> bool {
677        false
678    }
679    fn max_seq_len(&self) -> usize {
680        self.max_seq_len
681    }
682    fn config(&self) -> &ModelConfigMetadata {
683        &self.cfg
684    }
685}
686
687impl AnyMoeBaseModelMixin for Model {
688    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
689        let mut mlps = Vec::new();
690        for layer in &self.layers {
691            mlps.push(&*layer.mlp);
692        }
693        mlps
694    }
695    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
696        let mut mlps = Vec::new();
697        for layer in &mut self.layers {
698            mlps.push(&mut layer.mlp);
699        }
700        mlps
701    }
702    fn create_anymoe_layers(
703        &mut self,
704        additional_vbs: Vec<ShardedVarBuilder>,
705        config: AnyMoeConfig,
706        (prefix, mlp): (String, String),
707        mut layers: Vec<usize>,
708        expert_type: AnyMoeExpertType,
709        gate_vb: Option<ShardedVarBuilder>,
710    ) -> Result<()> {
711        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
712        if layers.is_empty() {
713            layers = (0..self.layers.len()).collect::<Vec<_>>();
714        }
715        for _ in 0..layers.len() {
716            experts.push(Vec::new());
717        }
718        for vb in additional_vbs {
719            let vb = vb.pp(&prefix);
720            for (layer, row) in experts.iter_mut().enumerate() {
721                if !layers.contains(&layer) {
722                    continue;
723                }
724
725                let intermediate_size = self.layers[layer].mlp.get_params()[1];
726                let hidden_size = self.layers[layer].mlp.get_params()[0];
727                match expert_type {
728                    AnyMoeExpertType::FineTuned => {
729                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
730                        row.push(Box::new(Mlp::replicate(
731                            self.layers[layer].mlp.get_params(),
732                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
733                            self.layers[layer].mlp.hidden_act(),
734                            &self.mapper.get_comm_for(layer)?,
735                        )?));
736                    }
737                    AnyMoeExpertType::LoraAdapter {
738                        rank,
739                        alpha,
740                        ref target_modules,
741                    } => {
742                        let vb_mlp = vb.pp(layer).pp(&mlp);
743
744                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
745                            Some(get_delta_from_lora_ab!(
746                                vb_mlp,
747                                rank,
748                                alpha,
749                                (hidden_size, intermediate_size),
750                                "gate_proj"
751                            ))
752                        } else {
753                            None
754                        };
755                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
756                            Some(get_delta_from_lora_ab!(
757                                vb_mlp,
758                                rank,
759                                alpha,
760                                (hidden_size, intermediate_size),
761                                "up_proj"
762                            ))
763                        } else {
764                            None
765                        };
766                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
767                            Some(get_delta_from_lora_ab!(
768                                vb_mlp,
769                                rank,
770                                alpha,
771                                (intermediate_size, hidden_size),
772                                "down_proj"
773                            ))
774                        } else {
775                            None
776                        };
777
778                        row.push(self.layers[layer].mlp.new_added_delta(vec![
779                            gate_proj_delta,
780                            up_proj_delta,
781                            down_proj_delta,
782                        ])?);
783                    }
784                }
785            }
786        }
787        for (layer, expert) in layers.into_iter().zip(experts) {
788            let mut experts_all = vec![self.layers[layer].mlp.clone()];
789            experts_all.extend(expert);
790            let (dtype, device) = self.layers[layer].mlp.dtype_device();
791            self.layers[layer].mlp = Box::new(MoeMlp::new(
792                experts_all,
793                config.clone(),
794                dtype,
795                &device,
796                layer,
797                gate_vb.as_ref(),
798            )?);
799        }
800        Ok(())
801    }
802    fn amoe_supported(&self) -> bool {
803        true
804    }
805}