mistralrs_core/models/
phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use candle_core::{DType, Device, Module, Result, Tensor, D};
6use mistralrs_quant::{QuantMethod, QuantizedConfig, ReplicatedLayer, ShardedVarBuilder};
7use std::{collections::HashMap, sync::Arc};
8
9use crate::{
10    amoe::{
11        AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainableLayer, MlpLayer,
12        MoeMlp,
13    },
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{
18        embedding, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19        PhiRotaryEmbedding, RmsNorm, Sdpa,
20    },
21    layers_masker::PastKvLenCache,
22    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
23    pipeline::{
24        extract_logits,
25        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
26        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
27    },
28    serde_default_fn,
29    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
30};
31
32serde_default_fn!(bool, word_emb_default, false);
33
34// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
35#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
36pub struct Config {
37    pub vocab_size: usize,
38    pub hidden_act: Activation,
39    pub hidden_size: usize,
40    pub intermediate_size: usize,
41    pub num_hidden_layers: usize,
42    pub num_attention_heads: usize,
43    pub num_key_value_heads: usize,
44    pub rms_norm_eps: f64,
45    pub rope_theta: f64,
46    pub bos_token_id: Option<u32>,
47    pub eos_token_id: Option<u32>,
48    pub rope_scaling: Option<PhiRopeScalingConfig>,
49    pub max_position_embeddings: usize,
50    pub use_flash_attn: bool,
51    pub sliding_window: Option<usize>,
52    pub original_max_position_embeddings: usize,
53    pub quantization_config: Option<QuantizedConfig>,
54    #[serde(default = "word_emb_default")]
55    pub tie_word_embeddings: bool,
56    pub partial_rotary_factor: Option<f64>,
57}
58
59impl From<Config> for PhiRopeConfig {
60    fn from(val: Config) -> Self {
61        PhiRopeConfig {
62            rope_scaling: val.rope_scaling,
63            max_position_embeddings: val.max_position_embeddings,
64            original_max_position_embeddings: val.original_max_position_embeddings,
65            rope_theta: val.rope_theta,
66            head_dim: val.hidden_size / val.num_attention_heads,
67            partial_rotary_factor: val.partial_rotary_factor,
68        }
69    }
70}
71
72impl Config {
73    pub fn head_dim(&self) -> usize {
74        self.hidden_size / self.num_attention_heads
75    }
76}
77
78struct Attention {
79    qkv_proj: Arc<dyn QuantMethod>,
80    o_proj: Arc<dyn QuantMethod>,
81    num_heads: usize,
82    num_kv_heads: usize,
83    head_dim: usize,
84    rotary_emb: Arc<PhiRotaryEmbedding>,
85    paged_attn: Option<PagedAttention>,
86    sdpa_params: SdpaParams,
87}
88
89impl Attention {
90    fn new(
91        rotary_emb: Arc<PhiRotaryEmbedding>,
92        cfg: &Config,
93        vb: ShardedVarBuilder,
94        paged_attn: Option<PagedAttention>,
95    ) -> Result<Self> {
96        let num_heads = cfg.num_attention_heads;
97        let num_kv_heads = cfg.num_key_value_heads;
98        let head_dim = cfg.head_dim();
99        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
100
101        // No TP here.
102        let qkv_proj = mistralrs_quant::linear_no_bias(
103            cfg.hidden_size,
104            op_size,
105            &cfg.quantization_config,
106            vb.pp("qkv_proj"),
107        )?;
108
109        let o_proj = mistralrs_quant::linear_no_bias(
110            num_heads * head_dim,
111            cfg.hidden_size,
112            &cfg.quantization_config,
113            vb.pp("o_proj"),
114        )?;
115
116        Ok(Self {
117            qkv_proj,
118            o_proj,
119            rotary_emb,
120            num_heads,
121            num_kv_heads,
122            head_dim,
123            paged_attn,
124            sdpa_params: SdpaParams {
125                n_kv_groups: num_heads / num_kv_heads,
126                use_flash_attn: cfg.use_flash_attn,
127                softcap: None,
128                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
129                sliding_window: cfg.sliding_window,
130            },
131        })
132    }
133
134    #[allow(clippy::too_many_arguments)]
135    fn forward(
136        &self,
137        xs: &Tensor,
138        attention_mask: Option<&Tensor>,
139        seqlen_offsets: &[usize],
140        position_ids: &[usize],
141        kv_cache: &mut KvCache,
142        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
143        flash_params: &FlashParams,
144    ) -> Result<Tensor> {
145        let (b_sz, q_len, _) = xs.dims3()?;
146
147        let original_dtype = xs.dtype();
148        let mut xs = xs.clone();
149        if let Some(t) = self.qkv_proj.quantized_act_type() {
150            xs = xs.to_dtype(t)?;
151        }
152        let mut qkv = MatMul.qmethod_matmul(&xs, &*self.qkv_proj)?;
153        if self.qkv_proj.quantized_act_type().is_some() {
154            qkv = qkv.to_dtype(original_dtype)?;
155        }
156        let query_pos = self.num_heads * self.head_dim;
157        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
158        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
159        let v = qkv.narrow(
160            D::Minus1,
161            query_pos + self.num_kv_heads * self.head_dim,
162            self.num_kv_heads * self.head_dim,
163        )?;
164
165        let (q, k, v) = if q_len != 1 {
166            let q = q
167                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
168                .transpose(1, 2)?;
169            let k = k
170                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
171                .transpose(1, 2)?;
172            let v = v
173                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
174                .transpose(1, 2)?;
175            (q, k, v)
176        } else {
177            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
178            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
179            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180            (q, k, v)
181        };
182
183        let (q, k) = self
184            .rotary_emb
185            .forward(&q, &k, seqlen_offsets, position_ids)?;
186
187        let mut attn_output = match &self.paged_attn {
188            Some(paged_attn) => match metadata {
189                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
190                    &q,
191                    &k.contiguous()?,
192                    &v.contiguous()?,
193                    attention_mask,
194                    Some(key_cache),
195                    Some(value_cache),
196                    input_metadata,
197                    &self.sdpa_params,
198                    Some(flash_params),
199                )?,
200                None => {
201                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
202                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
203                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
204                    // Sanity check.
205                    assert!(attention_mask.is_some());
206                    paged_attn.forward(
207                        &q,
208                        &k.contiguous()?,
209                        &v.contiguous()?,
210                        attention_mask,
211                        None,
212                        None,
213                        &input_metadata,
214                        &self.sdpa_params,
215                        Some(flash_params),
216                    )?
217                }
218            },
219            _ => {
220                let (k, v) = kv_cache.append(&k, &v)?;
221
222                Sdpa.run_attention(
223                    &q,
224                    &k,
225                    &v,
226                    attention_mask,
227                    Some(flash_params),
228                    &self.sdpa_params,
229                )?
230            }
231        };
232
233        if let Some(t) = self.qkv_proj.quantized_act_type() {
234            attn_output = attn_output.to_dtype(t)?;
235        }
236        attn_output = if attention_mask.is_some() {
237            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
238        } else {
239            attn_output.reshape((b_sz, q_len, ()))?
240        };
241        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
242        if self.qkv_proj.quantized_act_type().is_some() {
243            res = res.to_dtype(original_dtype)?;
244        }
245        Ok(res)
246    }
247}
248
249#[derive(Clone)]
250struct Mlp {
251    gate_up_proj: Arc<dyn QuantMethod>,
252    down_proj: Arc<dyn QuantMethod>,
253    act_fn: Activation,
254    i_size: usize,
255    params: Vec<usize>,
256}
257
258impl Mlp {
259    fn new(cfg: &Config, vb: ShardedVarBuilder) -> Result<Self> {
260        let hidden_size = cfg.hidden_size;
261        let i_size = cfg.intermediate_size;
262
263        // No TP here.
264        let gate_up_proj = mistralrs_quant::linear_no_bias(
265            hidden_size,
266            2 * i_size,
267            &cfg.quantization_config,
268            vb.pp("gate_up_proj"),
269        )?;
270
271        let down_proj = mistralrs_quant::linear_no_bias(
272            i_size,
273            hidden_size,
274            &cfg.quantization_config,
275            vb.pp("down_proj"),
276        )?;
277
278        Ok(Self {
279            gate_up_proj,
280            down_proj,
281            act_fn: cfg.hidden_act,
282            i_size,
283            params: vec![hidden_size, i_size],
284        })
285    }
286}
287
288impl AnyMoeTrainableLayer for Mlp {}
289
290impl MlpLayer for Mlp {
291    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
292        let original_dtype = xs.dtype();
293        let mut xs = xs.clone();
294        if let Some(t) = self.gate_up_proj.quantized_act_type() {
295            xs = xs.to_dtype(t)?;
296        }
297        let up_states = MatMul.qmethod_matmul(&xs, &*self.gate_up_proj)?;
298        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
299        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
300        let up_states = (up_states * gate.apply(&self.act_fn))?;
301        let mut res = MatMul.qmethod_matmul(&up_states, &*self.down_proj)?;
302        if self.gate_up_proj.quantized_act_type().is_some() {
303            res = res.to_dtype(original_dtype)?;
304        }
305        Ok(res)
306    }
307    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
308        vec![&mut self.gate_up_proj, &mut self.down_proj]
309    }
310    fn clone(&self) -> Box<dyn MlpLayer> {
311        Box::new(Clone::clone(self))
312    }
313    fn get_params(&self) -> &[usize] {
314        &self.params
315    }
316    fn hidden_act(&self) -> Activation {
317        self.act_fn
318    }
319    // gate_up, down
320    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
321        let new_gate_up = if let Some(ref delta) = deltas[0] {
322            self.gate_up_proj.add_delta_w(delta)?
323        } else {
324            self.gate_up_proj.clone()
325        };
326        let new_down = if let Some(ref delta) = deltas[1] {
327            self.down_proj.add_delta_w(delta)?
328        } else {
329            self.down_proj.clone()
330        };
331
332        Ok(Box::new(Self {
333            gate_up_proj: new_gate_up,
334            down_proj: new_down,
335            act_fn: self.act_fn,
336            i_size: self.i_size,
337            params: self.params.clone(),
338        }))
339    }
340
341    fn dtype_device(&self) -> (DType, Device) {
342        self.gate_up_proj.dtype_and_device()
343    }
344}
345
346struct DecoderLayer {
347    self_attn: Attention,
348    mlp: Box<dyn MlpLayer>,
349    input_layernorm: RmsNorm,
350    post_attention_layernorm: RmsNorm,
351}
352
353impl DecoderLayer {
354    fn new(
355        rotary_emb: Arc<PhiRotaryEmbedding>,
356        cfg: &Config,
357        vb: ShardedVarBuilder,
358        mapper: &dyn DeviceMapper,
359        layer_idx: usize,
360        loading_isq: bool,
361        paged_attn: Option<PagedAttention>,
362    ) -> Result<Self> {
363        let self_attn = Attention::new(
364            rotary_emb,
365            cfg,
366            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
367            paged_attn,
368        )?;
369        let mlp = Mlp::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
370        let input_layernorm = RmsNorm::new(
371            cfg.hidden_size,
372            cfg.rms_norm_eps,
373            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
374        )?;
375        let post_attention_layernorm = RmsNorm::new(
376            cfg.hidden_size,
377            cfg.rms_norm_eps,
378            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
379        )?;
380        Ok(Self {
381            self_attn,
382            mlp: Box::new(mlp),
383            input_layernorm,
384            post_attention_layernorm,
385        })
386    }
387
388    #[allow(clippy::too_many_arguments)]
389    fn forward(
390        &self,
391        xs: &Tensor,
392        attention_mask: Option<&Tensor>,
393        seqlen_offsets: &[usize],
394        position_ids: &[usize],
395        kv_cache: &mut KvCache,
396        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
397        flash_params: &FlashParams,
398    ) -> Result<Tensor> {
399        let residual = xs;
400        let xs = self.input_layernorm.forward(xs)?;
401        let xs = self.self_attn.forward(
402            &xs,
403            attention_mask,
404            seqlen_offsets,
405            position_ids,
406            kv_cache,
407            metadata,
408            flash_params,
409        )?;
410        let xs = (xs + residual)?;
411        let residual = &xs;
412        let xs = self
413            .mlp
414            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
415        residual + xs
416    }
417}
418
419pub struct Model {
420    embed_tokens: candle_nn::Embedding,
421    layers: Vec<DecoderLayer>,
422    norm: RmsNorm,
423    lm_head: Arc<dyn QuantMethod>,
424    device: Device,
425    cache: EitherCache,
426    max_seq_len: usize,
427    mapper: Box<dyn DeviceMapper + Send + Sync>,
428    sliding_window: Option<usize>,
429    cfg: ModelConfigMetadata,
430}
431
432impl Model {
433    pub fn new(
434        cfg: &Config,
435        vb: ShardedVarBuilder,
436        _is_gptx: bool,
437        normal_loading_metadata: NormalLoadingMetadata,
438        attention_mechanism: AttentionImplementation,
439    ) -> Result<Self> {
440        if let Some(ref quant_cfg) = &cfg.quantization_config {
441            tracing::info!(
442                "Using {} quantization: {}.",
443                quant_cfg.name(),
444                quant_cfg.get_bits_name(&vb)
445            );
446        }
447        let mapper = normal_loading_metadata.mapper;
448        let vb_m = vb.pp("model");
449
450        let embed_tokens = embedding(
451            cfg.vocab_size,
452            cfg.hidden_size,
453            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
454            &cfg.quantization_config,
455        )?;
456        let mut ropes = HashMap::new();
457        for layer_idx in 0..cfg.num_hidden_layers {
458            let device = mapper
459                .device_for(layer_idx, false)
460                .unwrap_or(&normal_loading_metadata.real_device);
461            ropes.insert(
462                device.location(),
463                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
464            );
465        }
466        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
467        let vb_l = vb_m.pp("layers");
468        for layer_idx in NiceProgressBar::<_, 'b'>(
469            0..cfg.num_hidden_layers,
470            "Loading repeating layers",
471            &normal_loading_metadata.multi_progress,
472        ) {
473            let device = mapper
474                .device_for(layer_idx, false)
475                .unwrap_or(&normal_loading_metadata.real_device);
476            let rotary_emb = ropes
477                .get(&device.location())
478                .expect("No RoPE for device location!")
479                .clone();
480            let paged_attn = match &attention_mechanism {
481                AttentionImplementation::Eager => None,
482                AttentionImplementation::PagedAttention => {
483                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
484                }
485            };
486            let layer = DecoderLayer::new(
487                rotary_emb.clone(),
488                cfg,
489                vb_l.pp(layer_idx),
490                &*mapper,
491                layer_idx,
492                normal_loading_metadata.loading_isq,
493                paged_attn,
494            )?;
495            layers.push(layer)
496        }
497        let norm = RmsNorm::new(
498            cfg.hidden_size,
499            cfg.rms_norm_eps,
500            mapper.set_nm_device(vb_m.pp("norm"), false),
501        )?;
502        let lm_head = if !cfg.tie_word_embeddings {
503            ReplicatedLayer::new(
504                cfg.hidden_size,
505                cfg.vocab_size,
506                &None,
507                false,
508                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
509            )?
510        } else {
511            ReplicatedLayer::from_linear(candle_nn::Linear::new(
512                mapper.cast_nm_device(
513                    embed_tokens.embeddings(),
514                    normal_loading_metadata.loading_isq,
515                )?,
516                None,
517            ))?
518        };
519        Ok(Self {
520            embed_tokens,
521            layers,
522            norm,
523            lm_head,
524            device: normal_loading_metadata.real_device,
525            cache: EitherCache::Normal(NormalCache::new_sliding(
526                cfg.num_hidden_layers,
527                cfg.max_position_embeddings,
528                cfg.sliding_window,
529            )),
530            max_seq_len: cfg.max_position_embeddings,
531            sliding_window: cfg.sliding_window,
532            cfg: ModelConfigMetadata {
533                max_seq_len: cfg.max_position_embeddings,
534                num_layers: cfg.num_hidden_layers,
535                hidden_size: cfg.hidden_size,
536                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
537                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
538                    .max(1),
539                sliding_window: cfg.sliding_window,
540                k_head_dim: cfg.head_dim(),
541                v_head_dim: cfg.head_dim(),
542            },
543            mapper,
544        })
545    }
546
547    pub fn forward(
548        &self,
549        input_ids: &Tensor,
550        seqlen_offsets: &[usize],
551        position_ids: &[usize],
552        context_lens: Vec<(usize, usize)>,
553        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
554        flash_params: &FlashParams,
555    ) -> Result<Tensor> {
556        let mut xs = self.embed_tokens.forward(input_ids)?;
557        let cache = &mut self.cache.normal().0;
558        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
559            input_ids,
560            metadata
561                .as_ref()
562                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
563                .unwrap_or(cache as &dyn PastKvLenCache),
564            self.sliding_window,
565            xs.dtype(),
566            self.cfg.num_attn_heads,
567        )?;
568        // PagedAttention prompt chunking
569        let attention_mask = attention_mask.filter(|_| {
570            metadata
571                .as_ref()
572                .map(|(_, meta)| meta.is_first_prompt_chunk)
573                .unwrap_or(true)
574        });
575
576        for (i, layer) in self.layers.iter().enumerate() {
577            xs = self.mapper.map(xs, i)?;
578            xs = layer.forward(
579                &xs,
580                attention_mask
581                    .as_ref()
582                    .map(|m| m.to_device(xs.device()).unwrap())
583                    .as_ref(),
584                seqlen_offsets,
585                position_ids,
586                &mut cache[i],
587                metadata
588                    .as_ref()
589                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
590                flash_params,
591            )?
592        }
593        let xs = xs.to_device(&self.device)?;
594        let mut xs = xs.apply(&self.norm)?;
595        if let Some(t) = self.lm_head.quantized_act_type() {
596            xs = xs.to_dtype(t)?;
597        }
598        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
599    }
600}
601
602impl IsqModel for Model {
603    fn get_layers(
604        &mut self,
605    ) -> (
606        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
607        &dyn DeviceMapper,
608    ) {
609        let mut tensors = Vec::new();
610        tensors.push((&mut self.lm_head, None));
611        for (i, layer) in self.layers.iter_mut().enumerate() {
612            tensors.push((&mut layer.self_attn.qkv_proj, Some(i)));
613            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
614            tensors.extend(
615                layer
616                    .mlp
617                    .get_isq_layers()
618                    .into_iter()
619                    .map(|m| (m, Some(i)))
620                    .collect::<Vec<_>>(),
621            );
622        }
623        (tensors, &*self.mapper)
624    }
625
626    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
627        let uvb = UnVarBuilder::new();
628
629        let uvb_m = uvb.pp("model");
630        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
631        uvb_m.pp("norm").add(&self.norm);
632
633        for (layer_idx, layer) in self.layers.iter().enumerate() {
634            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
635            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
636            uvb_l
637                .pp("post_attention_layernorm")
638                .add(&layer.post_attention_layernorm);
639        }
640
641        uvb.to_safetensors()
642    }
643
644    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
645        // NOTE: dependant on the exact implementation in get_layers!
646        let mut names = Vec::new();
647        // lm_head
648        names.push(None);
649        for i in 0..self.layers.len() {
650            names.push(Some(format!("blk.{i}.attn_qkv.weight")));
651            names.push(Some(format!("blk.{i}.attn_output.weight")));
652            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
653            names.push(Some(format!("blk.{i}.ffn_up.weight")));
654            names.push(Some(format!("blk.{i}.ffn_down.weight")));
655        }
656        Ok(names)
657    }
658}
659
660impl NormalModel for Model {
661    fn forward(
662        &self,
663        input_ids: &Tensor,
664        seqlen_offsets: &[usize],
665        context_lens: Vec<(usize, usize)>,
666        position_ids: Vec<usize>,
667        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
668        flash_params: &FlashParams,
669    ) -> Result<Tensor> {
670        self.forward(
671            input_ids,
672            seqlen_offsets,
673            &position_ids,
674            context_lens,
675            metadata,
676            flash_params,
677        )
678    }
679    fn xlora_forward(
680        &self,
681        _input_ids: &Tensor,
682        _input_ids_full: &Tensor,
683        _seqlen_offsets: &[usize],
684        _seqlen_offsets_full: &[usize],
685        _no_kv_cache: bool,
686        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
687        _context_lens: Vec<(usize, usize)>,
688        _position_ids: Vec<usize>,
689        _flash_params: &FlashParams,
690        _flash_params_full: &FlashParams,
691    ) -> Result<Tensor> {
692        unimplemented!()
693    }
694    fn cache(&self) -> &EitherCache {
695        &self.cache
696    }
697    fn cache_mut(&mut self) -> &mut EitherCache {
698        &mut self.cache
699    }
700    fn device(&self) -> &Device {
701        &self.device
702    }
703    fn is_xlora(&self) -> bool {
704        false
705    }
706    fn max_seq_len(&self) -> usize {
707        self.max_seq_len
708    }
709    fn config(&self) -> &ModelConfigMetadata {
710        &self.cfg
711    }
712}
713
714impl AnyMoeBaseModelMixin for Model {
715    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
716        let mut mlps = Vec::new();
717        for layer in &self.layers {
718            mlps.push(&*layer.mlp);
719        }
720        mlps
721    }
722    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
723        let mut mlps = Vec::new();
724        for layer in &mut self.layers {
725            mlps.push(&mut layer.mlp);
726        }
727        mlps
728    }
729    fn create_anymoe_layers(
730        &mut self,
731        additional_vbs: Vec<ShardedVarBuilder>,
732        config: AnyMoeConfig,
733        (prefix, mlp): (String, String),
734        mut layers: Vec<usize>,
735        expert_type: AnyMoeExpertType,
736        gate_vb: Option<ShardedVarBuilder>,
737    ) -> Result<()> {
738        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
739        if layers.is_empty() {
740            layers = (0..self.layers.len()).collect::<Vec<_>>();
741        }
742        for _ in 0..layers.len() {
743            experts.push(Vec::new());
744        }
745        for vb in additional_vbs {
746            let vb = vb.pp(&prefix);
747            for (layer, row) in experts.iter_mut().enumerate() {
748                if !layers.contains(&layer) {
749                    continue;
750                }
751
752                let intermediate_size = self.layers[layer].mlp.get_params()[1];
753                let hidden_size = self.layers[layer].mlp.get_params()[0];
754                match expert_type {
755                    AnyMoeExpertType::FineTuned => {
756                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
757                        row.push(Box::new(Mlp::new(
758                            &Config {
759                                intermediate_size: self.layers[layer].mlp.get_params()[1],
760                                hidden_size: self.layers[layer].mlp.get_params()[0],
761                                ..Default::default()
762                            },
763                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
764                        )?));
765                    }
766                    AnyMoeExpertType::LoraAdapter {
767                        rank,
768                        alpha,
769                        ref target_modules,
770                    } => {
771                        let vb_mlp = vb.pp(layer).pp(&mlp);
772
773                        let gate_up_proj_delta =
774                            if target_modules.contains(&"gate_up_proj".to_string()) {
775                                Some(get_delta_from_lora_ab!(
776                                    vb_mlp,
777                                    rank,
778                                    alpha,
779                                    (hidden_size, 2 * intermediate_size),
780                                    "gate_up_proj"
781                                ))
782                            } else {
783                                None
784                            };
785                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
786                            Some(get_delta_from_lora_ab!(
787                                vb_mlp,
788                                rank,
789                                alpha,
790                                (hidden_size, intermediate_size),
791                                "down_proj"
792                            ))
793                        } else {
794                            None
795                        };
796
797                        row.push(
798                            self.layers[layer]
799                                .mlp
800                                .new_added_delta(vec![gate_up_proj_delta, down_proj_delta])?,
801                        );
802                    }
803                }
804            }
805        }
806        for (layer, expert) in layers.into_iter().zip(experts) {
807            let mut experts_all = vec![self.layers[layer].mlp.clone()];
808            experts_all.extend(expert);
809            let (dtype, device) = self.layers[layer].mlp.dtype_device();
810            self.layers[layer].mlp = Box::new(MoeMlp::new(
811                experts_all,
812                config.clone(),
813                dtype,
814                &device,
815                layer,
816                gate_vb.as_ref(),
817            )?);
818        }
819        Ok(())
820    }
821    fn amoe_supported(&self) -> bool {
822        true
823    }
824}