mistralrs_core/models/
gemma2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, Module, Result, Tensor};
6use candle_nn::Linear;
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantMethodConfig, QuantizedConfig, RowParallelLayer,
9    ShardedVarBuilder, UnquantLinear,
10};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeConfig, AnyMoeExpertType, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{embedding, Activation, CausalMasker, MatMul, Mlp, RmsNorm, RotaryEmbedding, Sdpa},
18    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19    pipeline::{
20        extract_logits,
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23    },
24    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27#[derive(Debug, Clone, Default, serde::Serialize)]
28pub struct Config {
29    pub attention_bias: bool,
30    pub head_dim: usize,
31    // The code gemma configs include both hidden_act and hidden_activation.
32    pub hidden_act: Option<Activation>,
33    pub hidden_activation: Option<Activation>,
34    pub hidden_size: usize,
35    pub intermediate_size: usize,
36    pub num_attention_heads: usize,
37    pub num_hidden_layers: usize,
38    pub num_key_value_heads: usize,
39    pub rms_norm_eps: f64,
40    pub rope_theta: f64,
41    pub vocab_size: usize,
42    pub sliding_window: usize,
43    pub attn_logit_softcapping: Option<f64>,
44    pub final_logit_softcapping: Option<f64>,
45    pub query_pre_attn_scalar: usize,
46    pub max_position_embeddings: usize,
47    pub quantization_config: Option<QuantizedConfig>,
48    pub use_flash_attn: bool,
49    #[allow(dead_code)]
50    pub tie_word_embeddings: bool,
51}
52
53impl Config {
54    pub fn hidden_act(&self) -> Result<Activation> {
55        match (self.hidden_act, self.hidden_activation) {
56            (None, Some(act)) | (Some(act), None) => Ok(act),
57            (Some(act), Some(_)) => {
58                // If both are set just use hidden_act
59                Ok(act)
60            }
61            (None, None) => candle_core::bail!("none of hidden_act and hidden_activation are set"),
62        }
63    }
64}
65struct Attention {
66    q_proj: Arc<dyn QuantMethod>,
67    k_proj: Arc<dyn QuantMethod>,
68    v_proj: Arc<dyn QuantMethod>,
69    o_proj: Arc<dyn QuantMethod>,
70    num_heads: usize,
71    num_kv_heads: usize,
72    head_dim: usize,
73    rotary_emb: Arc<RotaryEmbedding>,
74    use_sliding_window: bool,
75    paged_attn: Option<PagedAttention>,
76    sdpa_params: SdpaParams,
77}
78
79impl Attention {
80    fn new(
81        rotary_emb: Arc<RotaryEmbedding>,
82        cfg: &Config,
83        layer_idx: usize,
84        vb: ShardedVarBuilder,
85        paged_attn: Option<PagedAttention>,
86        comm: &Arc<mistralrs_quant::Comm>,
87    ) -> Result<Self> {
88        let hidden_sz = cfg.hidden_size;
89        let num_heads = cfg.num_attention_heads;
90        let num_kv_heads = cfg.num_key_value_heads;
91        let head_dim = cfg.head_dim;
92        let bias = cfg.attention_bias;
93        let q_proj = ColumnParallelLayer::new(
94            hidden_sz,
95            num_heads * head_dim,
96            &cfg.quantization_config,
97            bias,
98            comm,
99            vb.pp("q_proj"),
100        )?;
101        let kv_shard = mistralrs_quant::compute_kv_shard(
102            cfg.num_key_value_heads,
103            cfg.hidden_size / cfg.num_attention_heads,
104            comm,
105        );
106        let k_proj = ColumnParallelLayer::new_with_shard(
107            hidden_sz,
108            num_kv_heads * head_dim,
109            &cfg.quantization_config,
110            bias,
111            comm,
112            kv_shard,
113            vb.pp("k_proj"),
114        )?;
115        let v_proj = ColumnParallelLayer::new_with_shard(
116            hidden_sz,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            bias,
120            comm,
121            kv_shard,
122            vb.pp("v_proj"),
123        )?;
124        let o_proj = RowParallelLayer::new(
125            num_heads * head_dim,
126            hidden_sz,
127            &cfg.quantization_config,
128            bias,
129            comm,
130            vb.pp("o_proj"),
131        )?;
132        let sliding_window = if layer_idx % 2 == 0 {
133            // ^ Order is SWA, global, SWA
134            Some(cfg.sliding_window)
135        } else {
136            None
137        };
138        Ok(Self {
139            q_proj,
140            k_proj,
141            v_proj,
142            o_proj,
143            num_heads: num_heads / comm.world_size(),
144            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
145            head_dim,
146            rotary_emb,
147            use_sliding_window: layer_idx % 2 == 0, // Order is SWA, global, SWA
148            paged_attn,
149            sdpa_params: SdpaParams {
150                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
151                    cfg.num_key_value_heads,
152                    cfg.num_attention_heads,
153                    comm,
154                ),
155                use_flash_attn: cfg.use_flash_attn,
156                softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
157                softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
158                sliding_window,
159            },
160        })
161    }
162
163    #[allow(clippy::too_many_arguments)]
164    fn forward(
165        &self,
166        xs: &Tensor,
167        attention_mask: Option<&Tensor>,
168        sliding_attention_mask: Option<&Tensor>,
169        seqlen_offsets: &[usize],
170        kv_cache: &mut KvCache,
171        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
172        flash_params: &FlashParams,
173    ) -> Result<Tensor> {
174        let (b_sz, q_len, _) = xs.dims3()?;
175
176        let original_dtype = xs.dtype();
177        let mut xs = xs.clone();
178        if let Some(t) = self.q_proj.quantized_act_type() {
179            xs = xs.to_dtype(t)?;
180        }
181        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
182        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
183        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
184        if self.q_proj.quantized_act_type().is_some() {
185            q = q.to_dtype(original_dtype)?;
186            k = k.to_dtype(original_dtype)?;
187            v = v.to_dtype(original_dtype)?;
188        }
189
190        let (q, k, v) = if q_len != 1 {
191            let q = q
192                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
193                .transpose(1, 2)?;
194            let k = k
195                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
196                .transpose(1, 2)?;
197            let v = v
198                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
199                .transpose(1, 2)?;
200            (q, k, v)
201        } else {
202            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
203            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
204            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
205            (q, k, v)
206        };
207
208        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
209
210        let mask = if self.use_sliding_window {
211            sliding_attention_mask
212        } else {
213            attention_mask
214        };
215
216        let mut attn_output = match &self.paged_attn {
217            Some(paged_attn) => match metadata {
218                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
219                    &q,
220                    &k,
221                    &v,
222                    attention_mask,
223                    Some(key_cache),
224                    Some(value_cache),
225                    input_metadata,
226                    &self.sdpa_params,
227                    Some(flash_params),
228                )?,
229                None => {
230                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
231                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
232                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
233                    // Sanity check.
234                    assert!(attention_mask.is_some());
235                    paged_attn.forward(
236                        &q,
237                        &k,
238                        &v,
239                        attention_mask,
240                        None,
241                        None,
242                        &input_metadata,
243                        &self.sdpa_params,
244                        Some(flash_params),
245                    )?
246                }
247            },
248            None => {
249                // self.sliding_window is None if !self.use_sliding_window
250                let (k, v) = kv_cache.append(&k, &v)?;
251
252                Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
253            }
254        };
255
256        if let Some(t) = self.q_proj.quantized_act_type() {
257            attn_output = attn_output.to_dtype(t)?;
258        }
259        attn_output = if attention_mask.is_some() {
260            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
261        } else {
262            attn_output.reshape((b_sz, q_len, ()))?
263        };
264        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
265        if self.q_proj.quantized_act_type().is_some() {
266            res = res.to_dtype(original_dtype)?;
267        }
268        Ok(res)
269    }
270}
271
272struct DecoderLayer {
273    self_attn: Attention,
274    mlp: Box<dyn MlpLayer>,
275    input_layernorm: RmsNorm,
276    post_attention_layernorm: RmsNorm,
277    pre_feedforward_layernorm: RmsNorm,
278    post_feedforward_layernorm: RmsNorm,
279}
280
281impl DecoderLayer {
282    #[allow(clippy::too_many_arguments)]
283    fn new(
284        rotary_emb: Arc<RotaryEmbedding>,
285        cfg: &Config,
286        vb: ShardedVarBuilder,
287        mapper: &dyn DeviceMapper,
288        layer_idx: usize,
289        loading_isq: bool,
290        paged_attn: Option<PagedAttention>,
291        comm: &Arc<mistralrs_quant::Comm>,
292    ) -> Result<Self> {
293        let self_attn = Attention::new(
294            rotary_emb,
295            cfg,
296            layer_idx,
297            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
298            paged_attn,
299            comm,
300        )?;
301        let mlp = Mlp::new(
302            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
303            cfg.hidden_size,
304            cfg.intermediate_size,
305            &cfg.quantization_config,
306            cfg.hidden_act()?,
307            comm,
308        )?;
309        let input_layernorm = RmsNorm::new_gemma(
310            cfg.hidden_size,
311            cfg.rms_norm_eps,
312            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
313        )?;
314        let post_attention_layernorm = RmsNorm::new_gemma(
315            cfg.hidden_size,
316            cfg.rms_norm_eps,
317            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
318        )?;
319        let pre_feedforward_layernorm = RmsNorm::new_gemma(
320            cfg.hidden_size,
321            cfg.rms_norm_eps,
322            mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
323        )?;
324        let post_feedforward_layernorm = RmsNorm::new_gemma(
325            cfg.hidden_size,
326            cfg.rms_norm_eps,
327            mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
328        )?;
329        Ok(Self {
330            self_attn,
331            mlp: Box::new(mlp),
332            input_layernorm,
333            post_attention_layernorm,
334            pre_feedforward_layernorm,
335            post_feedforward_layernorm,
336        })
337    }
338
339    #[allow(clippy::too_many_arguments)]
340    fn forward(
341        &self,
342        xs: &Tensor,
343        attention_mask: Option<&Tensor>,
344        sliding_attention_mask: Option<&Tensor>,
345        seqlen_offsets: &[usize],
346        kv_cache: &mut KvCache,
347        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
348        flash_params: &FlashParams,
349    ) -> Result<Tensor> {
350        let residual = xs;
351        let xs = self.input_layernorm.forward(xs)?;
352        let xs = self
353            .self_attn
354            .forward(
355                &xs,
356                attention_mask,
357                sliding_attention_mask,
358                seqlen_offsets,
359                kv_cache,
360                metadata,
361                flash_params,
362            )?
363            .apply(&self.post_attention_layernorm)?;
364        let xs = (xs + residual)?;
365        let residual = &xs;
366        let xs = self
367            .mlp
368            .forward(&xs.apply(&self.pre_feedforward_layernorm)?)?
369            .apply(&self.post_feedforward_layernorm)?;
370        residual + xs
371    }
372}
373
374pub struct Model {
375    embed_tokens: candle_nn::Embedding,
376    layers: Vec<DecoderLayer>,
377    norm: RmsNorm,
378    lm_head: Arc<dyn QuantMethod>,
379    hidden_size: usize,
380    device: Device,
381    cache: EitherCache,
382    max_seq_len: usize,
383    mapper: Box<dyn DeviceMapper + Send + Sync>,
384    sliding_window: usize,
385    final_logit_softcapping: Option<f64>,
386    cfg: ModelConfigMetadata,
387}
388
389impl Model {
390    pub fn new(
391        cfg: &Config,
392        vb: ShardedVarBuilder,
393        is_gptx: bool,
394        normal_loading_metadata: NormalLoadingMetadata,
395        attention_mechanism: AttentionImplementation,
396    ) -> Result<Self> {
397        if let Some(ref quant_cfg) = &cfg.quantization_config {
398            tracing::info!(
399                "Using {} quantization: {}.",
400                quant_cfg.quant_method.to_string(),
401                quant_cfg.get_bits_name(&vb)
402            );
403        }
404        let mapper = normal_loading_metadata.mapper;
405
406        let vb_m = vb.pp("model");
407        let embed_tokens = embedding(
408            cfg.vocab_size,
409            cfg.hidden_size,
410            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
411        )?;
412        let mut ropes = HashMap::new();
413        for layer_idx in 0..cfg.num_hidden_layers {
414            let device = mapper
415                .device_for(layer_idx, false)
416                .unwrap_or(&normal_loading_metadata.real_device);
417            ropes.insert(
418                device.location(),
419                Arc::new(RotaryEmbedding::new(
420                    cfg.rope_theta as f32,
421                    cfg.head_dim,
422                    cfg.max_position_embeddings,
423                    device,
424                    is_gptx,
425                    vb_m.dtype(),
426                )?),
427            );
428        }
429        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
430        let vb_l = vb_m.pp("layers");
431        for layer_idx in NiceProgressBar::<_, 'b'>(
432            0..cfg.num_hidden_layers,
433            "Loading repeating layers",
434            &normal_loading_metadata.multi_progress,
435        ) {
436            let device = mapper
437                .device_for(layer_idx, false)
438                .unwrap_or(&normal_loading_metadata.real_device);
439            let rotary_emb = ropes
440                .get(&device.location())
441                .expect("No RoPE for device location!")
442                .clone();
443            let paged_attn = match &attention_mechanism {
444                AttentionImplementation::Eager => None,
445                AttentionImplementation::PagedAttention => {
446                    Some(PagedAttention::new(cfg.head_dim, device, None)?)
447                }
448            };
449            let comm = mapper.get_comm_for(layer_idx)?;
450            let layer = DecoderLayer::new(
451                rotary_emb.clone(),
452                cfg,
453                vb_l.pp(layer_idx),
454                &*mapper,
455                layer_idx,
456                normal_loading_metadata.loading_isq,
457                paged_attn,
458                &comm,
459            )?;
460            layers.push(layer)
461        }
462        let norm = RmsNorm::new_gemma(
463            cfg.hidden_size,
464            cfg.rms_norm_eps,
465            mapper.set_nm_device(vb_m.pp("norm"), false),
466        )?;
467        let lm_head = mapper.cast_nm_device(
468            embed_tokens.embeddings(),
469            normal_loading_metadata.loading_isq,
470        )?;
471        Ok(Self {
472            embed_tokens,
473            layers,
474            norm,
475            lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(
476                Linear::new(lm_head, None),
477            ))?),
478            device: normal_loading_metadata.real_device,
479            hidden_size: cfg.hidden_size,
480            cache: EitherCache::Normal(NormalCache::new_sliding(
481                cfg.num_hidden_layers,
482                cfg.max_position_embeddings,
483                Some(cfg.sliding_window),
484            )),
485            max_seq_len: cfg.max_position_embeddings,
486            sliding_window: cfg.sliding_window,
487            final_logit_softcapping: cfg.final_logit_softcapping,
488            cfg: ModelConfigMetadata {
489                max_seq_len: cfg.max_position_embeddings,
490                num_layers: cfg.num_hidden_layers,
491                hidden_size: cfg.hidden_size,
492                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
493                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
494                    .max(1),
495                sliding_window: None,
496                k_head_dim: cfg.head_dim,
497                v_head_dim: cfg.head_dim,
498            },
499            mapper,
500        })
501    }
502
503    pub fn forward(
504        &self,
505        input_ids: &Tensor,
506        seqlen_offsets: &[usize],
507        context_lens: Vec<(usize, usize)>,
508        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
509        flash_params: &FlashParams,
510    ) -> Result<Tensor> {
511        let xs = self.embed_tokens.forward(input_ids)?;
512        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
513        let cache = &mut self.cache.normal().0;
514        let attention_mask = CausalMasker.make_causal_mask_matrix(
515            input_ids,
516            &*cache,
517            xs.dtype(),
518            self.cfg.num_attn_heads,
519        )?;
520        // PagedAttention prompt chunking
521        let attention_mask = attention_mask.filter(|_| {
522            metadata
523                .as_ref()
524                .map(|(_, meta)| meta.is_first_prompt_chunk)
525                .unwrap_or(true)
526        });
527        let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
528            input_ids,
529            &*cache,
530            Some(self.sliding_window),
531            xs.dtype(),
532            self.cfg.num_attn_heads,
533        )?;
534        // PagedAttention prompt chunking
535        let sliding_attention_mask = sliding_attention_mask.filter(|_| {
536            metadata
537                .as_ref()
538                .map(|(_, meta)| meta.is_first_prompt_chunk)
539                .unwrap_or(true)
540        });
541        for (i, layer) in self.layers.iter().enumerate() {
542            xs = self.mapper.map(xs, i)?;
543            xs = layer.forward(
544                &xs,
545                attention_mask
546                    .as_ref()
547                    .map(|m| m.to_device(xs.device()).unwrap())
548                    .as_ref(),
549                sliding_attention_mask
550                    .as_ref()
551                    .map(|m| m.to_device(xs.device()).unwrap())
552                    .as_ref(),
553                seqlen_offsets,
554                &mut cache[i],
555                metadata
556                    .as_ref()
557                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
558                flash_params,
559            )?;
560        }
561        let xs = xs.to_device(&self.device)?;
562        let mut xs = xs.apply(&self.norm)?;
563        if let Some(t) = self.lm_head.quantized_act_type() {
564            xs = xs.to_dtype(t)?;
565        }
566
567        let mut xs = MatMul.qmethod_matmul(&xs, &*self.lm_head)?;
568
569        if let Some(final_logit_softcapping) = self.final_logit_softcapping {
570            xs = (xs / final_logit_softcapping)?;
571            xs = xs.tanh()?;
572            xs = (xs * final_logit_softcapping)?;
573        }
574
575        extract_logits(&xs, context_lens)
576    }
577}
578
579impl IsqModel for Model {
580    fn get_layers(
581        &mut self,
582    ) -> (
583        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
584        &dyn DeviceMapper,
585    ) {
586        let mut tensors = Vec::new();
587        tensors.push((&mut self.lm_head, None));
588        for (i, layer) in self.layers.iter_mut().enumerate() {
589            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
590            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
591            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
592            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
593            tensors.extend(
594                layer
595                    .mlp
596                    .get_isq_layers()
597                    .into_iter()
598                    .map(|m| (m, Some(i)))
599                    .collect::<Vec<_>>(),
600            );
601        }
602        (tensors, &*self.mapper)
603    }
604
605    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
606        let uvb = UnVarBuilder::new();
607
608        let uvb_m = uvb.pp("model");
609        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
610        uvb_m.pp("norm").add(&self.norm.undo_gemma().unwrap());
611
612        for (layer_idx, layer) in self.layers.iter().enumerate() {
613            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
614            uvb_l
615                .pp("input_layernorm")
616                .add(&layer.input_layernorm.undo_gemma().unwrap());
617            uvb_l
618                .pp("post_attention_layernorm")
619                .add(&layer.post_attention_layernorm.undo_gemma().unwrap());
620            uvb_l
621                .pp("pre_feedforward_layernorm")
622                .add(&layer.pre_feedforward_layernorm.undo_gemma().unwrap());
623            uvb_l
624                .pp("post_feedforward_layernorm")
625                .add(&layer.post_feedforward_layernorm.undo_gemma().unwrap());
626        }
627
628        uvb.to_safetensors()
629    }
630
631    fn imatrix_names(&self) -> candle_core::Result<Vec<Option<String>>> {
632        // NOTE: dependant on the exact implementation in get_layers!
633        let mut names = Vec::new();
634        // lm_head
635        names.push(None);
636        for i in 0..self.layers.len() {
637            names.push(Some(format!("blk.{i}.attn_q.weight")));
638            names.push(Some(format!("blk.{i}.attn_k.weight")));
639            names.push(Some(format!("blk.{i}.attn_v.weight")));
640            names.push(Some(format!("blk.{i}.attn_output.weight")));
641            names.push(Some(format!("blk.{i}.ffn_gate.weight")));
642            names.push(Some(format!("blk.{i}.ffn_up.weight")));
643            names.push(Some(format!("blk.{i}.ffn_down.weight")));
644        }
645        Ok(names)
646    }
647}
648
649impl NormalModel for Model {
650    fn forward(
651        &self,
652        input_ids: &Tensor,
653        seqlen_offsets: &[usize],
654        context_lens: Vec<(usize, usize)>,
655        _position_ids: Vec<usize>,
656        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
657        flash_params: &FlashParams,
658    ) -> Result<Tensor> {
659        self.forward(
660            input_ids,
661            seqlen_offsets,
662            context_lens,
663            metadata,
664            flash_params,
665        )
666    }
667    fn xlora_forward(
668        &self,
669        _input_ids: &Tensor,
670        _input_ids_full: &Tensor,
671        _seqlen_offsets: &[usize],
672        _seqlen_offsets_full: &[usize],
673        _no_kv_cache: bool,
674        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
675        _context_lens: Vec<(usize, usize)>,
676        _position_ids: Vec<usize>,
677        _flash_params: &FlashParams,
678        _flash_params_full: &FlashParams,
679    ) -> Result<Tensor> {
680        unimplemented!()
681    }
682    fn cache(&self) -> &EitherCache {
683        &self.cache
684    }
685    fn cache_mut(&mut self) -> &mut EitherCache {
686        &mut self.cache
687    }
688    fn device(&self) -> &Device {
689        &self.device
690    }
691    fn is_xlora(&self) -> bool {
692        false
693    }
694    fn max_seq_len(&self) -> usize {
695        self.max_seq_len
696    }
697    fn config(&self) -> &ModelConfigMetadata {
698        &self.cfg
699    }
700}
701
702impl AnyMoeBaseModelMixin for Model {
703    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
704        let mut mlps = Vec::new();
705        for layer in &self.layers {
706            mlps.push(&*layer.mlp);
707        }
708        mlps
709    }
710    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
711        let mut mlps = Vec::new();
712        for layer in &mut self.layers {
713            mlps.push(&mut layer.mlp);
714        }
715        mlps
716    }
717    fn create_anymoe_layers(
718        &mut self,
719        additional_vbs: Vec<ShardedVarBuilder>,
720        config: AnyMoeConfig,
721        (prefix, mlp): (String, String),
722        mut layers: Vec<usize>,
723        expert_type: AnyMoeExpertType,
724        gate_vb: Option<ShardedVarBuilder>,
725    ) -> Result<()> {
726        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
727        if layers.is_empty() {
728            layers = (0..self.layers.len()).collect::<Vec<_>>();
729        }
730        for _ in 0..layers.len() {
731            experts.push(Vec::new());
732        }
733        for vb in additional_vbs {
734            let vb = vb.pp(&prefix);
735            for (layer, row) in experts.iter_mut().enumerate() {
736                if !layers.contains(&layer) {
737                    continue;
738                }
739
740                let intermediate_size = self.layers[layer].mlp.get_params()[1];
741                let hidden_size = self.layers[layer].mlp.get_params()[0];
742                match expert_type {
743                    AnyMoeExpertType::FineTuned => {
744                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
745                        row.push(Box::new(Mlp::replicate(
746                            self.layers[layer].mlp.get_params(),
747                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
748                            self.layers[layer].mlp.hidden_act(),
749                            &self.mapper.get_comm_for(layer)?,
750                        )?));
751                    }
752                    AnyMoeExpertType::LoraAdapter {
753                        rank,
754                        alpha,
755                        ref target_modules,
756                    } => {
757                        let vb_mlp = vb.pp(layer).pp(&mlp);
758
759                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
760                            Some(get_delta_from_lora_ab!(
761                                vb_mlp,
762                                rank,
763                                alpha,
764                                (hidden_size, intermediate_size),
765                                "gate_proj"
766                            ))
767                        } else {
768                            None
769                        };
770                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
771                            Some(get_delta_from_lora_ab!(
772                                vb_mlp,
773                                rank,
774                                alpha,
775                                (hidden_size, intermediate_size),
776                                "up_proj"
777                            ))
778                        } else {
779                            None
780                        };
781                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
782                            Some(get_delta_from_lora_ab!(
783                                vb_mlp,
784                                rank,
785                                alpha,
786                                (intermediate_size, hidden_size),
787                                "down_proj"
788                            ))
789                        } else {
790                            None
791                        };
792
793                        row.push(self.layers[layer].mlp.new_added_delta(vec![
794                            gate_proj_delta,
795                            up_proj_delta,
796                            down_proj_delta,
797                        ])?);
798                    }
799                }
800            }
801        }
802        for (layer, expert) in layers.into_iter().zip(experts) {
803            let mut experts_all = vec![self.layers[layer].mlp.clone()];
804            experts_all.extend(expert);
805            let (dtype, device) = self.layers[layer].mlp.dtype_device();
806            self.layers[layer].mlp = Box::new(MoeMlp::new(
807                experts_all,
808                config.clone(),
809                dtype,
810                &device,
811                layer,
812                gate_vb.as_ref(),
813            )?);
814        }
815        Ok(())
816    }
817    fn amoe_supported(&self) -> bool {
818        true
819    }
820}