mistralrs_core/models/
phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use candle_core::{DType, Device, Module, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
7use std::{collections::HashMap, sync::Arc};
8
9use crate::{
10    amoe::{
11        AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
12        MoeMlp,
13    },
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{
18        embedding, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19        PhiRotaryEmbedding, RmsNorm, Sdpa,
20    },
21    layers_masker::PastKvLenCache,
22    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23    pipeline::{
24        extract_logits,
25        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27    },
28    serde_default_fn,
29    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
35#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
36pub struct Config {
37    pub vocab_size: usize,
38    pub hidden_act: Activation,
39    pub hidden_size: usize,
40    pub intermediate_size: usize,
41    pub num_hidden_layers: usize,
42    pub num_attention_heads: usize,
43    pub num_key_value_heads: usize,
44    pub rms_norm_eps: f64,
45    pub rope_theta: f64,
46    pub bos_token_id: Option<u32>,
47    pub eos_token_id: Option<u32>,
48    pub rope_scaling: Option<PhiRopeScalingConfig>,
49    pub max_position_embeddings: usize,
50    pub sliding_window: Option<usize>,
51    pub original_max_position_embeddings: usize,
52    pub quantization_config: Option<QuantizedConfig>,
53    #[serde(default = "word_emb_default")]
54    pub tie_word_embeddings: bool,
55    pub partial_rotary_factor: Option<f64>,
56}
57
58impl From<Config> for PhiRopeConfig {
59    fn from(val: Config) -> Self {
60        PhiRopeConfig {
61            rope_scaling: val.rope_scaling,
62            max_position_embeddings: val.max_position_embeddings,
63            original_max_position_embeddings: val.original_max_position_embeddings,
64            rope_theta: val.rope_theta,
65            head_dim: val.hidden_size / val.num_attention_heads,
66            partial_rotary_factor: val.partial_rotary_factor,
67        }
68    }
69}
70
71impl Config {
72    pub fn head_dim(&self) -> usize {
73        self.hidden_size / self.num_attention_heads
74    }
75}
76
77struct Attention {
78    qkv_proj: Arc<dyn QuantMethod>,
79    o_proj: Arc<dyn QuantMethod>,
80    num_heads: usize,
81    num_kv_heads: usize,
82    head_dim: usize,
83    rotary_emb: Arc<PhiRotaryEmbedding>,
84    paged_attn: Option<PagedAttention>,
85    sdpa_params: SdpaParams,
86}
87
88impl Attention {
89    fn new(
90        rotary_emb: Arc<PhiRotaryEmbedding>,
91        cfg: &Config,
92        vb: ShardedVarBuilder,
93        paged_attn: Option<PagedAttention>,
94    ) -> Result<Self> {
95        let num_heads = cfg.num_attention_heads;
96        let num_kv_heads = cfg.num_key_value_heads;
97        let head_dim = cfg.head_dim();
98        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
99
100        // No TP here.
101        let qkv_proj = mistralrs_quant::linear_no_bias(
102            cfg.hidden_size,
103            op_size,
104            &cfg.quantization_config,
105            vb.pp("qkv_proj"),
106        )?;
107
108        let o_proj = mistralrs_quant::linear_no_bias(
109            num_heads * head_dim,
110            cfg.hidden_size,
111            &cfg.quantization_config,
112            vb.pp("o_proj"),
113        )?;
114
115        Ok(Self {
116            qkv_proj,
117            o_proj,
118            rotary_emb,
119            num_heads,
120            num_kv_heads,
121            head_dim,
122            paged_attn,
123            sdpa_params: SdpaParams {
124                n_kv_groups: num_heads / num_kv_heads,
125                softcap: None,
126                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
127                sliding_window: cfg.sliding_window,
128            },
129        })
130    }
131
132    #[allow(clippy::too_many_arguments)]
133    fn forward(
134        &self,
135        xs: &Tensor,
136        attention_mask: Option<&Tensor>,
137        seqlen_offsets: &[usize],
138        position_ids: &[usize],
139        kv_cache: &mut KvCache,
140        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
141        flash_params: &FlashParams,
142    ) -> Result<Tensor> {
143        let (b_sz, q_len, _) = xs.dims3()?;
144
145        let original_dtype = xs.dtype();
146        let mut xs = xs.clone();
147        if let Some(t) = self.qkv_proj.quantized_act_type() {
148            xs = xs.to_dtype(t)?;
149        }
150        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
151        if self.qkv_proj.quantized_act_type().is_some() {
152            qkv = qkv.to_dtype(original_dtype)?;
153        }
154        let query_pos = self.num_heads * self.head_dim;
155        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
156        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
157        let v = qkv.narrow(
158            D::Minus1,
159            query_pos + self.num_kv_heads * self.head_dim,
160            self.num_kv_heads * self.head_dim,
161        )?;
162
163        let (q, k, v) = if q_len != 1 {
164            let q = q
165                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
166                .transpose(1, 2)?;
167            let k = k
168                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
169                .transpose(1, 2)?;
170            let v = v
171                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172                .transpose(1, 2)?;
173            (q, k, v)
174        } else {
175            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
176            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
177            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178            (q, k, v)
179        };
180
181        let (q, k) = self
182            .rotary_emb
183            .forward(&q, &k, seqlen_offsets, position_ids)?;
184
185        let mut attn_output = match &self.paged_attn {
186            Some(paged_attn) => match metadata {
187                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
188                    &q,
189                    &k.contiguous()?,
190                    &v.contiguous()?,
191                    attention_mask,
192                    Some(key_cache),
193                    Some(value_cache),
194                    input_metadata,
195                    &self.sdpa_params,
196                    Some(flash_params),
197                )?,
198                None => {
199                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
200                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
201                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
202                    // Sanity check.
203                    assert!(attention_mask.is_some());
204                    paged_attn.forward(
205                        &q,
206                        &k.contiguous()?,
207                        &v.contiguous()?,
208                        attention_mask,
209                        None,
210                        None,
211                        &input_metadata,
212                        &self.sdpa_params,
213                        Some(flash_params),
214                    )?
215                }
216            },
217            _ => {
218                let (k, v) = kv_cache.append(&k, &v)?;
219
220                Sdpa.run_attention(
221                    &q,
222                    &k,
223                    &v,
224                    attention_mask,
225                    Some(flash_params),
226                    &self.sdpa_params,
227                )?
228            }
229        };
230
231        if let Some(t) = self.qkv_proj.quantized_act_type() {
232            attn_output = attn_output.to_dtype(t)?;
233        }
234        attn_output = if attention_mask.is_some() {
235            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
236        } else {
237            attn_output.reshape((b_sz, q_len, ()))?
238        };
239        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
240        if self.qkv_proj.quantized_act_type().is_some() {
241            res = res.to_dtype(original_dtype)?;
242        }
243        Ok(res)
244    }
245}
246
247#[derive(Clone)]
248struct Mlp {
249    gate_up_proj: Arc<dyn QuantMethod>,
250    down_proj: Arc<dyn QuantMethod>,
251    act_fn: Activation,
252    i_size: usize,
253    params: Vec<usize>,
254}
255
256impl Mlp {
257    fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
258        let hidden_size = cfg.hidden_size;
259        let i_size = cfg.intermediate_size;
260
261        // No TP here.
262        let gate_up_proj = mistralrs_quant::linear_no_bias(
263            hidden_size,
264            2 * i_size,
265            &cfg.quantization_config,
266            vb.pp("gate_up_proj"),
267        )?;
268
269        let down_proj = mistralrs_quant::linear_no_bias(
270            i_size,
271            hidden_size,
272            &cfg.quantization_config,
273            vb.pp("down_proj"),
274        )?;
275
276        Ok(Self {
277            gate_up_proj,
278            down_proj,
279            act_fn: cfg.hidden_act,
280            i_size,
281            params: vec![hidden_size, i_size],
282        })
283    }
284}
285
286impl AnyMoeTrainableLayer for Mlp {}
287
288impl MlpLayer for Mlp {
289    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
290        let original_dtype = xs.dtype();
291        let mut xs = xs.clone();
292        if let Some(t) = self.gate_up_proj.quantized_act_type() {
293            xs = xs.to_dtype(t)?;
294        }
295        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
296        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
297        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
298        let up_states = (up_states * gate.apply(&self.act_fn))?;
299        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
300        if self.gate_up_proj.quantized_act_type().is_some() {
301            res = res.to_dtype(original_dtype)?;
302        }
303        Ok(res)
304    }
305    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
306        vec![&mut self.gate_up_proj, &mut self.down_proj]
307    }
308    fn clone(&self) -> Box<dyn MlpLayer> {
309        Box::new(Clone::clone(self))
310    }
311    fn get_params(&self) -> &[usize] {
312        &self.params
313    }
314    fn hidden_act(&self) -> Activation {
315        self.act_fn
316    }
317    // gate_up, down
318    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
319        let new_gate_up = if let Some(ref delta) = deltas[0] {
320            self.gate_up_proj.add_delta_w(delta)?
321        } else {
322            self.gate_up_proj.clone()
323        };
324        let new_down = if let Some(ref delta) = deltas[1] {
325            self.down_proj.add_delta_w(delta)?
326        } else {
327            self.down_proj.clone()
328        };
329
330        Ok(Box::new(Self {
331            gate_up_proj: new_gate_up,
332            down_proj: new_down,
333            act_fn: self.act_fn,
334            i_size: self.i_size,
335            params: self.params.clone(),
336        }))
337    }
338
339    fn dtype_device(&self) -> (DType, Device) {
340        self.gate_up_proj.dtype_and_device()
341    }
342}
343
344struct DecoderLayer {
345    self_attn: Attention,
346    mlp: Box<dyn MlpLayer>,
347    input_layernorm: RmsNorm,
348    post_attention_layernorm: RmsNorm,
349}
350
351impl DecoderLayer {
352    fn new(
353        rotary_emb: Arc<PhiRotaryEmbedding>,
354        cfg: &Config,
355        vb: ShardedVarBuilder,
356        mapper: &dyn DeviceMapper,
357        layer_idx: usize,
358        loading_isq: bool,
359        paged_attn: Option<PagedAttention>,
360    ) -> Result<Self> {
361        let self_attn = Attention::new(
362            rotary_emb,
363            cfg,
364            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
365            paged_attn,
366        )?;
367        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
368        let input_layernorm = RmsNorm::new(
369            cfg.hidden_size,
370            cfg.rms_norm_eps,
371            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
372        )?;
373        let post_attention_layernorm = RmsNorm::new(
374            cfg.hidden_size,
375            cfg.rms_norm_eps,
376            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
377        )?;
378        Ok(Self {
379            self_attn,
380            mlp: Box::new(mlp),
381            input_layernorm,
382            post_attention_layernorm,
383        })
384    }
385
386    #[allow(clippy::too_many_arguments)]
387    fn forward(
388        &self,
389        xs: &Tensor,
390        attention_mask: Option<&Tensor>,
391        seqlen_offsets: &[usize],
392        position_ids: &[usize],
393        kv_cache: &mut KvCache,
394        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
395        flash_params: &FlashParams,
396    ) -> Result<Tensor> {
397        let residual = xs;
398        let xs = self.input_layernorm.forward(xs)?;
399        let xs = self.self_attn.forward(
400            &xs,
401            attention_mask,
402            seqlen_offsets,
403            position_ids,
404            kv_cache,
405            metadata,
406            flash_params,
407        )?;
408        let xs = (xs + residual)?;
409        let residual = &xs;
410        let xs = self
411            .mlp
412            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
413        residual + xs
414    }
415}
416
417pub struct Model {
418    embed_tokens: candle_nn::Embedding,
419    layers: Vec<DecoderLayer>,
420    norm: RmsNorm,
421    lm_head: Arc<dyn QuantMethod>,
422    device: Device,
423    cache: EitherCache,
424    max_seq_len: usize,
425    mapper: Box<dyn DeviceMapper + Send + Sync>,
426    sliding_window: Option<usize>,
427    cfg: ModelConfigMetadata,
428}
429
430impl Model {
431    pub fn new(
432        cfg: &Config,
433        vb: ShardedVarBuilder,
434        _is_gptx: bool,
435        normal_loading_metadata: NormalLoadingMetadata,
436        attention_mechanism: AttentionImplementation,
437    ) -> Result<Self> {
438        if let Some(ref quant_cfg) = &cfg.quantization_config {
439            tracing::info!(
440                "Using {} quantization: {}.",
441                quant_cfg.name(),
442                quant_cfg.get_bits_name(&vb)
443            );
444        }
445        let mapper = normal_loading_metadata.mapper;
446        let vb_m = vb.pp("model");
447
448        let embed_tokens = embedding(
449            cfg.vocab_size,
450            cfg.hidden_size,
451            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
452            &cfg.quantization_config,
453        )?;
454        let mut ropes = HashMap::new();
455        for layer_idx in 0..cfg.num_hidden_layers {
456            let device = mapper
457                .device_for(layer_idx, false)
458                .unwrap_or(&normal_loading_metadata.real_device);
459            ropes.insert(
460                device.location(),
461                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
462            );
463        }
464        let vb_l = vb_m.pp("layers");
465        let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
466            0..cfg.num_hidden_layers,
467            "Loading repeating layers",
468            &normal_loading_metadata.multi_progress,
469        )
470        .par_iter_if_isq(|layer_idx| {
471            let device = mapper
472                .device_for(layer_idx, false)
473                .unwrap_or(&normal_loading_metadata.real_device);
474            let rotary_emb = ropes
475                .get(&device.location())
476                .expect("No RoPE for device location!")
477                .clone();
478            let paged_attn = match &attention_mechanism {
479                AttentionImplementation::Eager => None,
480                AttentionImplementation::PagedAttention => {
481                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
482                }
483            };
484            DecoderLayer::new(
485                rotary_emb.clone(),
486                cfg,
487                vb_l.pp(layer_idx),
488                &*mapper,
489                layer_idx,
490                normal_loading_metadata.loading_isq,
491                paged_attn,
492            )
493        })?;
494        let norm = RmsNorm::new(
495            cfg.hidden_size,
496            cfg.rms_norm_eps,
497            mapper.set_nm_device(vb_m.pp("norm"), false),
498        )?;
499        let lm_head = if !cfg.tie_word_embeddings {
500            ReplicatedLayer::new(
501                cfg.hidden_size,
502                cfg.vocab_size,
503                &None,
504                false,
505                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
506            )?
507        } else {
508            ReplicatedLayer::from_linear(candle_nn::Linear::new(
509                mapper.cast_nm_device(
510                    embed_tokens.embeddings(),
511                    normal_loading_metadata.loading_isq,
512                )?,
513                None,
514            ))?
515        };
516        Ok(Self {
517            embed_tokens,
518            layers,
519            norm,
520            lm_head,
521            device: normal_loading_metadata.real_device,
522            cache: EitherCache::Normal(NormalCache::new_sliding(
523                cfg.num_hidden_layers,
524                cfg.max_position_embeddings,
525                cfg.sliding_window,
526            )),
527            max_seq_len: cfg.max_position_embeddings,
528            sliding_window: cfg.sliding_window,
529            cfg: ModelConfigMetadata {
530                max_seq_len: cfg.max_position_embeddings,
531                num_layers: cfg.num_hidden_layers,
532                hidden_size: cfg.hidden_size,
533                num_attn_heads: cfg.num_attention_heads,
534                num_kv_heads: cfg.num_key_value_heads,
535                sliding_window: cfg.sliding_window,
536                k_head_dim: cfg.head_dim(),
537                v_head_dim: cfg.head_dim(),
538            },
539            mapper,
540        })
541    }
542
543    pub fn forward(
544        &self,
545        input_ids: &Tensor,
546        seqlen_offsets: &[usize],
547        position_ids: &[usize],
548        context_lens: Vec<(usize, usize)>,
549        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
550        flash_params: &FlashParams,
551    ) -> Result<Tensor> {
552        let mut xs = self.embed_tokens.forward(input_ids)?;
553        let cache = &mut self.cache.normal().0;
554        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
555            input_ids,
556            metadata
557                .as_ref()
558                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
559                .unwrap_or(cache as &dyn PastKvLenCache),
560            self.sliding_window,
561            xs.dtype(),
562            self.cfg.num_attn_heads,
563        )?;
564        // PagedAttention prompt chunking
565        let attention_mask = attention_mask.filter(|_| {
566            metadata
567                .as_ref()
568                .map(|(_, meta)| meta.is_first_prompt_chunk)
569                .unwrap_or(true)
570        });
571
572        for (i, layer) in self.layers.iter().enumerate() {
573            xs = self.mapper.map(xs, i)?;
574            xs = layer.forward(
575                &xs,
576                attention_mask
577                    .as_ref()
578                    .map(|m| m.to_device(xs.device()).unwrap())
579                    .as_ref(),
580                seqlen_offsets,
581                position_ids,
582                &mut cache[i],
583                metadata
584                    .as_ref()
585                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
586                flash_params,
587            )?
588        }
589        let xs = xs.to_device(&self.device)?;
590        let mut xs = xs.apply(&self.norm)?;
591        if let Some(t) = self.lm_head.quantized_act_type() {
592            xs = xs.to_dtype(t)?;
593        }
594        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
595    }
596}
597
598impl IsqModel for Model {
599    fn get_layers(
600        &mut self,
601    ) -> (
602        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
603        &dyn DeviceMapper,
604    ) {
605        let mut tensors = Vec::new();
606        tensors.push((&mut self.lm_head, None));
607        for (i, layer) in self.layers.iter_mut().enumerate() {
608            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
609            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
610            tensors.extend(
611                layer
612                    .mlp
613                    .get_isq_layers()
614                    .into_iter()
615                    .map(|m| (m, Some(i)))
616                    .collect::<Vec<_>>(),
617            );
618        }
619        (tensors, &*self.mapper)
620    }
621
622    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
623        let uvb = UnVarBuilder::new();
624
625        let uvb_m = uvb.pp("model");
626        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
627        uvb_m.pp("norm").add(&self.norm);
628
629        for (layer_idx, layer) in self.layers.iter().enumerate() {
630            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
631            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
632            uvb_l
633                .pp("post_attention_layernorm")
634                .add(&layer.post_attention_layernorm);
635        }
636
637        uvb.to_safetensors()
638    }
639
640    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
641        // NOTE: dependant on the exact implementation in get_layers!
642        let mut names = Vec::new();
643        // lm_head
644        names.push(None);
645        for i in 0..self.layers.len() {
646            names.push(Some(format!("blk.{i}.attn_qkv.weight")));
647            names.push(Some(format!("blk.{i}.attn_output.weight")));
648            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
649            names.push(Some(format!("blk.{i}.ffn_up.weight")));
650            names.push(Some(format!("blk.{i}.ffn_down.weight")));
651        }
652        Ok(names)
653    }
654}
655
656impl NormalModel for Model {
657    fn forward(
658        &self,
659        input_ids: &Tensor,
660        seqlen_offsets: &[usize],
661        context_lens: Vec<(usize, usize)>,
662        position_ids: Vec<usize>,
663        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
664        flash_params: &FlashParams,
665    ) -> Result<Tensor> {
666        self.forward(
667            input_ids,
668            seqlen_offsets,
669            &position_ids,
670            context_lens,
671            metadata,
672            flash_params,
673        )
674    }
675    fn xlora_forward(
676        &self,
677        _input_ids: &Tensor,
678        _input_ids_full: &Tensor,
679        _seqlen_offsets: &[usize],
680        _seqlen_offsets_full: &[usize],
681        _no_kv_cache: bool,
682        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
683        _context_lens: Vec<(usize, usize)>,
684        _position_ids: Vec<usize>,
685        _flash_params: &FlashParams,
686        _flash_params_full: &FlashParams,
687    ) -> Result<Tensor> {
688        unimplemented!()
689    }
690    fn cache(&self) -> &EitherCache {
691        &self.cache
692    }
693    fn cache_mut(&mut self) -> &mut EitherCache {
694        &mut self.cache
695    }
696    fn device(&self) -> &Device {
697        &self.device
698    }
699    fn is_xlora(&self) -> bool {
700        false
701    }
702    fn max_seq_len(&self) -> usize {
703        self.max_seq_len
704    }
705    fn config(&self) -> &ModelConfigMetadata {
706        &self.cfg
707    }
708}
709
710impl AnyMoeBaseModelMixin for Model {
711    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
712        let mut mlps = Vec::new();
713        for layer in &self.layers {
714            mlps.push(&*layer.mlp);
715        }
716        mlps
717    }
718    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
719        let mut mlps = Vec::new();
720        for layer in &mut self.layers {
721            mlps.push(&mut layer.mlp);
722        }
723        mlps
724    }
725    fn create_anymoe_layers(
726        &mut self,
727        additional_vbs: Vec<ShardedVarBuilder>,
728        config: AnyMoeConfig,
729        (prefix, mlp): (String, String),
730        mut layers: Vec<usize>,
731        expert_type: AnyMoeExpertType,
732        gate_vb: Option<ShardedVarBuilder>,
733    ) -> Result<()> {
734        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
735        if layers.is_empty() {
736            layers = (0..self.layers.len()).collect::<Vec<_>>();
737        }
738        for _ in 0..layers.len() {
739            experts.push(Vec::new());
740        }
741        for vb in additional_vbs {
742            let vb = vb.pp(&prefix);
743            for (layer, row) in experts.iter_mut().enumerate() {
744                if !layers.contains(&layer) {
745                    continue;
746                }
747
748                let intermediate_size = self.layers[layer].mlp.get_params()[1];
749                let hidden_size = self.layers[layer].mlp.get_params()[0];
750                match expert_type {
751                    AnyMoeExpertType::FineTuned => {
752                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
753                        row.push(Box::new(Mlp::new(
754                            &Config {
755                                intermediate_size: self.layers[layer].mlp.get_params()[1],
756                                hidden_size: self.layers[layer].mlp.get_params()[0],
757                                ..Default::default()
758                            },
759                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
760                        )?));
761                    }
762                    AnyMoeExpertType::LoraAdapter {
763                        rank,
764                        alpha,
765                        ref target_modules,
766                    } => {
767                        let vb_mlp = vb.pp(layer).pp(&mlp);
768
769                        let gate_up_proj_delta =
770                            if target_modules.contains(&"gate_up_proj".to_string()) {
771                                Some(get_delta_from_lora_ab!(
772                                    vb_mlp,
773                                    rank,
774                                    alpha,
775                                    (hidden_size, 2 * intermediate_size),
776                                    "gate_up_proj"
777                                ))
778                            } else {
779                                None
780                            };
781                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
782                            Some(get_delta_from_lora_ab!(
783                                vb_mlp,
784                                rank,
785                                alpha,
786                                (hidden_size, intermediate_size),
787                                "down_proj"
788                            ))
789                        } else {
790                            None
791                        };
792
793                        row.push(
794                            self.layers[layer]
795                                .mlp
796                                .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
797                        );
798                    }
799                }
800            }
801        }
802        for (layer, expert) in layers.into_iter().zip(experts) {
803            let mut experts_all = vec![self.layers[layer].mlp.clone()];
804            experts_all.extend(expert);
805            let (dtype, device) = self.layers[layer].mlp.dtype_device();
806            self.layers[layer].mlp = Box::new(MoeMlp::new(
807                experts_all,
808                config.clone(),
809                dtype,
810                &device,
811                layer,
812                gate_vb.as_ref(),
813            )?);
814        }
815        Ok(())
816    }
817    fn amoe_supported(&self) -> bool {
818        true
819    }
820}