mistralrs_core/vision_models/llava/llava_llm/
mistral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::sync::Arc;
4
5/// Mistral LLM, https://github.com/mistralai/mistral-src
6use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantMethodConfig, RowParallelLayer, ShardedVarBuilder,
9    UnquantLinear,
10};
11
12use crate::{
13    amoe::{AnyMoeBaseModelMixin, AnyMoeTrainableLayer, MlpLayer, MoeMlp},
14    attention::SdpaParams,
15    device_map::DeviceMapper,
16    get_delta_from_lora_ab,
17    layers::{self, linear_no_bias, Activation, CausalMasker, MatMul, RmsNorm, Sdpa},
18    layers_masker::PastKvLenCache,
19    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
20    pipeline::{
21        extract_logits,
22        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
23        Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
24    },
25    utils::progress::NiceProgressBar,
26    AnyMoeConfig, AnyMoeExpertType,
27};
28
29use super::{LLaVALLM, OrdinaryRoPE};
30use crate::models::mistral::Config;
31
32#[derive(Clone)]
33#[allow(clippy::upper_case_acronyms)]
34struct MLP {
35    gate_proj: Arc<dyn QuantMethod>,
36    up_proj: Arc<dyn QuantMethod>,
37    down_proj: Arc<dyn QuantMethod>,
38    act_fn: Activation,
39    params: Vec<usize>,
40}
41
42impl MLP {
43    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
44        let hidden_sz = cfg.hidden_size;
45        let intermediate_sz = cfg.intermediate_size;
46        let gate_proj = ColumnParallelLayer::new(
47            hidden_sz,
48            intermediate_sz,
49            &cfg.quantization_config,
50            false,
51            comm,
52            vb.pp("gate_proj"),
53        )?;
54        let up_proj = ColumnParallelLayer::new(
55            hidden_sz,
56            intermediate_sz,
57            &cfg.quantization_config,
58            false,
59            comm,
60            vb.pp("up_proj"),
61        )?;
62        let down_proj = RowParallelLayer::new(
63            intermediate_sz,
64            hidden_sz,
65            &cfg.quantization_config,
66            false,
67            comm,
68            vb.pp("down_proj"),
69        )?;
70        Ok(Self {
71            gate_proj,
72            up_proj,
73            down_proj,
74            act_fn: cfg.hidden_act,
75            params: vec![hidden_sz, intermediate_sz],
76        })
77    }
78}
79
80impl AnyMoeTrainableLayer for MLP {}
81
82impl MlpLayer for MLP {
83    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
84        let original_dtype = xs.dtype();
85        let mut xs = xs.clone();
86        if let Some(t) = self.gate_proj.quantized_act_type() {
87            xs = xs.to_dtype(t)?;
88        }
89        let lhs = MatMul
90            .qmethod_matmul(&xs, &*self.gate_proj)?
91            .apply(&self.act_fn)?;
92        let rhs = MatMul.qmethod_matmul(&xs, &*self.up_proj)?;
93        let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.down_proj)?;
94        if self.gate_proj.quantized_act_type().is_some() {
95            res = res.to_dtype(original_dtype)?;
96        }
97        Ok(res)
98    }
99    fn get_isq_layers(&mut self) -> Vec<&mut Arc<dyn QuantMethod>> {
100        vec![&mut self.gate_proj, &mut self.up_proj, &mut self.down_proj]
101    }
102    fn clone(&self) -> Box<dyn MlpLayer> {
103        Box::new(Clone::clone(self))
104    }
105    fn get_params(&self) -> &[usize] {
106        &self.params
107    }
108    fn hidden_act(&self) -> Activation {
109        self.act_fn
110    }
111    // gate, up, down
112    fn new_added_delta(&self, deltas: Vec<Option<Tensor>>) -> Result<Box<dyn MlpLayer>> {
113        let gate_proj = if let Some(ref delta) = deltas[0] {
114            self.gate_proj.add_delta_w(delta)?
115        } else {
116            self.gate_proj.clone()
117        };
118        let up_proj = if let Some(ref delta) = deltas[1] {
119            self.up_proj.add_delta_w(delta)?
120        } else {
121            self.up_proj.clone()
122        };
123        let down_proj = if let Some(ref delta) = deltas[2] {
124            self.down_proj.add_delta_w(delta)?
125        } else {
126            self.down_proj.clone()
127        };
128
129        Ok(Box::new(Self {
130            gate_proj,
131            up_proj,
132            down_proj,
133            act_fn: self.act_fn,
134            params: self.params.clone(),
135        }))
136    }
137
138    fn dtype_device(&self) -> (DType, Device) {
139        self.gate_proj.dtype_and_device()
140    }
141}
142
143struct Attention {
144    q_proj: Arc<dyn QuantMethod>,
145    k_proj: Arc<dyn QuantMethod>,
146    v_proj: Arc<dyn QuantMethod>,
147    o_proj: Arc<dyn QuantMethod>,
148    num_heads: usize,
149    num_kv_heads: usize,
150    head_dim: usize,
151    sliding_window: Option<usize>,
152    paged_attn: Option<PagedAttention>,
153    sdpa_params: SdpaParams,
154}
155
156impl Attention {
157    fn new(
158        cfg: &Config,
159        vb: ShardedVarBuilder,
160        paged_attn: Option<PagedAttention>,
161        comm: &Arc<mistralrs_quant::Comm>,
162    ) -> Result<Self> {
163        let hidden_sz = cfg.hidden_size;
164        let num_heads = cfg.num_attention_heads;
165        let num_kv_heads = cfg.num_key_value_heads;
166        let head_dim = cfg.head_dim();
167        let q_proj = ColumnParallelLayer::new(
168            hidden_sz,
169            num_heads * head_dim,
170            &cfg.quantization_config,
171            false,
172            comm,
173            vb.pp("q_proj"),
174        )?;
175        let kv_shard = mistralrs_quant::compute_kv_shard(
176            cfg.num_key_value_heads,
177            cfg.hidden_size / cfg.num_attention_heads,
178            comm,
179        );
180        let k_proj = ColumnParallelLayer::new_with_shard(
181            hidden_sz,
182            num_kv_heads * head_dim,
183            &cfg.quantization_config,
184            false,
185            comm,
186            kv_shard,
187            vb.pp("k_proj"),
188        )?;
189        let v_proj = ColumnParallelLayer::new_with_shard(
190            hidden_sz,
191            num_kv_heads * head_dim,
192            &cfg.quantization_config,
193            false,
194            comm,
195            kv_shard,
196            vb.pp("v_proj"),
197        )?;
198        let o_proj = RowParallelLayer::new(
199            num_heads * head_dim,
200            hidden_sz,
201            &cfg.quantization_config,
202            false,
203            comm,
204            vb.pp("o_proj"),
205        )?;
206        Ok(Self {
207            q_proj,
208            k_proj,
209            v_proj,
210            o_proj,
211            num_heads: num_heads / comm.world_size(),
212            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
213            head_dim,
214            sliding_window: cfg.sliding_window,
215            paged_attn,
216            sdpa_params: SdpaParams {
217                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
218                    cfg.num_key_value_heads,
219                    cfg.num_attention_heads,
220                    comm,
221                ),
222                use_flash_attn: cfg.use_flash_attn,
223                softcap: None,
224                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
225                sliding_window: cfg.sliding_window,
226            },
227        })
228    }
229
230    #[allow(clippy::too_many_arguments)]
231    fn forward(
232        &self,
233        xs: &Tensor,
234        attention_mask: Option<&Tensor>,
235        seqlen_offsets: &[usize],
236        kv_cache: &mut Option<(Tensor, Tensor)>,
237        rope_parameter: (&Tensor, &Tensor),
238        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
239        flash_params: &FlashParams,
240    ) -> Result<Tensor> {
241        let (b_sz, q_len, _) = xs.dims3()?;
242
243        let original_dtype = xs.dtype();
244        let mut xs = xs.clone();
245        if let Some(t) = self.q_proj.quantized_act_type() {
246            xs = xs.to_dtype(t)?;
247        }
248        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
249        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
250        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
251        if self.q_proj.quantized_act_type().is_some() {
252            q = q.to_dtype(original_dtype)?;
253            k = k.to_dtype(original_dtype)?;
254            v = v.to_dtype(original_dtype)?;
255        }
256
257        let mut q = q
258            .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
259            .transpose(1, 2)?
260            .contiguous()?;
261        let mut k = k
262            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
263            .transpose(1, 2)?
264            .contiguous()?;
265        q = OrdinaryRoPE::forward(&q, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
266        k = OrdinaryRoPE::forward(&k, seqlen_offsets[0], rope_parameter.0, rope_parameter.1)?;
267        let v = v
268            .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
269            .transpose(1, 2)?;
270
271        let mut attn_output = match &self.paged_attn {
272            Some(paged_attn) => match metadata {
273                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
274                    &q,
275                    &k,
276                    &v,
277                    attention_mask,
278                    Some(key_cache),
279                    Some(value_cache),
280                    input_metadata,
281                    &self.sdpa_params,
282                    Some(flash_params),
283                )?,
284                None => {
285                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
286                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
287                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
288                    // Sanity check.
289                    assert!(attention_mask.is_some());
290                    paged_attn.forward(
291                        &q,
292                        &k,
293                        &v,
294                        attention_mask,
295                        None,
296                        None,
297                        &input_metadata,
298                        &self.sdpa_params,
299                        Some(flash_params),
300                    )?
301                }
302            },
303            None => {
304                let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
305                    kv_cache,
306                    k,
307                    v,
308                    attention_mask,
309                    self.sliding_window,
310                    false,
311                )?;
312
313                Sdpa.run_attention(
314                    &q,
315                    &k,
316                    &v,
317                    attn_mask.as_ref(),
318                    Some(flash_params),
319                    &self.sdpa_params,
320                )?
321            }
322        };
323
324        if let Some(t) = self.q_proj.quantized_act_type() {
325            attn_output = attn_output.to_dtype(t)?;
326        }
327        attn_output = if attention_mask.is_some() {
328            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
329        } else {
330            attn_output.reshape((b_sz, q_len, ()))?
331        };
332        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
333        if self.q_proj.quantized_act_type().is_some() {
334            res = res.to_dtype(original_dtype)?;
335        }
336        Ok(res)
337    }
338}
339
340struct DecoderLayer {
341    self_attn: Attention,
342    mlp: Box<dyn MlpLayer>,
343    input_layernorm: RmsNorm,
344    post_attention_layernorm: RmsNorm,
345    rope_parameter: (Tensor, Tensor),
346}
347
348impl DecoderLayer {
349    #[allow(clippy::too_many_arguments)]
350    fn new(
351        cfg: &Config,
352        vb: ShardedVarBuilder,
353        mapper: &dyn DeviceMapper,
354        layer_idx: usize,
355        loading_isq: bool,
356        paged_attn: Option<PagedAttention>,
357        rope_parameter: (Tensor, Tensor),
358        comm: &Arc<mistralrs_quant::Comm>,
359    ) -> Result<Self> {
360        let self_attn = Attention::new(
361            cfg,
362            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
363            paged_attn,
364            comm,
365        )?;
366        let mlp = MLP::new(
367            cfg,
368            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
369            comm,
370        )?;
371        let input_layernorm = RmsNorm::new(
372            cfg.hidden_size,
373            cfg.rms_norm_eps,
374            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
375        )?;
376        let post_attention_layernorm = RmsNorm::new(
377            cfg.hidden_size,
378            cfg.rms_norm_eps,
379            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
380        )?;
381        Ok(Self {
382            self_attn,
383            mlp: Box::new(mlp),
384            input_layernorm,
385            post_attention_layernorm,
386            rope_parameter,
387        })
388    }
389
390    #[allow(clippy::too_many_arguments)]
391    fn forward(
392        &self,
393        xs: &Tensor,
394        attention_mask: Option<&Tensor>,
395        seqlen_offsets: &[usize],
396        kv_cache: &mut Option<(Tensor, Tensor)>,
397        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
398        flash_params: &FlashParams,
399    ) -> Result<Tensor> {
400        let residual = xs;
401        let mut xs = self.input_layernorm.forward(xs)?;
402        xs = self.self_attn.forward(
403            &xs,
404            attention_mask,
405            seqlen_offsets,
406            kv_cache,
407            (&self.rope_parameter.0, &self.rope_parameter.1),
408            metadata,
409            flash_params,
410        )?;
411        xs = (xs + residual)?;
412        let residual = &xs;
413        let xs = self
414            .mlp
415            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
416        residual + xs
417    }
418}
419
420pub struct Model {
421    embed_tokens: candle_nn::Embedding,
422    layers: Vec<DecoderLayer>,
423    norm: RmsNorm,
424    lm_head: Arc<dyn QuantMethod>,
425    sliding_window: Option<usize>,
426    device: Device,
427    cache: EitherCache,
428    max_seq_len: usize,
429    mapper: Box<dyn DeviceMapper + Send + Sync>,
430    cfg: ModelConfigMetadata,
431}
432
433impl Model {
434    pub fn new(
435        cfg: &Config,
436        vb: ShardedVarBuilder,
437        is_gptx: bool,
438        normal_loading_metadata: NormalLoadingMetadata,
439        attention_mechanism: AttentionImplementation,
440    ) -> Result<Self> {
441        let vb_m = vb.pp("model");
442        let vb_lm_head = vb.pp("lm_head");
443        Self::new_inner(
444            cfg,
445            vb_m,
446            vb_lm_head,
447            is_gptx,
448            normal_loading_metadata,
449            attention_mechanism,
450        )
451    }
452
453    pub fn new_inner(
454        cfg: &Config,
455        vb_m: ShardedVarBuilder,
456        vb_lm_head: ShardedVarBuilder,
457        _is_gptx: bool,
458        normal_loading_metadata: NormalLoadingMetadata,
459        attention_mechanism: AttentionImplementation,
460    ) -> Result<Self> {
461        if let Some(ref quant_cfg) = &cfg.quantization_config {
462            tracing::info!(
463                "Using {} quantization: {}.",
464                quant_cfg.name(),
465                quant_cfg.get_bits_name(&vb_m)
466            );
467        }
468        let mapper = normal_loading_metadata.mapper;
469        let embed_tokens = layers::embedding(
470            cfg.vocab_size,
471            cfg.hidden_size,
472            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
473            &cfg.quantization_config,
474        )?;
475        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
476        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
477        let vb_l = vb_m.pp("layers");
478        for layer_idx in NiceProgressBar::<_, 'b'>(
479            0..cfg.num_hidden_layers,
480            "Loading repeating layers",
481            &normal_loading_metadata.multi_progress,
482        ) {
483            let device = mapper
484                .device_for(layer_idx, false)
485                .unwrap_or(&normal_loading_metadata.real_device);
486            let rope_parameters = OrdinaryRoPE::create_parameters(
487                head_dim,
488                cfg.max_position_embeddings,
489                cfg.rope_theta as f32,
490                vb_m.dtype(),
491                device,
492            )?;
493            let paged_attn = match &attention_mechanism {
494                AttentionImplementation::Eager => None,
495                AttentionImplementation::PagedAttention => {
496                    Some(PagedAttention::new(head_dim, device, None)?)
497                }
498            };
499            let comm = mapper.get_comm_for(layer_idx)?;
500            let layer = DecoderLayer::new(
501                cfg,
502                vb_l.pp(layer_idx),
503                &*mapper,
504                layer_idx,
505                normal_loading_metadata.loading_isq,
506                paged_attn,
507                rope_parameters,
508                &comm,
509            )?;
510            layers.push(layer)
511        }
512        let norm = RmsNorm::new(
513            cfg.hidden_size,
514            cfg.rms_norm_eps,
515            mapper.set_nm_device(vb_m.pp("norm"), false),
516        )?;
517        let lm_head = linear_no_bias(
518            cfg.hidden_size,
519            cfg.vocab_size,
520            mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
521        )?;
522        Ok(Self {
523            embed_tokens,
524            layers,
525            norm,
526            lm_head: Arc::new(UnquantLinear::new(QuantMethodConfig::Unquantized(lm_head))?),
527            sliding_window: cfg.sliding_window,
528            device: normal_loading_metadata.real_device,
529            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
530            max_seq_len: cfg.max_position_embeddings,
531            cfg: ModelConfigMetadata {
532                max_seq_len: cfg.max_position_embeddings,
533                num_layers: cfg.num_hidden_layers,
534                hidden_size: cfg.hidden_size,
535                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
536                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
537                    .max(1),
538                sliding_window: cfg.sliding_window,
539                k_head_dim: cfg.head_dim(),
540                v_head_dim: cfg.head_dim(),
541            },
542            mapper,
543        })
544    }
545
546    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
547        self.embed_tokens.forward(input_ids)
548    }
549
550    pub fn forward(
551        &self,
552        input_ids: &Tensor,
553        seqlen_offsets: &[usize],
554        context_lens: Vec<(usize, usize)>,
555        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
556        flash_params: &FlashParams,
557    ) -> Result<Tensor> {
558        self.forward_embeds(
559            input_ids,
560            self.embed_tokens.forward(input_ids)?,
561            seqlen_offsets,
562            context_lens,
563            metadata,
564            flash_params,
565        )
566    }
567
568    #[allow(clippy::too_many_arguments)]
569    pub fn forward_embeds(
570        &self,
571        input_ids: &Tensor,
572        input_embeds: Tensor,
573        seqlen_offsets: &[usize],
574        context_lens: Vec<(usize, usize)>,
575        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
576        flash_params: &FlashParams,
577    ) -> Result<Tensor> {
578        let mut xs = input_embeds;
579        let mut cache = self.cache.full().lock();
580        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
581            input_ids,
582            metadata
583                .as_ref()
584                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
585                .unwrap_or(&*cache as &dyn PastKvLenCache),
586            self.sliding_window,
587            xs.dtype(),
588            self.cfg.num_attn_heads,
589        )?;
590        let attention_mask = attention_mask.filter(|_| {
591            metadata
592                .as_ref()
593                .map(|(_, meta)| meta.is_first_prompt_chunk)
594                .unwrap_or(true)
595        });
596        for (i, layer) in self.layers.iter().enumerate() {
597            xs = self.mapper.map(xs, i)?;
598            xs = layer.forward(
599                &xs,
600                attention_mask
601                    .as_ref()
602                    .map(|m| m.to_device(xs.device()).unwrap())
603                    .as_ref(),
604                seqlen_offsets,
605                &mut cache[i],
606                metadata
607                    .as_ref()
608                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
609                flash_params,
610            )?;
611        }
612        xs = xs.to_device(&self.device)?;
613        xs = xs.apply(&self.norm)?;
614        if let Some(t) = self.lm_head.quantized_act_type() {
615            xs = xs.to_dtype(t)?;
616        }
617        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
618    }
619}
620
621impl IsqModel for Model {
622    fn get_layers(
623        &mut self,
624    ) -> (
625        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
626        &dyn DeviceMapper,
627    ) {
628        let mut tensors = Vec::new();
629        tensors.push((&mut self.lm_head, None));
630        for (i, layer) in self.layers.iter_mut().enumerate() {
631            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
632            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
633            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
634            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
635            tensors.extend(
636                layer
637                    .mlp
638                    .get_isq_layers()
639                    .into_iter()
640                    .map(|m| (m, Some(i)))
641                    .collect::<Vec<_>>(),
642            );
643        }
644        (tensors, &*self.mapper)
645    }
646
647    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
648        Vec::new()
649    }
650}
651
652impl LLaVALLM for Model {
653    fn embed(&self, input_ids: &Tensor) -> Result<Tensor> {
654        self.get_input_embeddings(input_ids)
655    }
656
657    fn forward_input_embed(
658        &self,
659        input_ids: &Tensor,
660        input_embed: Tensor,
661        seqlen_offsets: &[usize],
662        context_lens: Vec<(usize, usize)>,
663        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
664        flash_params: &FlashParams,
665    ) -> Result<Tensor> {
666        self.forward_embeds(
667            input_ids,
668            input_embed,
669            seqlen_offsets,
670            context_lens,
671            metadata,
672            flash_params,
673        )
674    }
675}
676
677impl NormalModel for Model {
678    fn forward(
679        &self,
680        input_ids: &Tensor,
681        seqlen_offsets: &[usize],
682        context_lens: Vec<(usize, usize)>,
683        _position_ids: Vec<usize>,
684        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
685        flash_params: &FlashParams,
686    ) -> Result<Tensor> {
687        self.forward(
688            input_ids,
689            seqlen_offsets,
690            context_lens,
691            metadata,
692            flash_params,
693        )
694    }
695    fn xlora_forward(
696        &self,
697        _input_ids: &Tensor,
698        _input_ids_full: &Tensor,
699        _seqlen_offsets: &[usize],
700        _seqlen_offsets_full: &[usize],
701        _no_kv_cache: bool,
702        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
703        _context_lens: Vec<(usize, usize)>,
704        _position_ids: Vec<usize>,
705        _flash_params: &FlashParams,
706        _flash_params_full: &FlashParams,
707    ) -> Result<Tensor> {
708        unimplemented!()
709    }
710    fn cache(&self) -> &EitherCache {
711        &self.cache
712    }
713    fn cache_mut(&mut self) -> &mut EitherCache {
714        &mut self.cache
715    }
716    fn device(&self) -> &Device {
717        &self.device
718    }
719    fn is_xlora(&self) -> bool {
720        false
721    }
722    fn max_seq_len(&self) -> usize {
723        self.max_seq_len
724    }
725    fn config(&self) -> &ModelConfigMetadata {
726        &self.cfg
727    }
728}
729
730impl AnyMoeBaseModelMixin for Model {
731    fn get_mlps(&self) -> Vec<&dyn MlpLayer> {
732        let mut mlps = Vec::new();
733        for layer in &self.layers {
734            mlps.push(&*layer.mlp);
735        }
736        mlps
737    }
738    fn get_mlps_mut(&mut self) -> Vec<&mut Box<dyn MlpLayer>> {
739        let mut mlps = Vec::new();
740        for layer in &mut self.layers {
741            mlps.push(&mut layer.mlp);
742        }
743        mlps
744    }
745    fn create_anymoe_layers(
746        &mut self,
747        additional_vbs: Vec<ShardedVarBuilder>,
748        config: AnyMoeConfig,
749        (prefix, mlp): (String, String),
750        mut layers: Vec<usize>,
751        expert_type: AnyMoeExpertType,
752        gate_vb: Option<ShardedVarBuilder>,
753    ) -> Result<()> {
754        let mut experts: Vec<Vec<Box<dyn MlpLayer>>> = Vec::new();
755        if layers.is_empty() {
756            layers = (0..self.layers.len()).collect::<Vec<_>>();
757        }
758        for _ in 0..layers.len() {
759            experts.push(Vec::new());
760        }
761        for vb in additional_vbs {
762            let vb = vb.pp(&prefix);
763            for (layer, row) in experts.iter_mut().enumerate() {
764                if !layers.contains(&layer) {
765                    continue;
766                }
767
768                let intermediate_size = self.layers[layer].mlp.get_params()[1];
769                let hidden_size = self.layers[layer].mlp.get_params()[0];
770                match expert_type {
771                    AnyMoeExpertType::FineTuned => {
772                        let (dtype, device) = self.layers[layer].mlp.dtype_device();
773                        row.push(Box::new(MLP::new(
774                            &Config {
775                                intermediate_size: self.layers[layer].mlp.get_params()[1],
776                                hidden_size: self.layers[layer].mlp.get_params()[0],
777                                ..Default::default()
778                            },
779                            vb.pp(layer).pp(&mlp).set_dtype(dtype).set_device(device),
780                            &self.mapper.get_comm_for(layer)?,
781                        )?));
782                    }
783                    AnyMoeExpertType::LoraAdapter {
784                        rank,
785                        alpha,
786                        ref target_modules,
787                    } => {
788                        let vb_mlp = vb.pp(layer).pp(&mlp);
789
790                        let gate_proj_delta = if target_modules.contains(&"gate_proj".to_string()) {
791                            Some(get_delta_from_lora_ab!(
792                                vb_mlp,
793                                rank,
794                                alpha,
795                                (hidden_size, intermediate_size),
796                                "gate_proj"
797                            ))
798                        } else {
799                            None
800                        };
801                        let up_proj_delta = if target_modules.contains(&"up_proj".to_string()) {
802                            Some(get_delta_from_lora_ab!(
803                                vb_mlp,
804                                rank,
805                                alpha,
806                                (hidden_size, intermediate_size),
807                                "up_proj"
808                            ))
809                        } else {
810                            None
811                        };
812                        let down_proj_delta = if target_modules.contains(&"down_proj".to_string()) {
813                            Some(get_delta_from_lora_ab!(
814                                vb_mlp,
815                                rank,
816                                alpha,
817                                (intermediate_size, hidden_size),
818                                "down_proj"
819                            ))
820                        } else {
821                            None
822                        };
823
824                        row.push(self.layers[layer].mlp.new_added_delta(vec![
825                            gate_proj_delta,
826                            up_proj_delta,
827                            down_proj_delta,
828                        ])?);
829                    }
830                }
831            }
832        }
833        for (layer, expert) in layers.into_iter().zip(experts) {
834            let mut experts_all = vec![self.layers[layer].mlp.clone()];
835            experts_all.extend(expert);
836            let (dtype, device) = self.layers[layer].mlp.dtype_device();
837            self.layers[layer].mlp = Box::new(MoeMlp::new(
838                experts_all,
839                config.clone(),
840                dtype,
841                &device,
842                layer,
843                gate_vb.as_ref(),
844            )?);
845        }
846        Ok(())
847    }
848    fn amoe_supported(&self) -> bool {
849        true
850    }
851}