mistralrs_core/models/
phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use candle_core::{DType, Device, Module, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
7use std::{collections::HashMap, sync::Arc};
8
9use crate::{
10    amoe::{
11        AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
12        MoeMlp,
13    },
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{
18        embedding, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19        PhiRotaryEmbedding, RmsNorm, Sdpa,
20    },
21    layers_masker::PastKvLenCache,
22    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23    pipeline::{
24        extract_logits,
25        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27    },
28    serde_default_fn,
29    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
35#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
36pub struct Config {
37    pub vocab_size: usize,
38    pub hidden_act: Activation,
39    pub hidden_size: usize,
40    pub intermediate_size: usize,
41    pub num_hidden_layers: usize,
42    pub num_attention_heads: usize,
43    pub num_key_value_heads: usize,
44    pub rms_norm_eps: f64,
45    pub rope_theta: f64,
46    pub bos_token_id: Option<u32>,
47    pub eos_token_id: Option<u32>,
48    pub rope_scaling: Option<PhiRopeScalingConfig>,
49    pub max_position_embeddings: usize,
50    pub sliding_window: Option<usize>,
51    pub original_max_position_embeddings: usize,
52    pub quantization_config: Option<QuantizedConfig>,
53    #[serde(default = "word_emb_default")]
54    pub tie_word_embeddings: bool,
55    pub partial_rotary_factor: Option<f64>,
56}
57
58impl From<Config> for PhiRopeConfig {
59    fn from(val: Config) -> Self {
60        PhiRopeConfig {
61            rope_scaling: val.rope_scaling,
62            max_position_embeddings: val.max_position_embeddings,
63            original_max_position_embeddings: val.original_max_position_embeddings,
64            rope_theta: val.rope_theta,
65            head_dim: val.hidden_size / val.num_attention_heads,
66            partial_rotary_factor: val.partial_rotary_factor,
67        }
68    }
69}
70
71impl Config {
72    pub fn head_dim(&self) -> usize {
73        self.hidden_size / self.num_attention_heads
74    }
75}
76
77struct Attention {
78    qkv_proj: Arc<dyn QuantMethod>,
79    o_proj: Arc<dyn QuantMethod>,
80    num_heads: usize,
81    num_kv_heads: usize,
82    head_dim: usize,
83    rotary_emb: Arc<PhiRotaryEmbedding>,
84    paged_attn: Option<PagedAttention>,
85    sdpa_params: SdpaParams,
86}
87
88impl Attention {
89    fn new(
90        rotary_emb: Arc<PhiRotaryEmbedding>,
91        cfg: &Config,
92        vb: ShardedVarBuilder,
93        paged_attn: Option<PagedAttention>,
94    ) -> Result<Self> {
95        let num_heads = cfg.num_attention_heads;
96        let num_kv_heads = cfg.num_key_value_heads;
97        let head_dim = cfg.head_dim();
98        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
99
100        // No TP here.
101        let qkv_proj = mistralrs_quant::linear_no_bias(
102            cfg.hidden_size,
103            op_size,
104            &cfg.quantization_config,
105            vb.pp("qkv_proj"),
106        )?;
107
108        let o_proj = mistralrs_quant::linear_no_bias(
109            num_heads * head_dim,
110            cfg.hidden_size,
111            &cfg.quantization_config,
112            vb.pp("o_proj"),
113        )?;
114
115        Ok(Self {
116            qkv_proj,
117            o_proj,
118            rotary_emb,
119            num_heads,
120            num_kv_heads,
121            head_dim,
122            paged_attn,
123            sdpa_params: SdpaParams {
124                n_kv_groups: num_heads / num_kv_heads,
125                softcap: None,
126                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
127                sliding_window: cfg.sliding_window,
128            },
129        })
130    }
131
132    #[allow(clippy::too_many_arguments)]
133    fn forward(
134        &self,
135        xs: &Tensor,
136        attention_mask: Option<&Tensor>,
137        seqlen_offsets: &[usize],
138        position_ids: &[usize],
139        kv_cache: &mut KvCache,
140        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
141        flash_params: &FlashParams,
142    ) -> Result<Tensor> {
143        let (b_sz, q_len, _) = xs.dims3()?;
144
145        let original_dtype = xs.dtype();
146        let mut xs = xs.clone();
147        if let Some(t) = self.qkv_proj.quantized_act_type() {
148            xs = xs.to_dtype(t)?;
149        }
150        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
151        if self.qkv_proj.quantized_act_type().is_some() {
152            qkv = qkv.to_dtype(original_dtype)?;
153        }
154        let query_pos = self.num_heads * self.head_dim;
155        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
156        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
157        let v = qkv.narrow(
158            D::Minus1,
159            query_pos + self.num_kv_heads * self.head_dim,
160            self.num_kv_heads * self.head_dim,
161        )?;
162
163        let (q, k, v) = if q_len != 1 {
164            let q = q
165                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
166                .transpose(1, 2)?;
167            let k = k
168                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
169                .transpose(1, 2)?;
170            let v = v
171                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172                .transpose(1, 2)?;
173            (q, k, v)
174        } else {
175            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
176            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
177            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
178            (q, k, v)
179        };
180
181        let (q, k) = self
182            .rotary_emb
183            .forward(&q, &k, seqlen_offsets, position_ids)?;
184
185        let mut attn_output = match &self.paged_attn {
186            Some(paged_attn) => match metadata {
187                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
188                    &q,
189                    &k.contiguous()?,
190                    &v.contiguous()?,
191                    attention_mask,
192                    Some(key_cache),
193                    Some(value_cache),
194                    input_metadata,
195                    &self.sdpa_params,
196                    Some(flash_params),
197                )?,
198                None => {
199                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
200                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
201                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
202                    // Sanity check.
203                    assert!(attention_mask.is_some());
204                    paged_attn.forward(
205                        &q,
206                        &k.contiguous()?,
207                        &v.contiguous()?,
208                        attention_mask,
209                        None,
210                        None,
211                        &input_metadata,
212                        &self.sdpa_params,
213                        Some(flash_params),
214                    )?
215                }
216            },
217            _ => {
218                let (k, v) = kv_cache.append(&k, &v)?;
219
220                Sdpa.run_attention(
221                    &q,
222                    &k,
223                    &v,
224                    attention_mask,
225                    Some(flash_params),
226                    &self.sdpa_params,
227                )?
228            }
229        };
230
231        if let Some(t) = self.qkv_proj.quantized_act_type() {
232            attn_output = attn_output.to_dtype(t)?;
233        }
234        attn_output = if attention_mask.is_some() {
235            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
236        } else {
237            attn_output.reshape((b_sz, q_len, ()))?
238        };
239        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
240        if self.qkv_proj.quantized_act_type().is_some() {
241            res = res.to_dtype(original_dtype)?;
242        }
243        Ok(res)
244    }
245}
246
247#[derive(Clone)]
248struct Mlp {
249    gate_up_proj: Arc<dyn QuantMethod>,
250    down_proj: Arc<dyn QuantMethod>,
251    act_fn: Activation,
252    i_size: usize,
253    params: Vec<usize>,
254}
255
256impl Mlp {
257    fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
258        let hidden_size = cfg.hidden_size;
259        let i_size = cfg.intermediate_size;
260
261        // No TP here.
262        let gate_up_proj = mistralrs_quant::linear_no_bias(
263            hidden_size,
264            2 * i_size,
265            &cfg.quantization_config,
266            vb.pp("gate_up_proj"),
267        )?;
268
269        let down_proj = mistralrs_quant::linear_no_bias(
270            i_size,
271            hidden_size,
272            &cfg.quantization_config,
273            vb.pp("down_proj"),
274        )?;
275
276        Ok(Self {
277            gate_up_proj,
278            down_proj,
279            act_fn: cfg.hidden_act,
280            i_size,
281            params: vec![hidden_size, i_size],
282        })
283    }
284}
285
286impl AnyMoeTrainableLayer for Mlp {}
287
288impl MlpLayer for Mlp {
289    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
290        let original_dtype = xs.dtype();
291        let mut xs = xs.clone();
292        if let Some(t) = self.gate_up_proj.quantized_act_type() {
293            xs = xs.to_dtype(t)?;
294        }
295        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
296        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
297        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
298        let up_states = (up_states * gate.apply(&self.act_fn))?;
299        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
300        if self.gate_up_proj.quantized_act_type().is_some() {
301            res = res.to_dtype(original_dtype)?;
302        }
303        Ok(res)
304    }
305    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
306        vec![&mut self.gate_up_proj, &mut self.down_proj]
307    }
308    fn clone(&self) -> Box<dyn MlpLayer> {
309        Box::new(Clone::clone(self))
310    }
311    fn get_params(&self) -> &[usize] {
312        &self.params
313    }
314    fn hidden_act(&self) -> Activation {
315        self.act_fn
316    }
317    // gate_up, down
318    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
319        let new_gate_up = if let Some(ref delta) = deltas[0] {
320            self.gate_up_proj.add_delta_w(delta)?
321        } else {
322            self.gate_up_proj.clone()
323        };
324        let new_down = if let Some(ref delta) = deltas[1] {
325            self.down_proj.add_delta_w(delta)?
326        } else {
327            self.down_proj.clone()
328        };
329
330        Ok(Box::new(Self {
331            gate_up_proj: new_gate_up,
332            down_proj: new_down,
333            act_fn: self.act_fn,
334            i_size: self.i_size,
335            params: self.params.clone(),
336        }))
337    }
338
339    fn dtype_device(&self) -> (DType, Device) {
340        self.gate_up_proj.dtype_and_device()
341    }
342}
343
344struct DecoderLayer {
345    self_attn: Attention,
346    mlp: Box<dyn MlpLayer>,
347    input_layernorm: RmsNorm,
348    post_attention_layernorm: RmsNorm,
349}
350
351impl DecoderLayer {
352    fn new(
353        rotary_emb: Arc<PhiRotaryEmbedding>,
354        cfg: &Config,
355        vb: ShardedVarBuilder,
356        mapper: &dyn DeviceMapper,
357        layer_idx: usize,
358        loading_isq: bool,
359        paged_attn: Option<PagedAttention>,
360    ) -> Result<Self> {
361        let self_attn = Attention::new(
362            rotary_emb,
363            cfg,
364            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
365            paged_attn,
366        )?;
367        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
368        let input_layernorm = RmsNorm::new(
369            cfg.hidden_size,
370            cfg.rms_norm_eps,
371            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
372        )?;
373        let post_attention_layernorm = RmsNorm::new(
374            cfg.hidden_size,
375            cfg.rms_norm_eps,
376            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
377        )?;
378        Ok(Self {
379            self_attn,
380            mlp: Box::new(mlp),
381            input_layernorm,
382            post_attention_layernorm,
383        })
384    }
385
386    #[allow(clippy::too_many_arguments)]
387    fn forward(
388        &self,
389        xs: &Tensor,
390        attention_mask: Option<&Tensor>,
391        seqlen_offsets: &[usize],
392        position_ids: &[usize],
393        kv_cache: &mut KvCache,
394        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
395        flash_params: &FlashParams,
396    ) -> Result<Tensor> {
397        let residual = xs;
398        let xs = self.input_layernorm.forward(xs)?;
399        let xs = self.self_attn.forward(
400            &xs,
401            attention_mask,
402            seqlen_offsets,
403            position_ids,
404            kv_cache,
405            metadata,
406            flash_params,
407        )?;
408        let xs = (xs + residual)?;
409        let residual = &xs;
410        let xs = self
411            .mlp
412            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
413        residual + xs
414    }
415}
416
417pub struct Model {
418    embed_tokens: candle_nn::Embedding,
419    layers: Vec<DecoderLayer>,
420    norm: RmsNorm,
421    lm_head: Arc<dyn QuantMethod>,
422    device: Device,
423    cache: EitherCache,
424    max_seq_len: usize,
425    mapper: Box<dyn DeviceMapper + Send + Sync>,
426    sliding_window: Option<usize>,
427    cfg: ModelConfigMetadata,
428}
429
430impl Model {
431    pub fn new(
432        cfg: &Config,
433        vb: ShardedVarBuilder,
434        _is_gptx: bool,
435        normal_loading_metadata: NormalLoadingMetadata,
436        attention_mechanism: AttentionImplementation,
437    ) -> Result<Self> {
438        if let Some(ref quant_cfg) = &cfg.quantization_config {
439            tracing::info!(
440                "Using {} quantization: {}.",
441                quant_cfg.name(),
442                quant_cfg.get_bits_name(&vb)
443            );
444        }
445        let mapper = normal_loading_metadata.mapper;
446        let vb_m = vb.pp("model");
447
448        let embed_tokens = embedding(
449            cfg.vocab_size,
450            cfg.hidden_size,
451            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
452            &cfg.quantization_config,
453        )?;
454        let mut ropes = HashMap::new();
455        for layer_idx in 0..cfg.num_hidden_layers {
456            let device = mapper
457                .device_for(layer_idx, false)
458                .unwrap_or(&normal_loading_metadata.real_device);
459            ropes.insert(
460                device.location(),
461                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
462            );
463        }
464        let vb_l = vb_m.pp("layers");
465        let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
466            0..cfg.num_hidden_layers,
467            "Loading repeating layers",
468            &normal_loading_metadata.multi_progress,
469        )
470        .par_iter_if_isq(|layer_idx| {
471            let device = mapper
472                .device_for(layer_idx, false)
473                .unwrap_or(&normal_loading_metadata.real_device);
474            let rotary_emb = ropes
475                .get(&device.location())
476                .expect("No RoPE for device location!")
477                .clone();
478            let paged_attn = match &attention_mechanism {
479                AttentionImplementation::Eager => None,
480                AttentionImplementation::PagedAttention => {
481                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
482                }
483            };
484            DecoderLayer::new(
485                rotary_emb.clone(),
486                cfg,
487                vb_l.pp(layer_idx),
488                &*mapper,
489                layer_idx,
490                normal_loading_metadata.loading_isq,
491                paged_attn,
492            )
493        })?;
494        let norm = RmsNorm::new(
495            cfg.hidden_size,
496            cfg.rms_norm_eps,
497            mapper.set_nm_device(vb_m.pp("norm"), false),
498        )?;
499        let lm_head = if !cfg.tie_word_embeddings {
500            ReplicatedLayer::new(
501                cfg.hidden_size,
502                cfg.vocab_size,
503                &None,
504                false,
505                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
506            )?
507        } else {
508            ReplicatedLayer::from_linear(candle_nn::Linear::new(
509                mapper.cast_nm_device(
510                    embed_tokens.embeddings(),
511                    normal_loading_metadata.loading_isq,
512                )?,
513                None,
514            ))?
515        };
516        Ok(Self {
517            embed_tokens,
518            layers,
519            norm,
520            lm_head,
521            device: normal_loading_metadata.real_device,
522            cache: EitherCache::Normal(NormalCache::new_sliding(
523                cfg.num_hidden_layers,
524                cfg.max_position_embeddings,
525                cfg.sliding_window,
526            )),
527            max_seq_len: cfg.max_position_embeddings,
528            sliding_window: cfg.sliding_window,
529            cfg: ModelConfigMetadata {
530                max_seq_len: cfg.max_position_embeddings,
531                num_layers: cfg.num_hidden_layers,
532                hidden_size: cfg.hidden_size,
533                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
534                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
535                    .max(1),
536                sliding_window: cfg.sliding_window,
537                k_head_dim: cfg.head_dim(),
538                v_head_dim: cfg.head_dim(),
539            },
540            mapper,
541        })
542    }
543
544    pub fn forward(
545        &self,
546        input_ids: &Tensor,
547        seqlen_offsets: &[usize],
548        position_ids: &[usize],
549        context_lens: Vec<(usize, usize)>,
550        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
551        flash_params: &FlashParams,
552    ) -> Result<Tensor> {
553        let mut xs = self.embed_tokens.forward(input_ids)?;
554        let cache = &mut self.cache.normal().0;
555        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
556            input_ids,
557            metadata
558                .as_ref()
559                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
560                .unwrap_or(cache as &dyn PastKvLenCache),
561            self.sliding_window,
562            xs.dtype(),
563            self.cfg.num_attn_heads,
564        )?;
565        // PagedAttention prompt chunking
566        let attention_mask = attention_mask.filter(|_| {
567            metadata
568                .as_ref()
569                .map(|(_, meta)| meta.is_first_prompt_chunk)
570                .unwrap_or(true)
571        });
572
573        for (i, layer) in self.layers.iter().enumerate() {
574            xs = self.mapper.map(xs, i)?;
575            xs = layer.forward(
576                &xs,
577                attention_mask
578                    .as_ref()
579                    .map(|m| m.to_device(xs.device()).unwrap())
580                    .as_ref(),
581                seqlen_offsets,
582                position_ids,
583                &mut cache[i],
584                metadata
585                    .as_ref()
586                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
587                flash_params,
588            )?
589        }
590        let xs = xs.to_device(&self.device)?;
591        let mut xs = xs.apply(&self.norm)?;
592        if let Some(t) = self.lm_head.quantized_act_type() {
593            xs = xs.to_dtype(t)?;
594        }
595        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
596    }
597}
598
599impl IsqModel for Model {
600    fn get_layers(
601        &mut self,
602    ) -> (
603        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
604        &dyn DeviceMapper,
605    ) {
606        let mut tensors = Vec::new();
607        tensors.push((&mut self.lm_head, None));
608        for (i, layer) in self.layers.iter_mut().enumerate() {
609            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
610            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
611            tensors.extend(
612                layer
613                    .mlp
614                    .get_isq_layers()
615                    .into_iter()
616                    .map(|m| (m, Some(i)))
617                    .collect::<Vec<_>>(),
618            );
619        }
620        (tensors, &*self.mapper)
621    }
622
623    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
624        let uvb = UnVarBuilder::new();
625
626        let uvb_m = uvb.pp("model");
627        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
628        uvb_m.pp("norm").add(&self.norm);
629
630        for (layer_idx, layer) in self.layers.iter().enumerate() {
631            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
632            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
633            uvb_l
634                .pp("post_attention_layernorm")
635                .add(&layer.post_attention_layernorm);
636        }
637
638        uvb.to_safetensors()
639    }
640
641    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
642        // NOTE: dependant on the exact implementation in get_layers!
643        let mut names = Vec::new();
644        // lm_head
645        names.push(None);
646        for i in 0..self.layers.len() {
647            names.push(Some(format!("blk.{i}.attn_qkv.weight")));
648            names.push(Some(format!("blk.{i}.attn_output.weight")));
649            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
650            names.push(Some(format!("blk.{i}.ffn_up.weight")));
651            names.push(Some(format!("blk.{i}.ffn_down.weight")));
652        }
653        Ok(names)
654    }
655}
656
657impl NormalModel for Model {
658    fn forward(
659        &self,
660        input_ids: &Tensor,
661        seqlen_offsets: &[usize],
662        context_lens: Vec<(usize, usize)>,
663        position_ids: Vec<usize>,
664        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
665        flash_params: &FlashParams,
666    ) -> Result<Tensor> {
667        self.forward(
668            input_ids,
669            seqlen_offsets,
670            &position_ids,
671            context_lens,
672            metadata,
673            flash_params,
674        )
675    }
676    fn xlora_forward(
677        &self,
678        _input_ids: &Tensor,
679        _input_ids_full: &Tensor,
680        _seqlen_offsets: &[usize],
681        _seqlen_offsets_full: &[usize],
682        _no_kv_cache: bool,
683        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
684        _context_lens: Vec<(usize, usize)>,
685        _position_ids: Vec<usize>,
686        _flash_params: &FlashParams,
687        _flash_params_full: &FlashParams,
688    ) -> Result<Tensor> {
689        unimplemented!()
690    }
691    fn cache(&self) -> &EitherCache {
692        &self.cache
693    }
694    fn cache_mut(&mut self) -> &mut EitherCache {
695        &mut self.cache
696    }
697    fn device(&self) -> &Device {
698        &self.device
699    }
700    fn is_xlora(&self) -> bool {
701        false
702    }
703    fn max_seq_len(&self) -> usize {
704        self.max_seq_len
705    }
706    fn config(&self) -> &ModelConfigMetadata {
707        &self.cfg
708    }
709}
710
711impl AnyMoeBaseModelMixin for Model {
712    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
713        let mut mlps = Vec::new();
714        for layer in &self.layers {
715            mlps.push(&*layer.mlp);
716        }
717        mlps
718    }
719    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
720        let mut mlps = Vec::new();
721        for layer in &mut self.layers {
722            mlps.push(&mut layer.mlp);
723        }
724        mlps
725    }
726    fn create_anymoe_layers(
727        &mut self,
728        additional_vbs: Vec<ShardedVarBuilder>,
729        config: AnyMoeConfig,
730        (prefix, mlp): (String, String),
731        mut layers: Vec<usize>,
732        expert_type: AnyMoeExpertType,
733        gate_vb: Option<ShardedVarBuilder>,
734    ) -> Result<()> {
735        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
736        if layers.is_empty() {
737            layers = (0..self.layers.len()).collect::<Vec<_>>();
738        }
739        for _ in 0..layers.len() {
740            experts.push(Vec::new());
741        }
742        for vb in additional_vbs {
743            let vb = vb.pp(&prefix);
744            for (layer, row) in experts.iter_mut().enumerate() {
745                if !layers.contains(&layer) {
746                    continue;
747                }
748
749                let intermediate_size = self.layers[layer].mlp.get_params()[1];
750                let hidden_size = self.layers[layer].mlp.get_params()[0];
751                match expert_type {
752                    AnyMoeExpertType::FineTuned => {
753                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
754                        row.push(Box::new(Mlp::new(
755                            &Config {
756                                intermediate_size: self.layers[layer].mlp.get_params()[1],
757                                hidden_size: self.layers[layer].mlp.get_params()[0],
758                                ..Default::default()
759                            },
760                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
761                        )?));
762                    }
763                    AnyMoeExpertType::LoraAdapter {
764                        rank,
765                        alpha,
766                        ref target_modules,
767                    } => {
768                        let vb_mlp = vb.pp(layer).pp(&mlp);
769
770                        let gate_up_proj_delta =
771                            if target_modules.contains(&"gate_up_proj".to_string()) {
772                                Some(get_delta_from_lora_ab!(
773                                    vb_mlp,
774                                    rank,
775                                    alpha,
776                                    (hidden_size, 2 * intermediate_size),
777                                    "gate_up_proj"
778                                ))
779                            } else {
780                                None
781                            };
782                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
783                            Some(get_delta_from_lora_ab!(
784                                vb_mlp,
785                                rank,
786                                alpha,
787                                (hidden_size, intermediate_size),
788                                "down_proj"
789                            ))
790                        } else {
791                            None
792                        };
793
794                        row.push(
795                            self.layers[layer]
796                                .mlp
797                                .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
798                        );
799                    }
800                }
801            }
802        }
803        for (layer, expert) in layers.into_iter().zip(experts) {
804            let mut experts_all = vec![self.layers[layer].mlp.clone()];
805            experts_all.extend(expert);
806            let (dtype, device) = self.layers[layer].mlp.dtype_device();
807            self.layers[layer].mlp = Box::new(MoeMlp::new(
808                experts_all,
809                config.clone(),
810                dtype,
811                &device,
812                layer,
813                gate_vb.as_ref(),
814            )?);
815        }
816        Ok(())
817    }
818    fn amoe_supported(&self) -> bool {
819        true
820    }
821}