mistralrs_core/models/
starcoder2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::{LayerNorm, Linear};
5use mistralrs_quant::{
6    ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
7    ShardedVarBuilder, UnquantLinear,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12    amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
13    attention::SdpaParams,
14    device_map::DeviceMapper,
15    get_delta_from_lora_ab,
16    layers::{embedding, layer_norm, Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa},
17    layers_masker::PastKvLenCache,
18    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19    pipeline::{
20        extract_logits,
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23    },
24    serde_default_fn,
25    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
26    AnyMoeConfig, AnyMoeExpertType,
27};
28
29serde_default_fn!(bool, word_emb_default, false);
30
31#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
32pub struct Config {
33    pub(crate) vocab_size: usize,
34    pub(crate) hidden_size: usize,
35    pub(crate) intermediate_size: usize,
36    pub(crate) num_hidden_layers: usize,
37    pub(crate) num_attention_heads: usize,
38    pub(crate) num_key_value_heads: usize,
39    pub(crate) hidden_act: Activation,
40    pub(crate) max_position_embeddings: usize,
41    pub(crate) norm_epsilon: f64,
42    pub(crate) rope_theta: f64,
43    pub(crate) use_bias: bool,
44    pub(crate) sliding_window: Option<usize>,
45    pub(crate) use_flash_attn: bool,
46    pub(crate) quantization_config: Option<QuantizedConfig>,
47    #[serde(default = "word_emb_default")]
48    #[allow(dead_code)]
49    pub(crate) tie_word_embeddings: bool,
50}
51
52#[derive(Clone)]
53#[allow(clippy::upper_case_acronyms)]
54struct MLP {
55    c_fc: Arc<dyn QuantMethod>,
56    c_proj: Arc<dyn QuantMethod>,
57    act: Activation,
58    params: Vec<usize>,
59}
60
61impl MLP {
62    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
63        let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
64        let c_fc = ColumnParallelLayer::new(
65            h_size,
66            i_size,
67            &cfg.quantization_config,
68            cfg.use_bias,
69            comm,
70            vb.pp("c_fc"),
71        )?;
72        let c_proj = RowParallelLayer::new(
73            i_size,
74            h_size,
75            &cfg.quantization_config,
76            cfg.use_bias,
77            comm,
78            vb.pp("c_proj"),
79        )?;
80        Ok(Self {
81            c_fc,
82            c_proj,
83            act: cfg.hidden_act,
84            params: vec![h_size, i_size],
85        })
86    }
87}
88
89impl AnyMoeTrainableLayer for MLP {}
90
91impl MlpLayer for MLP {
92    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
93        let original_dtype = xs.dtype();
94        let mut xs = xs.clone();
95        if let Some(t) = self.c_fc.quantized_act_type() {
96            xs = xs.to_dtype(t)?;
97        }
98        let mut res = MatMul.qmethod_matmul(
99            &MatMul.qmethod_matmul(&xs, &*self.c_fc)?.apply(&self.act)?,
100            &*self.c_proj,
101        )?;
102        if self.c_fc.quantized_act_type().is_some() {
103            res = res.to_dtype(original_dtype)?;
104        }
105        Ok(res)
106    }
107    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
108        vec![&mut self.c_fc, &mut self.c_proj]
109    }
110    fn clone(&self) -> Box<dyn MlpLayer> {
111        Box::new(Clone::clone(self))
112    }
113    fn get_params(&self) -> &[usize] {
114        &self.params
115    }
116    fn hidden_act(&self) -> Activation {
117        self.act
118    }
119    // c_fc, c_proj
120    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
121        let new_c_fc = if let Some(ref delta) = deltas[0] {
122            self.c_fc.add_delta_w(delta)?
123        } else {
124            self.c_fc.clone()
125        };
126        let new_c_proj = if let Some(ref delta) = deltas[1] {
127            self.c_proj.add_delta_w(delta)?
128        } else {
129            self.c_proj.clone()
130        };
131
132        Ok(Box::new(Self {
133            c_fc: new_c_fc,
134            c_proj: new_c_proj,
135            act: self.act,
136            params: self.params.clone(),
137        }))
138    }
139
140    fn dtype_device(&self) -> (DType, Device) {
141        self.c_fc.dtype_and_device()
142    }
143}
144
145struct Attention {
146    q_proj: Arc<dyn QuantMethod>,
147    k_proj: Arc<dyn QuantMethod>,
148    v_proj: Arc<dyn QuantMethod>,
149    o_proj: Arc<dyn QuantMethod>,
150    num_heads: usize,
151    num_kv_heads: usize,
152    head_dim: usize,
153    rotary_emb: Arc<RotaryEmbedding>,
154    paged_attn: Option<PagedAttention>,
155    sdpa_params: SdpaParams,
156}
157
158impl Attention {
159    fn new(
160        rotary_emb: Arc<RotaryEmbedding>,
161        cfg: &Config,
162        vb: ShardedVarBuilder,
163        paged_attn: Option<PagedAttention>,
164        comm: &Arc<mistralrs_quant::Comm>,
165    ) -> Result<Self> {
166        let hidden_sz = cfg.hidden_size;
167        let num_heads = cfg.num_attention_heads;
168        let num_kv_heads = cfg.num_key_value_heads;
169        let head_dim = hidden_sz / num_heads;
170        let b = cfg.use_bias;
171        let q_proj = ColumnParallelLayer::new(
172            hidden_sz,
173            num_heads * head_dim,
174            &cfg.quantization_config,
175            b,
176            comm,
177            vb.pp("q_proj"),
178        )?;
179        let kv_shard = mistralrs_quant::compute_kv_shard(
180            cfg.num_key_value_heads,
181            cfg.hidden_size / cfg.num_attention_heads,
182            comm,
183        );
184        let k_proj = ColumnParallelLayer::new_with_shard(
185            hidden_sz,
186            num_kv_heads * head_dim,
187            &cfg.quantization_config,
188            b,
189            comm,
190            kv_shard,
191            vb.pp("k_proj"),
192        )?;
193        let v_proj = ColumnParallelLayer::new_with_shard(
194            hidden_sz,
195            num_kv_heads * head_dim,
196            &cfg.quantization_config,
197            b,
198            comm,
199            kv_shard,
200            vb.pp("v_proj"),
201        )?;
202        let o_proj = RowParallelLayer::new(
203            num_heads * head_dim,
204            hidden_sz,
205            &cfg.quantization_config,
206            b,
207            comm,
208            vb.pp("o_proj"),
209        )?;
210        Ok(Self {
211            q_proj,
212            k_proj,
213            v_proj,
214            o_proj,
215            num_heads: num_heads / comm.world_size(),
216            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
217            head_dim,
218            rotary_emb,
219            paged_attn,
220            sdpa_params: SdpaParams {
221                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
222                    cfg.num_key_value_heads,
223                    cfg.num_attention_heads,
224                    comm,
225                ),
226                use_flash_attn: cfg.use_flash_attn,
227                softcap: None,
228                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
229                sliding_window: cfg.sliding_window,
230            },
231        })
232    }
233
234    #[allow(clippy::too_many_arguments)]
235    fn forward(
236        &self,
237        xs: &Tensor,
238        attention_mask: Option<&Tensor>,
239        seqlen_offsets: &[usize],
240        kv_cache: &mut KvCache,
241        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
242        flash_params: &FlashParams,
243    ) -> Result<Tensor> {
244        let (b_sz, q_len, _) = xs.dims3()?;
245
246        let original_dtype = xs.dtype();
247        let mut xs = xs.clone();
248        if let Some(t) = self.q_proj.quantized_act_type() {
249            xs = xs.to_dtype(t)?;
250        }
251        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
252        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
253        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
254        if self.q_proj.quantized_act_type().is_some() {
255            q = q.to_dtype(original_dtype)?;
256            k = k.to_dtype(original_dtype)?;
257            v = v.to_dtype(original_dtype)?;
258        }
259
260        let (q, k, v) = if q_len != 1 {
261            let q = q
262                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
263                .transpose(1, 2)?;
264            let k = k
265                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
266                .transpose(1, 2)?;
267            let v = v
268                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
269                .transpose(1, 2)?;
270            (q, k, v)
271        } else {
272            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
273            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
274            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
275            (q, k, v)
276        };
277
278        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
279
280        let mut attn_output = match &self.paged_attn {
281            Some(paged_attn) => match metadata {
282                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
283                    &q,
284                    &k,
285                    &v,
286                    attention_mask,
287                    Some(key_cache),
288                    Some(value_cache),
289                    input_metadata,
290                    &self.sdpa_params,
291                    Some(flash_params),
292                )?,
293                None => {
294                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
295                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
296                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
297                    // Sanity check.
298                    assert!(attention_mask.is_some());
299                    paged_attn.forward(
300                        &q,
301                        &k,
302                        &v,
303                        attention_mask,
304                        None,
305                        None,
306                        &input_metadata,
307                        &self.sdpa_params,
308                        Some(flash_params),
309                    )?
310                }
311            },
312            None => {
313                let (k, v) = kv_cache.append(&k, &v)?;
314
315                Sdpa.run_attention(
316                    &q,
317                    &k,
318                    &v,
319                    attention_mask,
320                    Some(flash_params),
321                    &self.sdpa_params,
322                )?
323            }
324        };
325
326        if let Some(t) = self.q_proj.quantized_act_type() {
327            attn_output = attn_output.to_dtype(t)?;
328        }
329        attn_output = if attention_mask.is_some() {
330            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
331        } else {
332            attn_output.reshape((b_sz, q_len, ()))?
333        };
334        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
335        if self.q_proj.quantized_act_type().is_some() {
336            res = res.to_dtype(original_dtype)?;
337        }
338        Ok(res)
339    }
340}
341
342struct DecoderLayer {
343    self_attn: Attention,
344    mlp: Box<dyn MlpLayer>,
345    input_layernorm: LayerNorm,
346    post_attention_layernorm: LayerNorm,
347}
348
349impl DecoderLayer {
350    #[allow(clippy::too_many_arguments)]
351    fn new(
352        rotary_emb: Arc<RotaryEmbedding>,
353        cfg: &Config,
354        vb: ShardedVarBuilder,
355        mapper: &dyn DeviceMapper,
356        layer_idx: usize,
357        loading_isq: bool,
358        paged_attn: Option<PagedAttention>,
359        comm: &Arc<mistralrs_quant::Comm>,
360    ) -> Result<Self> {
361        let self_attn = Attention::new(
362            rotary_emb,
363            cfg,
364            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
365            paged_attn,
366            comm,
367        )?;
368        let mlp = MLP::new(
369            cfg,
370            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
371            comm,
372        )?;
373        let input_layernorm = layer_norm(
374            cfg.hidden_size,
375            cfg.norm_epsilon,
376            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
377        )?;
378        let post_attention_layernorm = layer_norm(
379            cfg.hidden_size,
380            cfg.norm_epsilon,
381            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
382        )?;
383        Ok(Self {
384            self_attn,
385            mlp: Box::new(mlp),
386            input_layernorm,
387            post_attention_layernorm,
388        })
389    }
390
391    #[allow(clippy::too_many_arguments)]
392    fn forward(
393        &self,
394        xs: &Tensor,
395        attention_mask: Option<&Tensor>,
396        seqlen_offsets: &[usize],
397        kv_cache: &mut KvCache,
398        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
399        flash_params: &FlashParams,
400    ) -> Result<Tensor> {
401        let residual = xs;
402        let xs = self.input_layernorm.forward(xs)?;
403        let xs = self.self_attn.forward(
404            &xs,
405            attention_mask,
406            seqlen_offsets,
407            kv_cache,
408            metadata,
409            flash_params,
410        )?;
411        let xs = (xs + residual)?;
412        let residual = &xs;
413        let xs = self
414            .mlp
415            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
416        residual + xs
417    }
418}
419
420pub struct Model {
421    embed_tokens: candle_nn::Embedding,
422    layers: Vec<DecoderLayer>,
423    norm: LayerNorm,
424    lm_head: Arc<dyn QuantMethod>,
425    sliding_window: Option<usize>,
426    device: Device,
427    cache: EitherCache,
428    max_seq_len: usize,
429    mapper: Box<dyn DeviceMapper + Send + Sync>,
430    cfg: ModelConfigMetadata,
431}
432
433impl Model {
434    pub fn new(
435        cfg: &Config,
436        vb: ShardedVarBuilder,
437        is_gptx: bool,
438        normal_loading_metadata: NormalLoadingMetadata,
439        attention_mechanism: AttentionImplementation,
440    ) -> Result<Self> {
441        if let Some(ref quant_cfg) = &cfg.quantization_config {
442            tracing::info!(
443                "Using {} quantization: {}.",
444                quant_cfg.quant_method.to_string(),
445                quant_cfg.get_bits_name(&vb)
446            );
447        }
448        let mapper = normal_loading_metadata.mapper;
449        let vb_m = vb.pp("model");
450
451        let embed_tokens = embedding(
452            cfg.vocab_size,
453            cfg.hidden_size,
454            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
455        )?;
456        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
457        let vb_l = vb_m.pp("layers");
458        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
459
460        let mut ropes = HashMap::new();
461        for layer_idx in 0..cfg.num_hidden_layers {
462            let device = mapper
463                .device_for(layer_idx, false)
464                .unwrap_or(&normal_loading_metadata.real_device);
465            ropes.insert(
466                device.location(),
467                Arc::new(RotaryEmbedding::new(
468                    cfg.rope_theta as f32,
469                    head_dim,
470                    cfg.max_position_embeddings,
471                    device,
472                    is_gptx,
473                    vb_m.dtype(),
474                )?),
475            );
476        }
477
478        for layer_idx in NiceProgressBar::<_, 'b'>(
479            0..cfg.num_hidden_layers,
480            "Loading repeating layers",
481            &normal_loading_metadata.multi_progress,
482        ) {
483            let device = mapper
484                .device_for(layer_idx, false)
485                .unwrap_or(&normal_loading_metadata.real_device);
486            let rotary_emb = ropes
487                .get(&device.location())
488                .expect("No RoPE for device location!")
489                .clone();
490            let paged_attn = match &attention_mechanism {
491                AttentionImplementation::Eager => None,
492                AttentionImplementation::PagedAttention => {
493                    Some(PagedAttention::new(head_dim, device, None)?)
494                }
495            };
496            let comm = mapper.get_comm_for(layer_idx)?;
497            layers.push(DecoderLayer::new(
498                rotary_emb.clone(),
499                cfg,
500                vb_l.pp(layer_idx),
501                &*mapper,
502                layer_idx,
503                normal_loading_metadata.loading_isq,
504                paged_attn,
505                &comm,
506            )?)
507        }
508        let norm = layer_norm(
509            cfg.hidden_size,
510            cfg.norm_epsilon,
511            mapper.set_nm_device(vb_m.pp("norm"), false),
512        )?;
513        let lm_head = mapper.cast_nm_device(
514            embed_tokens.embeddings(),
515            normal_loading_metadata.loading_isq,
516        )?;
517        Ok(Self {
518            embed_tokens,
519            layers,
520            norm,
521            lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
522                Linear::new(lm_head, None),
523            ))?),
524            sliding_window: cfg.sliding_window,
525            device: normal_loading_metadata.real_device,
526            cache: EitherCache::Normal(NormalCache::new_sliding(
527                cfg.num_hidden_layers,
528                cfg.max_position_embeddings,
529                cfg.sliding_window,
530            )),
531            max_seq_len: cfg.max_position_embeddings,
532            cfg: ModelConfigMetadata {
533                max_seq_len: cfg.max_position_embeddings,
534                num_layers: cfg.num_hidden_layers,
535                hidden_size: cfg.hidden_size,
536                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
537                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
538                    .max(1),
539                sliding_window: cfg.sliding_window,
540                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
541                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
542            },
543            mapper,
544        })
545    }
546
547    pub fn forward(
548        &self,
549        input_ids: &Tensor,
550        seqlen_offsets: &[usize],
551        context_lens: Vec<(usize, usize)>,
552        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
553        flash_params: &FlashParams,
554    ) -> Result<Tensor> {
555        let mut xs = self.embed_tokens.forward(input_ids)?;
556
557        let cache = &mut self.cache.normal().0;
558        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
559            input_ids,
560            metadata
561                .as_ref()
562                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
563                .unwrap_or(cache as &dyn PastKvLenCache),
564            self.sliding_window,
565            xs.dtype(),
566            self.cfg.num_attn_heads,
567        )?;
568        let attention_mask = attention_mask.filter(|_| {
569            metadata
570                .as_ref()
571                .map(|(_, meta)| meta.is_first_prompt_chunk)
572                .unwrap_or(true)
573        });
574
575        for (i, layer) in self.layers.iter().enumerate() {
576            xs = self.mapper.map(xs, i)?;
577            xs = layer.forward(
578                &xs,
579                attention_mask
580                    .as_ref()
581                    .map(|m| m.to_device(xs.device()).unwrap())
582                    .as_ref(),
583                seqlen_offsets,
584                &mut cache[i],
585                metadata
586                    .as_ref()
587                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
588                flash_params,
589            )?
590        }
591        let mut xs = xs.to_device(&self.device)?.apply(&self.norm)?;
592        if let Some(t) = self.lm_head.quantized_act_type() {
593            xs = xs.to_dtype(t)?;
594        }
595        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
596    }
597}
598
599impl IsqModel for Model {
600    fn get_layers(
601        &mut self,
602    ) -> (
603        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
604        &dyn DeviceMapper,
605    ) {
606        let mut tensors = Vec::new();
607        tensors.push((&mut self.lm_head, None));
608        for (i, layer) in self.layers.iter_mut().enumerate() {
609            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
610            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
611            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
612            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
613            tensors.extend(
614                layer
615                    .mlp
616                    .get_isq_layers()
617                    .into_iter()
618                    .map(|m| (m, Some(i)))
619                    .collect::<Vec<_>>(),
620            );
621        }
622        (tensors, &*self.mapper)
623    }
624
625    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
626        let uvb = UnVarBuilder::new();
627
628        let uvb_m = uvb.pp("model");
629        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
630        uvb_m.pp("norm").add(&self.norm);
631
632        for (layer_idx, layer) in self.layers.iter().enumerate() {
633            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
634            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
635            uvb_l
636                .pp("post_attention_layernorm")
637                .add(&layer.post_attention_layernorm);
638        }
639
640        uvb.to_safetensors()
641    }
642}
643
644impl NormalModel for Model {
645    fn forward(
646        &self,
647        input_ids: &Tensor,
648        seqlen_offsets: &[usize],
649        context_lens: Vec<(usize, usize)>,
650        _position_ids: Vec<usize>,
651        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
652        flash_params: &FlashParams,
653    ) -> Result<Tensor> {
654        self.forward(
655            input_ids,
656            seqlen_offsets,
657            context_lens,
658            metadata,
659            flash_params,
660        )
661    }
662    fn xlora_forward(
663        &self,
664        _input_ids: &Tensor,
665        _input_ids_full: &Tensor,
666        _seqlen_offsets: &[usize],
667        _seqlen_offsets_full: &[usize],
668        _no_kv_cache: bool,
669        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
670        _context_lens: Vec<(usize, usize)>,
671        _position_ids: Vec<usize>,
672        _flash_params: &FlashParams,
673        _flash_params_full: &FlashParams,
674    ) -> Result<Tensor> {
675        unimplemented!()
676    }
677    fn cache(&self) -> &EitherCache {
678        &self.cache
679    }
680    fn cache_mut(&mut self) -> &mut EitherCache {
681        &mut self.cache
682    }
683    fn device(&self) -> &Device {
684        &self.device
685    }
686    fn is_xlora(&self) -> bool {
687        false
688    }
689    fn max_seq_len(&self) -> usize {
690        self.max_seq_len
691    }
692    fn config(&self) -> &ModelConfigMetadata {
693        &self.cfg
694    }
695}
696
697impl AnyMoeBaseModelMixin for Model {
698    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
699        let mut mlps = Vec::new();
700        for layer in &self.layers {
701            mlps.push(&*layer.mlp);
702        }
703        mlps
704    }
705    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
706        let mut mlps = Vec::new();
707        for layer in &mut self.layers {
708            mlps.push(&mut layer.mlp);
709        }
710        mlps
711    }
712    fn create_anymoe_layers(
713        &mut self,
714        additional_vbs: Vec<ShardedVarBuilder>,
715        config: AnyMoeConfig,
716        (prefix, mlp): (String, String),
717        mut layers: Vec<usize>,
718        expert_type: AnyMoeExpertType,
719        gate_vb: Option<ShardedVarBuilder>,
720    ) -> Result<()> {
721        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
722        if layers.is_empty() {
723            layers = (0..self.layers.len()).collect::<Vec<_>>();
724        }
725        for _ in 0..layers.len() {
726            experts.push(Vec::new());
727        }
728        for vb in additional_vbs {
729            let vb = vb.pp(&prefix);
730            for (layer, row) in experts.iter_mut().enumerate() {
731                if !layers.contains(&layer) {
732                    continue;
733                }
734
735                let intermediate_size = self.layers[layer].mlp.get_params()[1];
736                let hidden_size = self.layers[layer].mlp.get_params()[0];
737                match expert_type {
738                    AnyMoeExpertType::FineTuned => {
739                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
740                        row.push(Box::new(MLP::new(
741                            &Config {
742                                intermediate_size: self.layers[layer].mlp.get_params()[1],
743                                hidden_size: self.layers[layer].mlp.get_params()[0],
744                                ..Default::default()
745                            },
746                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
747                            &self.mapper.get_comm_for(layer)?,
748                        )?));
749                    }
750                    AnyMoeExpertType::LoraAdapter {
751                        rank,
752                        alpha,
753                        ref target_modules,
754                    } => {
755                        let vb_mlp = vb.pp(layer).pp(&mlp);
756
757                        let c_fc_delta = if target_modules.contains(&"c_fc".to_string()) {
758                            Some(get_delta_from_lora_ab!(
759                                vb_mlp,
760                                rank,
761                                alpha,
762                                (hidden_size, intermediate_size),
763                                "c_fc"
764                            ))
765                        } else {
766                            None
767                        };
768                        let c_proj_delta = if target_modules.contains(&"c_proj".to_string()) {
769                            Some(get_delta_from_lora_ab!(
770                                vb_mlp,
771                                rank,
772                                alpha,
773                                (intermediate_size, hidden_size),
774                                "c_proj"
775                            ))
776                        } else {
777                            None
778                        };
779
780                        row.push(
781                            self.layers[layer]
782                                .mlp
783                                .new_added_delta(vec![c_fc_delta, c_proj_delta])?,
784                        );
785                    }
786                }
787            }
788        }
789        for (layer, expert) in layers.into_iter().zip(experts) {
790            let mut experts_all = vec![self.layers[layer].mlp.clone()];
791            experts_all.extend(expert);
792            let (dtype, device) = self.layers[layer].mlp.dtype_device();
793            self.layers[layer].mlp = Box::new(MoeMlp::new(
794                experts_all,
795                config.clone(),
796                dtype,
797                &device,
798                layer,
799                gate_vb.as_ref(),
800            )?);
801        }
802        Ok(())
803    }
804    fn amoe_supported(&self) -> bool {
805        true
806    }
807}