mistralrs_core/models/
gemma2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Module, Result, Tensor};
6use candle_nn::Linear;
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
9    ShardedVarBuilder, UnquantLinear,
10};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19    pipeline::{
20        extract_logits,
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23    },
24    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27#[derive(Debug, Clone, Default, serde::Serialize)]
28pub struct Config {
29    pub attention_bias: bool,
30    pub head_dim: usize,
31    // The code gemma configs include both hidden_act and hidden_activation.
32    pub hidden_act: Option<Activation>,
33    pub hidden_activation: Option<Activation>,
34    pub hidden_size: usize,
35    pub intermediate_size: usize,
36    pub num_attention_heads: usize,
37    pub num_hidden_layers: usize,
38    pub num_key_value_heads: usize,
39    pub rms_norm_eps: f64,
40    pub rope_theta: f64,
41    pub vocab_size: usize,
42    pub sliding_window: usize,
43    pub attn_logit_softcapping: Option<f64>,
44    pub final_logit_softcapping: Option<f64>,
45    pub query_pre_attn_scalar: usize,
46    pub max_position_embeddings: usize,
47    pub quantization_config: Option<QuantizedConfig>,
48    pub use_flash_attn: bool,
49    #[allow(dead_code)]
50    pub tie_word_embeddings: bool,
51}
52
53impl Config {
54    pub fn hidden_act(&self) -> Result<Activation> {
55        match (self.hidden_act, self.hidden_activation) {
56            (None, Some(act)) | (Some(act), None) => Ok(act),
57            (Some(act), Some(_)) => {
58                // If both are set just use hidden_act
59                Ok(act)
60            }
61            (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"),
62        }
63    }
64}
65struct Attention {
66    q_proj: Arc<dyn QuantMethod>,
67    k_proj: Arc<dyn QuantMethod>,
68    v_proj: Arc<dyn QuantMethod>,
69    o_proj: Arc<dyn QuantMethod>,
70    num_heads: usize,
71    num_kv_heads: usize,
72    head_dim: usize,
73    rotary_emb: Arc<RotaryEmbedding>,
74    use_sliding_window: bool,
75    paged_attn: Option<PagedAttention>,
76    sdpa_params: SdpaParams,
77}
78
79impl Attention {
80    fn new(
81        rotary_emb: Arc<RotaryEmbedding>,
82        cfg: &Config,
83        layer_idx: usize,
84        vb: ShardedVarBuilder,
85        paged_attn: Option<PagedAttention>,
86        comm: &Arc<mistralrs_quant::Comm>,
87    ) -> Result<Self> {
88        let hidden_sz = cfg.hidden_size;
89        let num_heads = cfg.num_attention_heads;
90        let num_kv_heads = cfg.num_key_value_heads;
91        let head_dim = cfg.head_dim;
92        let bias = cfg.attention_bias;
93        let q_proj = ColumnParallelLayer::new(
94            hidden_sz,
95            num_heads * head_dim,
96            &cfg.quantization_config,
97            bias,
98            comm,
99            vb.pp("q_proj"),
100        )?;
101        let kv_shard = mistralrs_quant::compute_kv_shard(
102            cfg.num_key_value_heads,
103            cfg.hidden_size / cfg.num_attention_heads,
104            comm,
105        );
106        let k_proj = ColumnParallelLayer::new_with_shard(
107            hidden_sz,
108            num_kv_heads * head_dim,
109            &cfg.quantization_config,
110            bias,
111            comm,
112            kv_shard,
113            vb.pp("k_proj"),
114        )?;
115        let v_proj = ColumnParallelLayer::new_with_shard(
116            hidden_sz,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            bias,
120            comm,
121            kv_shard,
122            vb.pp("v_proj"),
123        )?;
124        let o_proj = RowParallelLayer::new(
125            num_heads * head_dim,
126            hidden_sz,
127            &cfg.quantization_config,
128            bias,
129            comm,
130            vb.pp("o_proj"),
131        )?;
132        let sliding_window = if layer_idx % 2 == 0 {
133            // ^ Order is SWA, global, SWA
134            Some(cfg.sliding_window)
135        } else {
136            None
137        };
138        Ok(Self {
139            q_proj,
140            k_proj,
141            v_proj,
142            o_proj,
143            num_heads: num_heads / comm.world_size(),
144            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
145            head_dim,
146            rotary_emb,
147            use_sliding_window: layer_idx % 2 == 0, // Order is SWA, global, SWA
148            paged_attn,
149            sdpa_params: SdpaParams {
150                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
151                    cfg.num_key_value_heads,
152                    cfg.num_attention_heads,
153                    comm,
154                ),
155                use_flash_attn: cfg.use_flash_attn,
156                softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
157                softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
158                sliding_window,
159            },
160        })
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    fn forward(
165        &self,
166        xs: &Tensor,
167        attention_mask: Option<&Tensor>,
168        sliding_attention_mask: Option<&Tensor>,
169        seqlen_offsets: &[usize],
170        kv_cache: &mut KvCache,
171        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
172        flash_params: &FlashParams,
173    ) -> Result<Tensor> {
174        let (b_sz, q_len, _) = xs.dims3()?;
175
176        let original_dtype = xs.dtype();
177        let mut xs = xs.clone();
178        if let Some(t) = self.q_proj.quantized_act_type() {
179            xs = xs.to_dtype(t)?;
180        }
181        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
182        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
183        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
184        if self.q_proj.quantized_act_type().is_some() {
185            q = q.to_dtype(original_dtype)?;
186            k = k.to_dtype(original_dtype)?;
187            v = v.to_dtype(original_dtype)?;
188        }
189
190        let (q, k, v) = if q_len != 1 {
191            let q = q
192                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
193                .transpose(1, 2)?;
194            let k = k
195                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196                .transpose(1, 2)?;
197            let v = v
198                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
199                .transpose(1, 2)?;
200            (q, k, v)
201        } else {
202            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
203            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
205            (q, k, v)
206        };
207
208        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
209
210        let mask = if self.use_sliding_window {
211            sliding_attention_mask
212        } else {
213            attention_mask
214        };
215
216        let mut attn_output = match &self.paged_attn {
217            Some(paged_attn) => match metadata {
218                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
219                    &q,
220                    &k,
221                    &v,
222                    attention_mask,
223                    Some(key_cache),
224                    Some(value_cache),
225                    input_metadata,
226                    &self.sdpa_params,
227                    Some(flash_params),
228                )?,
229                None => {
230                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
231                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
232                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
233                    // Sanity check.
234                    assert!(attention_mask.is_some());
235                    paged_attn.forward(
236                        &q,
237                        &k,
238                        &v,
239                        attention_mask,
240                        None,
241                        None,
242                        &input_metadata,
243                        &self.sdpa_params,
244                        Some(flash_params),
245                    )?
246                }
247            },
248            None => {
249                // self.sliding_window is None if !self.use_sliding_window
250                let (k, v) = kv_cache.append(&k, &v)?;
251
252                Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
253            }
254        };
255
256        if let Some(t) = self.q_proj.quantized_act_type() {
257            attn_output = attn_output.to_dtype(t)?;
258        }
259        attn_output = if attention_mask.is_some() {
260            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
261        } else {
262            attn_output.reshape((b_sz, q_len, ()))?
263        };
264        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
265        if self.q_proj.quantized_act_type().is_some() {
266            res = res.to_dtype(original_dtype)?;
267        }
268        Ok(res)
269    }
270}
271
272struct DecoderLayer {
273    self_attn: Attention,
274    mlp: Box<dyn MlpLayer>,
275    input_layernorm: RmsNorm,
276    post_attention_layernorm: RmsNorm,
277    pre_feedforward_layernorm: RmsNorm,
278    post_feedforward_layernorm: RmsNorm,
279}
280
281impl DecoderLayer {
282    #[allow(clippy::too_many_arguments)]
283    fn new(
284        rotary_emb: Arc<RotaryEmbedding>,
285        cfg: &Config,
286        vb: ShardedVarBuilder,
287        mapper: &dyn DeviceMapper,
288        layer_idx: usize,
289        loading_isq: bool,
290        paged_attn: Option<PagedAttention>,
291        comm: &Arc<mistralrs_quant::Comm>,
292    ) -> Result<Self> {
293        let self_attn = Attention::new(
294            rotary_emb,
295            cfg,
296            layer_idx,
297            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
298            paged_attn,
299            comm,
300        )?;
301        let mlp = Mlp::new(
302            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
303            cfg.hidden_size,
304            cfg.intermediate_size,
305            &cfg.quantization_config,
306            cfg.hidden_act()?,
307            comm,
308        )?;
309        let input_layernorm = RmsNorm::new_gemma(
310            cfg.hidden_size,
311            cfg.rms_norm_eps,
312            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
313        )?;
314        let post_attention_layernorm = RmsNorm::new_gemma(
315            cfg.hidden_size,
316            cfg.rms_norm_eps,
317            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
318        )?;
319        let pre_feedforward_layernorm = RmsNorm::new_gemma(
320            cfg.hidden_size,
321            cfg.rms_norm_eps,
322            mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
323        )?;
324        let post_feedforward_layernorm = RmsNorm::new_gemma(
325            cfg.hidden_size,
326            cfg.rms_norm_eps,
327            mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
328        )?;
329        Ok(Self {
330            self_attn,
331            mlp: Box::new(mlp),
332            input_layernorm,
333            post_attention_layernorm,
334            pre_feedforward_layernorm,
335            post_feedforward_layernorm,
336        })
337    }
338
339    #[allow(clippy::too_many_arguments)]
340    fn forward(
341        &self,
342        xs: &Tensor,
343        attention_mask: Option<&Tensor>,
344        sliding_attention_mask: Option<&Tensor>,
345        seqlen_offsets: &[usize],
346        kv_cache: &mut KvCache,
347        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
348        flash_params: &FlashParams,
349    ) -> Result<Tensor> {
350        let residual = xs;
351        let xs = self.input_layernorm.forward(xs)?;
352        let xs = self
353            .self_attn
354            .forward(
355                &xs,
356                attention_mask,
357                sliding_attention_mask,
358                seqlen_offsets,
359                kv_cache,
360                metadata,
361                flash_params,
362            )?
363            .apply(&self.post_attention_layernorm)?;
364        let xs = (xs + residual)?;
365        let residual = &xs;
366        let xs = self
367            .mlp
368            .forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
369            .apply(&self.post_feedforward_layernorm)?;
370        residual + xs
371    }
372}
373
374pub struct Model {
375    embed_tokens: candle_nn::Embedding,
376    layers: Vec<DecoderLayer>,
377    norm: RmsNorm,
378    lm_head: Arc<dyn QuantMethod>,
379    hidden_size: usize,
380    device: Device,
381    cache: EitherCache,
382    max_seq_len: usize,
383    mapper: Box<dyn DeviceMapper + Send + Sync>,
384    sliding_window: usize,
385    final_logit_softcapping: Option<f64>,
386    cfg: ModelConfigMetadata,
387}
388
389impl Model {
390    pub fn new(
391        cfg: &Config,
392        vb: ShardedVarBuilder,
393        is_gptx: bool,
394        normal_loading_metadata: NormalLoadingMetadata,
395        attention_mechanism: AttentionImplementation,
396    ) -> Result<Self> {
397        if let Some(ref quant_cfg) = &cfg.quantization_config {
398            tracing::info!(
399                "Using {} quantization: {}.",
400                quant_cfg.name(),
401                quant_cfg.get_bits_name(&vb)
402            );
403        }
404        let mapper = normal_loading_metadata.mapper;
405
406        let vb_m = vb.pp("model");
407        let embed_tokens = embedding(
408            cfg.vocab_size,
409            cfg.hidden_size,
410            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
411            &cfg.quantization_config,
412        )?;
413        let mut ropes = HashMap::new();
414        for layer_idx in 0..cfg.num_hidden_layers {
415            let device = mapper
416                .device_for(layer_idx, false)
417                .unwrap_or(&normal_loading_metadata.real_device);
418            ropes.insert(
419                device.location(),
420                Arc::new(RotaryEmbedding::new(
421                    cfg.rope_theta as f32,
422                    cfg.head_dim,
423                    cfg.max_position_embeddings,
424                    device,
425                    is_gptx,
426                    vb_m.dtype(),
427                )?),
428            );
429        }
430        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
431        let vb_l = vb_m.pp("layers");
432        for layer_idx in NiceProgressBar::<_, 'b'>(
433            0..cfg.num_hidden_layers,
434            "Loading repeating layers",
435            &normal_loading_metadata.multi_progress,
436        ) {
437            let device = mapper
438                .device_for(layer_idx, false)
439                .unwrap_or(&normal_loading_metadata.real_device);
440            let rotary_emb = ropes
441                .get(&device.location())
442                .expect("No RoPE for device location!")
443                .clone();
444            let paged_attn = match &attention_mechanism {
445                AttentionImplementation::Eager => None,
446                AttentionImplementation::PagedAttention => {
447                    Some(PagedAttention::new(cfg.head_dim, device, None)?)
448                }
449            };
450            let comm = mapper.get_comm_for(layer_idx)?;
451            let layer = DecoderLayer::new(
452                rotary_emb.clone(),
453                cfg,
454                vb_l.pp(layer_idx),
455                &*mapper,
456                layer_idx,
457                normal_loading_metadata.loading_isq,
458                paged_attn,
459                &comm,
460            )?;
461            layers.push(layer)
462        }
463        let norm = RmsNorm::new_gemma(
464            cfg.hidden_size,
465            cfg.rms_norm_eps,
466            mapper.set_nm_device(vb_m.pp("norm"), false),
467        )?;
468        let lm_head = mapper.cast_nm_device(
469            embed_tokens.embeddings(),
470            normal_loading_metadata.loading_isq,
471        )?;
472        Ok(Self {
473            embed_tokens,
474            layers,
475            norm,
476            lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
477                Linear::new(lm_head, None),
478            ))?),
479            device: normal_loading_metadata.real_device,
480            hidden_size: cfg.hidden_size,
481            cache: EitherCache::Normal(NormalCache::new_sliding(
482                cfg.num_hidden_layers,
483                cfg.max_position_embeddings,
484                Some(cfg.sliding_window),
485            )),
486            max_seq_len: cfg.max_position_embeddings,
487            sliding_window: cfg.sliding_window,
488            final_logit_softcapping: cfg.final_logit_softcapping,
489            cfg: ModelConfigMetadata {
490                max_seq_len: cfg.max_position_embeddings,
491                num_layers: cfg.num_hidden_layers,
492                hidden_size: cfg.hidden_size,
493                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
494                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
495                    .max(1),
496                sliding_window: None,
497                k_head_dim: cfg.head_dim,
498                v_head_dim: cfg.head_dim,
499            },
500            mapper,
501        })
502    }
503
504    pub fn forward(
505        &self,
506        input_ids: &Tensor,
507        seqlen_offsets: &[usize],
508        context_lens: Vec<(usize, usize)>,
509        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
510        flash_params: &FlashParams,
511    ) -> Result<Tensor> {
512        let xs = self.embed_tokens.forward(input_ids)?;
513        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
514        let cache = &mut self.cache.normal().0;
515        let attention_mask = CausalMasker.make_causal_mask_matrix(
516            input_ids,
517            &*cache,
518            xs.dtype(),
519            self.cfg.num_attn_heads,
520        )?;
521        // PagedAttention prompt chunking
522        let attention_mask = attention_mask.filter(|_| {
523            metadata
524                .as_ref()
525                .map(|(_, meta)| meta.is_first_prompt_chunk)
526                .unwrap_or(true)
527        });
528        let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
529            input_ids,
530            &*cache,
531            Some(self.sliding_window),
532            xs.dtype(),
533            self.cfg.num_attn_heads,
534        )?;
535        // PagedAttention prompt chunking
536        let sliding_attention_mask = sliding_attention_mask.filter(|_| {
537            metadata
538                .as_ref()
539                .map(|(_, meta)| meta.is_first_prompt_chunk)
540                .unwrap_or(true)
541        });
542        for (i, layer) in self.layers.iter().enumerate() {
543            xs = self.mapper.map(xs, i)?;
544            xs = layer.forward(
545                &xs,
546                attention_mask
547                    .as_ref()
548                    .map(|m| m.to_device(xs.device()).unwrap())
549                    .as_ref(),
550                sliding_attention_mask
551                    .as_ref()
552                    .map(|m| m.to_device(xs.device()).unwrap())
553                    .as_ref(),
554                seqlen_offsets,
555                &mut cache[i],
556                metadata
557                    .as_ref()
558                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
559                flash_params,
560            )?;
561        }
562        let xs = xs.to_device(&self.device)?;
563        let mut xs = xs.apply(&self.norm)?;
564        if let Some(t) = self.lm_head.quantized_act_type() {
565            xs = xs.to_dtype(t)?;
566        }
567
568        let mut xs = MatMul.qmethod_matmul(&xs, &*self.lm_head)?;
569
570        if let Some(final_logit_softcapping) = self.final_logit_softcapping {
571            xs = (xs / final_logit_softcapping)?;
572            xs = xs.tanh()?;
573            xs = (xs * final_logit_softcapping)?;
574        }
575
576        extract_logits(&xs, context_lens)
577    }
578}
579
580impl IsqModel for Model {
581    fn get_layers(
582        &mut self,
583    ) -> (
584        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
585        &dyn DeviceMapper,
586    ) {
587        let mut tensors = Vec::new();
588        tensors.push((&mut self.lm_head, None));
589        for (i, layer) in self.layers.iter_mut().enumerate() {
590            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
591            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
592            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
593            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
594            tensors.extend(
595                layer
596                    .mlp
597                    .get_isq_layers()
598                    .into_iter()
599                    .map(|m| (m, Some(i)))
600                    .collect::<Vec<_>>(),
601            );
602        }
603        (tensors, &*self.mapper)
604    }
605
606    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
607        let uvb = UnVarBuilder::new();
608
609        let uvb_m = uvb.pp("model");
610        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
611        uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
612
613        for (layer_idx, layer) in self.layers.iter().enumerate() {
614            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
615            uvb_l
616                .pp("input_layernorm")
617                .add(&layer.input_layernorm.undo_gemma().unwrap());
618            uvb_l
619                .pp("post_attention_layernorm")
620                .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
621            uvb_l
622                .pp("pre_feedforward_layernorm")
623                .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
624            uvb_l
625                .pp("post_feedforward_layernorm")
626                .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
627        }
628
629        uvb.to_safetensors()
630    }
631
632    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
633        // NOTE: dependant on the exact implementation in get_layers!
634        let mut names = Vec::new();
635        // lm_head
636        names.push(None);
637        for i in 0..self.layers.len() {
638            names.push(Some(format!("blk.{i}.attn_q.weight")));
639            names.push(Some(format!("blk.{i}.attn_k.weight")));
640            names.push(Some(format!("blk.{i}.attn_v.weight")));
641            names.push(Some(format!("blk.{i}.attn_output.weight")));
642            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
643            names.push(Some(format!("blk.{i}.ffn_up.weight")));
644            names.push(Some(format!("blk.{i}.ffn_down.weight")));
645        }
646        Ok(names)
647    }
648}
649
650impl NormalModel for Model {
651    fn forward(
652        &self,
653        input_ids: &Tensor,
654        seqlen_offsets: &[usize],
655        context_lens: Vec<(usize, usize)>,
656        _position_ids: Vec<usize>,
657        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
658        flash_params: &FlashParams,
659    ) -> Result<Tensor> {
660        self.forward(
661            input_ids,
662            seqlen_offsets,
663            context_lens,
664            metadata,
665            flash_params,
666        )
667    }
668    fn xlora_forward(
669        &self,
670        _input_ids: &Tensor,
671        _input_ids_full: &Tensor,
672        _seqlen_offsets: &[usize],
673        _seqlen_offsets_full: &[usize],
674        _no_kv_cache: bool,
675        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
676        _context_lens: Vec<(usize, usize)>,
677        _position_ids: Vec<usize>,
678        _flash_params: &FlashParams,
679        _flash_params_full: &FlashParams,
680    ) -> Result<Tensor> {
681        unimplemented!()
682    }
683    fn cache(&self) -> &EitherCache {
684        &self.cache
685    }
686    fn cache_mut(&mut self) -> &mut EitherCache {
687        &mut self.cache
688    }
689    fn device(&self) -> &Device {
690        &self.device
691    }
692    fn is_xlora(&self) -> bool {
693        false
694    }
695    fn max_seq_len(&self) -> usize {
696        self.max_seq_len
697    }
698    fn config(&self) -> &ModelConfigMetadata {
699        &self.cfg
700    }
701}
702
703impl AnyMoeBaseModelMixin for Model {
704    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
705        let mut mlps = Vec::new();
706        for layer in &self.layers {
707            mlps.push(&*layer.mlp);
708        }
709        mlps
710    }
711    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
712        let mut mlps = Vec::new();
713        for layer in &mut self.layers {
714            mlps.push(&mut layer.mlp);
715        }
716        mlps
717    }
718    fn create_anymoe_layers(
719        &mut self,
720        additional_vbs: Vec<ShardedVarBuilder>,
721        config: AnyMoeConfig,
722        (prefix, mlp): (String, String),
723        mut layers: Vec<usize>,
724        expert_type: AnyMoeExpertType,
725        gate_vb: Option<ShardedVarBuilder>,
726    ) -> Result<()> {
727        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
728        if layers.is_empty() {
729            layers = (0..self.layers.len()).collect::<Vec<_>>();
730        }
731        for _ in 0..layers.len() {
732            experts.push(Vec::new());
733        }
734        for vb in additional_vbs {
735            let vb = vb.pp(&prefix);
736            for (layer, row) in experts.iter_mut().enumerate() {
737                if !layers.contains(&layer) {
738                    continue;
739                }
740
741                let intermediate_size = self.layers[layer].mlp.get_params()[1];
742                let hidden_size = self.layers[layer].mlp.get_params()[0];
743                match expert_type {
744                    AnyMoeExpertType::FineTuned => {
745                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
746                        row.push(Box::new(Mlp::replicate(
747                            self.layers[layer].mlp.get_params(),
748                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
749                            self.layers[layer].mlp.hidden_act(),
750                            &self.mapper.get_comm_for(layer)?,
751                        )?));
752                    }
753                    AnyMoeExpertType::LoraAdapter {
754                        rank,
755                        alpha,
756                        ref target_modules,
757                    } => {
758                        let vb_mlp = vb.pp(layer).pp(&mlp);
759
760                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
761                            Some(get_delta_from_lora_ab!(
762                                vb_mlp,
763                                rank,
764                                alpha,
765                                (hidden_size, intermediate_size),
766                                "gate_proj"
767                            ))
768                        } else {
769                            None
770                        };
771                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
772                            Some(get_delta_from_lora_ab!(
773                                vb_mlp,
774                                rank,
775                                alpha,
776                                (hidden_size, intermediate_size),
777                                "up_proj"
778                            ))
779                        } else {
780                            None
781                        };
782                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
783                            Some(get_delta_from_lora_ab!(
784                                vb_mlp,
785                                rank,
786                                alpha,
787                                (intermediate_size, hidden_size),
788                                "down_proj"
789                            ))
790                        } else {
791                            None
792                        };
793
794                        row.push(self.layers[layer].mlp.new_added_delta(vec![
795                            gate_proj_delta,
796                            up_proj_delta,
797                            down_proj_delta,
798                        ])?);
799                    }
800                }
801            }
802        }
803        for (layer, expert) in layers.into_iter().zip(experts) {
804            let mut experts_all = vec![self.layers[layer].mlp.clone()];
805            experts_all.extend(expert);
806            let (dtype, device) = self.layers[layer].mlp.dtype_device();
807            self.layers[layer].mlp = Box::new(MoeMlp::new(
808                experts_all,
809                config.clone(),
810                dtype,
811                &device,
812                layer,
813                gate_vb.as_ref(),
814            )?);
815        }
816        Ok(())
817    }
818    fn amoe_supported(&self) -> bool {
819        true
820    }
821}