mistralrs_core/models/
phi2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5/// Phi model.
6/// https://huggingface.co/microsoft/phi-2
7/// This corresponds to the model update made with the following commit:
8/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
9use candle_core::{DType, Device, Result, Tensor};
10use candle_nn::{Embedding, LayerNorm};
11use mistralrs_quant::{
12    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
13    ShardedVarBuilder,
14};
15use serde::{Deserialize, Serialize};
16
17use crate::{
18    amoe::{
19        AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
20        MoeMlp,
21    },
22    attention::SdpaParams,
23    device_map::DeviceMapper,
24    get_delta_from_lora_ab,
25    layers::{embedding, layer_norm, Activation, CausalMasker, MatMul, RotaryEmbedding, Sdpa},
26    layers_masker::PastKvLenCache,
27    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
28    pipeline::{
29        extract_logits,
30        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
31        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
32    },
33    serde_default_fn,
34    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
35};
36
37serde_default_fn!(bool, word_emb_default, false);
38
39// https://huggingface.co/microsoft/phi-2/blob/main/configuration_phi.py
40#[derive(Debug, Clone, Deserialize, Default, Serialize)]
41pub struct Config {
42    pub(crate) vocab_size: usize,
43    pub(crate) hidden_size: usize,
44    pub(crate) intermediate_size: usize,
45    pub(crate) num_hidden_layers: usize,
46    pub(crate) num_attention_heads: usize,
47    pub(crate) num_key_value_heads: Option<usize>,
48    pub(crate) hidden_act: Activation,
49    pub(crate) max_position_embeddings: usize,
50    pub(crate) layer_norm_eps: f64,
51    pub(crate) rope_theta: f32,
52    pub(crate) partial_rotary_factor: f64,
53    pub(crate) qk_layernorm: bool,
54    pub(crate) use_flash_attn: bool,
55    pub(crate) quantization_config: Option<QuantizedConfig>,
56    #[serde(default = "word_emb_default")]
57    pub(crate) tie_word_embeddings: bool,
58}
59
60impl Config {
61    pub(crate) fn num_key_value_heads(&self) -> usize {
62        self.num_key_value_heads.unwrap_or(self.num_attention_heads)
63    }
64
65    pub(crate) fn head_dim(&self) -> usize {
66        self.hidden_size / self.num_attention_heads
67    }
68}
69
70#[derive(Clone)]
71#[allow(clippy::upper_case_acronyms)]
72struct MLP {
73    fc1: Arc<dyn QuantMethod>,
74    fc2: Arc<dyn QuantMethod>,
75    act: Activation,
76    params: Vec<usize>,
77}
78
79impl MLP {
80    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
81        let fc1 = ColumnParallelLayer::new(
82            cfg.hidden_size,
83            cfg.intermediate_size,
84            &cfg.quantization_config,
85            true,
86            comm,
87            vb.pp("fc1"),
88        )?;
89        let fc2 = RowParallelLayer::new(
90            cfg.intermediate_size,
91            cfg.hidden_size,
92            &cfg.quantization_config,
93            true,
94            comm,
95            vb.pp("fc2"),
96        )?;
97        Ok(Self {
98            fc1,
99            fc2,
100            // This does not match the mixformers implementation where Gelu is used rather than
101            // GeluNew.
102            act: cfg.hidden_act,
103            params: vec![cfg.hidden_size, cfg.intermediate_size],
104        })
105    }
106}
107
108impl AnyMoeTrainableLayer for MLP {}
109
110impl MlpLayer for MLP {
111    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
112        let original_dtype = xs.dtype();
113        let mut xs = xs.clone();
114        if let Some(t) = self.fc1.quantized_act_type() {
115            xs = xs.to_dtype(t)?;
116        }
117        let mut res = MatMul.qmethod_matmul(
118            &MatMul.qmethod_matmul(&xs, &*self.fc1)?.apply(&self.act)?,
119            &*self.fc2,
120        )?;
121        if self.fc1.quantized_act_type().is_some() {
122            res = res.to_dtype(original_dtype)?;
123        }
124        Ok(res)
125    }
126    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
127        vec![&mut self.fc1, &mut self.fc2]
128    }
129    fn clone(&self) -> Box<dyn MlpLayer> {
130        Box::new(Clone::clone(self))
131    }
132    fn get_params(&self) -> &[usize] {
133        &self.params
134    }
135    fn hidden_act(&self) -> Activation {
136        self.act
137    }
138    // fc1, fc2
139    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
140        let new_fc1 = if let Some(ref delta) = deltas[0] {
141            self.fc1.add_delta_w(delta)?
142        } else {
143            self.fc1.clone()
144        };
145        let new_fc2 = if let Some(ref delta) = deltas[1] {
146            self.fc2.add_delta_w(delta)?
147        } else {
148            self.fc2.clone()
149        };
150
151        Ok(Box::new(Self {
152            fc1: new_fc1,
153            fc2: new_fc2,
154            act: self.act,
155            params: self.params.clone(),
156        }))
157    }
158
159    fn dtype_device(&self) -> (DType, Device) {
160        self.fc1.dtype_and_device()
161    }
162}
163
164struct Attention {
165    q_proj: Arc<dyn QuantMethod>,
166    k_proj: Arc<dyn QuantMethod>,
167    v_proj: Arc<dyn QuantMethod>,
168    dense: Arc<dyn QuantMethod>,
169    q_layernorm: Option<LayerNorm>,
170    k_layernorm: Option<LayerNorm>,
171    rotary_emb: Arc<RotaryEmbedding>,
172    num_heads: usize,
173    num_kv_heads: usize,
174    head_dim: usize,
175    paged_attn: Option<PagedAttention>,
176    sdpa_params: SdpaParams,
177}
178
179impl Attention {
180    fn new(
181        cfg: &Config,
182        vb: ShardedVarBuilder,
183        rope: Arc<RotaryEmbedding>,
184        paged_attn: Option<PagedAttention>,
185        comm: &Arc<mistralrs_quant::Comm>,
186    ) -> Result<Self> {
187        let num_heads = cfg.num_attention_heads;
188        let num_kv_heads = cfg.num_key_value_heads();
189        let head_dim = cfg.head_dim();
190        let q_proj = ColumnParallelLayer::new(
191            cfg.hidden_size,
192            num_heads * head_dim,
193            &cfg.quantization_config,
194            true,
195            comm,
196            vb.pp("q_proj"),
197        )?;
198        let kv_shard = mistralrs_quant::compute_kv_shard(
199            cfg.num_attention_heads,
200            cfg.hidden_size / cfg.num_attention_heads,
201            comm,
202        );
203        let k_proj = ColumnParallelLayer::new_with_shard(
204            cfg.hidden_size,
205            num_kv_heads * head_dim,
206            &cfg.quantization_config,
207            true,
208            comm,
209            kv_shard,
210            vb.pp("k_proj"),
211        )?;
212        let v_proj = ColumnParallelLayer::new_with_shard(
213            cfg.hidden_size,
214            num_kv_heads * head_dim,
215            &cfg.quantization_config,
216            true,
217            comm,
218            kv_shard,
219            vb.pp("v_proj"),
220        )?;
221        let dense = RowParallelLayer::new(
222            num_heads * head_dim,
223            cfg.hidden_size,
224            &cfg.quantization_config,
225            true,
226            comm,
227            vb.pp("dense"),
228        )?;
229        let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
230            let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
231            let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
232            (Some(q_layernorm), Some(k_layernorm))
233        } else {
234            (None, None)
235        };
236        Ok(Self {
237            q_proj,
238            k_proj,
239            v_proj,
240            dense,
241            q_layernorm,
242            k_layernorm,
243            rotary_emb: rope,
244            num_heads: num_heads / comm.world_size(),
245            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
246            head_dim,
247            paged_attn,
248            sdpa_params: SdpaParams {
249                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
250                    cfg.num_attention_heads,
251                    cfg.num_attention_heads,
252                    comm,
253                ),
254                use_flash_attn: cfg.use_flash_attn,
255                softcap: None,
256                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
257                sliding_window: None,
258            },
259        })
260    }
261
262    #[allow(clippy::too_many_arguments)]
263    fn forward(
264        &self,
265        xs: &Tensor,
266        mask: Option<&Tensor>,
267        seqlen_offsets: &[usize],
268        kv_cache: &mut KvCache,
269        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
270        flash_params: &FlashParams,
271    ) -> Result<Tensor> {
272        let (b_size, seq_len, _n_embd) = xs.dims3()?;
273
274        let original_dtype = xs.dtype();
275        let mut xs = xs.clone();
276        if let Some(t) = self.q_proj.quantized_act_type() {
277            xs = xs.to_dtype(t)?;
278        }
279        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
280        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
281        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
282        if self.q_proj.quantized_act_type().is_some() {
283            q = q.to_dtype(original_dtype)?;
284            k = k.to_dtype(original_dtype)?;
285            v = v.to_dtype(original_dtype)?;
286        }
287
288        let q = match &self.q_layernorm {
289            None => q,
290            Some(ln) => q.apply(ln)?,
291        };
292        let k = match &self.k_layernorm {
293            None => k,
294            Some(ln) => k.apply(ln)?,
295        };
296
297        let (q, k, v) = if seq_len != 1 {
298            let q = q
299                .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
300                .transpose(1, 2)?;
301            let k = k
302                .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
303                .transpose(1, 2)?;
304            let v = v
305                .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
306                .transpose(1, 2)?;
307            (q, k, v)
308        } else {
309            let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
310            let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
311            let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
312            (q, k, v)
313        };
314
315        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
316
317        let mut attn_output = match &self.paged_attn {
318            Some(paged_attn) => match metadata {
319                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
320                    &q,
321                    &k,
322                    &v,
323                    mask,
324                    Some(key_cache),
325                    Some(value_cache),
326                    input_metadata,
327                    &self.sdpa_params,
328                    Some(flash_params),
329                )?,
330                None => {
331                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
332                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
333                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
334                    // Sanity check.
335                    assert!(mask.is_some());
336                    paged_attn.forward(
337                        &q,
338                        &k,
339                        &v,
340                        mask,
341                        None,
342                        None,
343                        &input_metadata,
344                        &self.sdpa_params,
345                        Some(flash_params),
346                    )?
347                }
348            },
349            None => {
350                let (k, v) = kv_cache.append(&k, &v)?;
351
352                Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
353            }
354        };
355
356        if let Some(t) = self.q_proj.quantized_act_type() {
357            attn_output = attn_output.to_dtype(t)?;
358        }
359        attn_output = if mask.is_some() {
360            attn_output
361                .transpose(1, 2)?
362                .reshape((b_size, seq_len, ()))?
363        } else {
364            attn_output.reshape((b_size, seq_len, ()))?
365        };
366        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.dense)?;
367        if self.q_proj.quantized_act_type().is_some() {
368            res = res.to_dtype(original_dtype)?;
369        }
370        Ok(res)
371    }
372}
373
374struct DecoderLayer {
375    self_attn: Attention,
376    mlp: Box<dyn MlpLayer>,
377    input_layernorm: LayerNorm,
378}
379
380impl DecoderLayer {
381    #[allow(clippy::too_many_arguments)]
382    fn new(
383        cfg: &Config,
384        vb: ShardedVarBuilder,
385        mapper: &dyn DeviceMapper,
386        layer_idx: usize,
387        loading_isq: bool,
388        rotary_emb: Arc<RotaryEmbedding>,
389        paged_attn: Option<PagedAttention>,
390        comm: &Arc<mistralrs_quant::Comm>,
391    ) -> Result<Self> {
392        let self_attn = Attention::new(
393            cfg,
394            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
395            rotary_emb,
396            paged_attn,
397            comm,
398        )?;
399        let mlp = MLP::new(
400            cfg,
401            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
402            comm,
403        )?;
404        let input_layernorm = layer_norm(
405            cfg.hidden_size,
406            cfg.layer_norm_eps,
407            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
408        )?;
409        Ok(Self {
410            self_attn,
411            mlp: Box::new(mlp),
412            input_layernorm,
413        })
414    }
415
416    #[allow(clippy::too_many_arguments)]
417    fn forward(
418        &self,
419        xs: &Tensor,
420        mask: Option<&Tensor>,
421        seqlen_offsets: &[usize],
422        kv_cache: &mut KvCache,
423        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
424        flash_params: &FlashParams,
425    ) -> Result<Tensor> {
426        let residual = xs;
427        let xs = xs.apply(&self.input_layernorm)?;
428        let attn_outputs =
429            self.self_attn
430                .forward(&xs, mask, seqlen_offsets, kv_cache, metadata, flash_params)?;
431        let feed_forward_hidden_states = self.mlp.forward(&xs)?;
432        attn_outputs + feed_forward_hidden_states + residual
433    }
434}
435
436pub struct Model {
437    embed_tokens: Embedding,
438    layers: Vec<DecoderLayer>,
439    final_layernorm: LayerNorm,
440    lm_head: Arc<dyn QuantMethod>,
441    cache: EitherCache,
442    device: Device,
443    max_seq_len: usize,
444    mapper: Box<dyn DeviceMapper + Send + Sync>,
445    cfg: ModelConfigMetadata,
446}
447
448impl Model {
449    pub fn new(
450        cfg: &Config,
451        vb: ShardedVarBuilder,
452        is_gptx: bool,
453        normal_loading_metadata: NormalLoadingMetadata,
454        attention_mechanism: AttentionImplementation,
455    ) -> Result<Self> {
456        if let Some(ref quant_cfg) = &cfg.quantization_config {
457            tracing::info!(
458                "Using {} quantization: {}.",
459                quant_cfg.quant_method.to_string(),
460                quant_cfg.get_bits_name(&vb)
461            );
462        }
463        let mapper = normal_loading_metadata.mapper;
464        let vb_m = vb.pp("model");
465
466        let embed_tokens = embedding(
467            cfg.vocab_size,
468            cfg.hidden_size,
469            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
470        )?;
471        let final_layernorm = layer_norm(
472            cfg.hidden_size,
473            cfg.layer_norm_eps,
474            mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
475        )?;
476        let mut ropes = HashMap::new();
477        for layer_idx in 0..cfg.num_hidden_layers {
478            let device = mapper
479                .device_for(layer_idx, false)
480                .unwrap_or(&normal_loading_metadata.real_device);
481            // Alternative rope scalings are not supported
482            ropes.insert(
483                device.location(),
484                Arc::new(RotaryEmbedding::new_partial(
485                    cfg.rope_theta,
486                    (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
487                    cfg.max_position_embeddings,
488                    device,
489                    is_gptx,
490                    vb.dtype(),
491                )?),
492            );
493        }
494        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
495        let vb_m = vb_m.pp("layers");
496        for layer_idx in NiceProgressBar::<_, 'b'>(
497            0..cfg.num_hidden_layers,
498            "Loading repeating layers",
499            &normal_loading_metadata.multi_progress,
500        ) {
501            let device = mapper
502                .device_for(layer_idx, false)
503                .unwrap_or(&normal_loading_metadata.real_device);
504            let rotary_emb = ropes
505                .get(&device.location())
506                .expect("No RoPE for device location!")
507                .clone();
508            let paged_attn = match &attention_mechanism {
509                AttentionImplementation::Eager => None,
510                AttentionImplementation::PagedAttention => {
511                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
512                }
513            };
514            let comm = mapper.get_comm_for(layer_idx)?;
515            let layer = DecoderLayer::new(
516                cfg,
517                vb_m.pp(layer_idx),
518                &*mapper,
519                layer_idx,
520                normal_loading_metadata.loading_isq,
521                rotary_emb,
522                paged_attn,
523                &comm,
524            )?;
525            layers.push(layer)
526        }
527        let lm_head = if !cfg.tie_word_embeddings {
528            ReplicatedLayer::new(
529                cfg.hidden_size,
530                cfg.vocab_size,
531                &None,
532                false,
533                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
534            )?
535        } else {
536            unreachable!()
537        };
538        Ok(Self {
539            embed_tokens,
540            layers,
541            final_layernorm,
542            lm_head,
543            cache: EitherCache::Normal(NormalCache::new(
544                cfg.num_hidden_layers,
545                cfg.max_position_embeddings,
546            )),
547            device: normal_loading_metadata.real_device,
548            max_seq_len: cfg.max_position_embeddings,
549            cfg: ModelConfigMetadata {
550                max_seq_len: cfg.max_position_embeddings,
551                num_layers: cfg.num_hidden_layers,
552                hidden_size: cfg.hidden_size,
553                num_kv_heads: (cfg.num_key_value_heads() / mapper.get_comm_for(0)?.world_size())
554                    .max(1),
555                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
556                sliding_window: None,
557                k_head_dim: cfg.head_dim(),
558                v_head_dim: cfg.head_dim(),
559            },
560            mapper,
561        })
562    }
563
564    pub fn forward(
565        &self,
566        input_ids: &Tensor,
567        seqlen_offsets: &[usize],
568        context_lens: Vec<(usize, usize)>,
569        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
570        flash_params: &FlashParams,
571    ) -> Result<Tensor> {
572        let mut xs = input_ids.apply(&self.embed_tokens)?;
573        let cache = &mut self.cache.normal().0;
574        let mask = CausalMasker.make_causal_mask_matrix(
575            input_ids,
576            metadata
577                .as_ref()
578                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
579                .unwrap_or(cache as &dyn PastKvLenCache),
580            xs.dtype(),
581            self.cfg.num_attn_heads,
582        )?;
583        // PagedAttention prompt chunking
584        let mask = mask.filter(|_| {
585            metadata
586                .as_ref()
587                .map(|(_, meta)| meta.is_first_prompt_chunk)
588                .unwrap_or(true)
589        });
590        for (i, layer) in self.layers.iter().enumerate() {
591            xs = self.mapper.map(xs, i)?;
592            xs = layer.forward(
593                &xs,
594                mask.as_ref()
595                    .map(|m| m.to_device(xs.device()).unwrap())
596                    .as_ref(),
597                seqlen_offsets,
598                &mut cache[i],
599                metadata
600                    .as_ref()
601                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
602                flash_params,
603            )?;
604        }
605        let xs = xs.to_device(&self.device)?;
606        let mut xs = xs.apply(&self.final_layernorm)?;
607        if let Some(t) = self.lm_head.quantized_act_type() {
608            xs = xs.to_dtype(t)?;
609        }
610        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
611    }
612}
613
614impl IsqModel for Model {
615    fn get_layers(
616        &mut self,
617    ) -> (
618        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
619        &dyn DeviceMapper,
620    ) {
621        let mut tensors = Vec::new();
622        tensors.push((&mut self.lm_head, None));
623        for (i, layer) in self.layers.iter_mut().enumerate() {
624            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
625            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
626            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
627            tensors.push((&mut layer.self_attn.dense, Some(i)));
628            tensors.extend(
629                layer
630                    .mlp
631                    .get_isq_layers()
632                    .into_iter()
633                    .map(|m| (m, Some(i)))
634                    .collect::<Vec<_>>(),
635            );
636        }
637        (tensors, &*self.mapper)
638    }
639
640    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
641        let uvb = UnVarBuilder::new();
642
643        let uvb_m = uvb.pp("model");
644        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
645        uvb_m.pp("norm").add(&self.final_layernorm);
646
647        for (layer_idx, layer) in self.layers.iter().enumerate() {
648            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
649            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
650        }
651
652        uvb.to_safetensors()
653    }
654}
655
656impl NormalModel for Model {
657    fn forward(
658        &self,
659        input_ids: &Tensor,
660        seqlen_offsets: &[usize],
661        context_lens: Vec<(usize, usize)>,
662        _position_ids: Vec<usize>,
663        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
664        flash_params: &FlashParams,
665    ) -> Result<Tensor> {
666        self.forward(
667            input_ids,
668            seqlen_offsets,
669            context_lens,
670            metadata,
671            flash_params,
672        )
673    }
674    fn xlora_forward(
675        &self,
676        _input_ids: &Tensor,
677        _input_ids_full: &Tensor,
678        _seqlen_offsets: &[usize],
679        _seqlen_offsets_full: &[usize],
680        _no_kv_cache: bool,
681        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
682        _context_lens: Vec<(usize, usize)>,
683        _position_ids: Vec<usize>,
684        _flash_params: &FlashParams,
685        _flash_params_full: &FlashParams,
686    ) -> Result<Tensor> {
687        unimplemented!()
688    }
689    fn cache(&self) -> &EitherCache {
690        &self.cache
691    }
692    fn cache_mut(&mut self) -> &mut EitherCache {
693        &mut self.cache
694    }
695    fn device(&self) -> &Device {
696        &self.device
697    }
698    fn is_xlora(&self) -> bool {
699        false
700    }
701    fn max_seq_len(&self) -> usize {
702        self.max_seq_len
703    }
704    fn config(&self) -> &ModelConfigMetadata {
705        &self.cfg
706    }
707}
708
709impl AnyMoeBaseModelMixin for Model {
710    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
711        let mut mlps = Vec::new();
712        for layer in &self.layers {
713            mlps.push(&*layer.mlp);
714        }
715        mlps
716    }
717    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
718        let mut mlps = Vec::new();
719        for layer in &mut self.layers {
720            mlps.push(&mut layer.mlp);
721        }
722        mlps
723    }
724    fn create_anymoe_layers(
725        &mut self,
726        additional_vbs: Vec<ShardedVarBuilder>,
727        config: AnyMoeConfig,
728        (prefix, mlp): (String, String),
729        mut layers: Vec<usize>,
730        expert_type: AnyMoeExpertType,
731        gate_vb: Option<ShardedVarBuilder>,
732    ) -> Result<()> {
733        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
734        if layers.is_empty() {
735            layers = (0..self.layers.len()).collect::<Vec<_>>();
736        }
737        for _ in 0..layers.len() {
738            experts.push(Vec::new());
739        }
740        for vb in additional_vbs {
741            let vb = vb.pp(&prefix);
742            for (layer, row) in experts.iter_mut().enumerate() {
743                if !layers.contains(&layer) {
744                    continue;
745                }
746
747                let intermediate_size = self.layers[layer].mlp.get_params()[1];
748                let hidden_size = self.layers[layer].mlp.get_params()[0];
749                match expert_type {
750                    AnyMoeExpertType::FineTuned => {
751                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
752                        row.push(Box::new(MLP::new(
753                            &Config {
754                                intermediate_size: self.layers[layer].mlp.get_params()[1],
755                                hidden_size: self.layers[layer].mlp.get_params()[0],
756                                ..Default::default()
757                            },
758                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
759                            &self.mapper.get_comm_for(layer)?,
760                        )?));
761                    }
762                    AnyMoeExpertType::LoraAdapter {
763                        rank,
764                        alpha,
765                        ref target_modules,
766                    } => {
767                        let vb_mlp = vb.pp(layer).pp(&mlp);
768
769                        let fc1_delta = if target_modules.contains(&"fc1".to_string()) {
770                            Some(get_delta_from_lora_ab!(
771                                vb_mlp,
772                                rank,
773                                alpha,
774                                (hidden_size, intermediate_size),
775                                "fc1"
776                            ))
777                        } else {
778                            None
779                        };
780                        let fc2_delta = if target_modules.contains(&"fc2".to_string()) {
781                            Some(get_delta_from_lora_ab!(
782                                vb_mlp,
783                                rank,
784                                alpha,
785                                (intermediate_size, hidden_size),
786                                "fc2"
787                            ))
788                        } else {
789                            None
790                        };
791
792                        row.push(
793                            self.layers[layer]
794                                .mlp
795                                .new_added_delta(vec![fc1_delta, fc2_delta])?,
796                        );
797                    }
798                }
799            }
800        }
801        for (layer, expert) in layers.into_iter().zip(experts) {
802            let mut experts_all = vec![self.layers[layer].mlp.clone()];
803            experts_all.extend(expert);
804            let (dtype, device) = self.layers[layer].mlp.dtype_device();
805            self.layers[layer].mlp = Box::new(MoeMlp::new(
806                experts_all,
807                config.clone(),
808                dtype,
809                &device,
810                layer,
811                gate_vb.as_ref(),
812            )?);
813        }
814        Ok(())
815    }
816    fn amoe_supported(&self) -> bool {
817        true
818    }
819}