mistralrs_core/models/
phi3_5_moe.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{
6    ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
7    RowParallelLayer, ShardedVarBuilder,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12    amoe::AnyMoeBaseModelMixin,
13    attention::SdpaParams,
14    device_map::DeviceMapper,
15    layers::{
16        self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
17        PhiRotaryEmbedding, Sdpa,
18    },
19    layers_masker::{masked_fill, PastKvLenCache},
20    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21    pipeline::{
22        extract_logits,
23        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25    },
26    serde_default_fn,
27    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
33#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
34pub struct Config {
35    pub(crate) vocab_size: usize,
36    pub(crate) hidden_act: Activation,
37    pub(crate) hidden_size: usize,
38    pub(crate) intermediate_size: usize,
39    pub(crate) num_hidden_layers: usize,
40    pub(crate) num_attention_heads: usize,
41    pub(crate) num_key_value_heads: usize,
42    pub(crate) rms_norm_eps: f64,
43    pub(crate) rope_theta: f64,
44    pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
45    pub(crate) max_position_embeddings: usize,
46    pub(crate) sliding_window: Option<usize>,
47    pub(crate) original_max_position_embeddings: usize,
48
49    pub(crate) quantization_config: Option<QuantizedConfig>,
50    pub(crate) lm_head_bias: bool,
51    pub(crate) attention_bias: bool,
52    pub(crate) num_local_experts: usize,
53    pub(crate) router_jitter_noise: f64,
54    #[serde(default = "word_emb_default")]
55    pub(crate) tie_word_embeddings: bool,
56}
57
58impl From<Config> for PhiRopeConfig {
59    fn from(val: Config) -> Self {
60        PhiRopeConfig {
61            rope_scaling: val.rope_scaling,
62            max_position_embeddings: val.max_position_embeddings,
63            original_max_position_embeddings: val.original_max_position_embeddings,
64            rope_theta: val.rope_theta,
65            head_dim: val.hidden_size / val.num_attention_heads,
66            partial_rotary_factor: None,
67        }
68    }
69}
70
71impl Config {
72    pub fn head_dim(&self) -> usize {
73        self.hidden_size / self.num_attention_heads
74    }
75}
76
77struct Attention {
78    q_proj: Arc<dyn QuantMethod>,
79    k_proj: Arc<dyn QuantMethod>,
80    v_proj: Arc<dyn QuantMethod>,
81    o_proj: Arc<dyn QuantMethod>,
82    num_heads: usize,
83    num_kv_heads: usize,
84    head_dim: usize,
85    rotary_emb: Arc<PhiRotaryEmbedding>,
86    paged_attn: Option<PagedAttention>,
87    sdpa_params: SdpaParams,
88}
89
90impl Attention {
91    fn new(
92        rotary_emb: Arc<PhiRotaryEmbedding>,
93        cfg: &Config,
94        vb: ShardedVarBuilder,
95        paged_attn: Option<PagedAttention>,
96        comm: &Arc<mistralrs_quant::Comm>,
97    ) -> Result<Self> {
98        let num_heads = cfg.num_attention_heads;
99        let num_kv_heads = cfg.num_key_value_heads;
100        let head_dim = cfg.head_dim();
101
102        let q_proj = ColumnParallelLayer::new(
103            cfg.hidden_size,
104            num_heads * head_dim,
105            &cfg.quantization_config,
106            cfg.attention_bias,
107            comm,
108            vb.pp("q_proj"),
109        )?;
110        let kv_shard = mistralrs_quant::compute_kv_shard(
111            cfg.num_key_value_heads,
112            cfg.hidden_size / cfg.num_attention_heads,
113            comm,
114        );
115        let k_proj = ColumnParallelLayer::new_with_shard(
116            cfg.hidden_size,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            cfg.attention_bias,
120            comm,
121            kv_shard,
122            vb.pp("k_proj"),
123        )?;
124        let v_proj = ColumnParallelLayer::new_with_shard(
125            cfg.hidden_size,
126            num_kv_heads * head_dim,
127            &cfg.quantization_config,
128            cfg.attention_bias,
129            comm,
130            kv_shard,
131            vb.pp("v_proj"),
132        )?;
133        let o_proj = RowParallelLayer::new(
134            num_heads * head_dim,
135            cfg.hidden_size,
136            &cfg.quantization_config,
137            cfg.attention_bias,
138            comm,
139            vb.pp("o_proj"),
140        )?;
141
142        Ok(Self {
143            q_proj,
144            k_proj,
145            v_proj,
146            o_proj,
147            rotary_emb,
148            num_heads: num_heads / comm.world_size(),
149            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
150            head_dim,
151            paged_attn,
152            sdpa_params: SdpaParams {
153                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
154                    cfg.num_key_value_heads,
155                    cfg.num_attention_heads,
156                    comm,
157                ),
158                softcap: None,
159                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
160                sliding_window: cfg.sliding_window,
161            },
162        })
163    }
164
165    #[allow(clippy::too_many_arguments)]
166    fn forward(
167        &self,
168        xs: &Tensor,
169        attention_mask: Option<&Tensor>,
170        seqlen_offsets: &[usize],
171        position_ids: &[usize],
172        kv_cache: &mut KvCache,
173        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
174        flash_params: &FlashParams,
175    ) -> Result<Tensor> {
176        let (b_sz, q_len, _) = xs.dims3()?;
177
178        let original_dtype = xs.dtype();
179        let mut xs = xs.clone();
180        if let Some(t) = self.q_proj.quantized_act_type() {
181            xs = xs.to_dtype(t)?;
182        }
183        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
184        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
185        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
186        if self.q_proj.quantized_act_type().is_some() {
187            q = q.to_dtype(original_dtype)?;
188            k = k.to_dtype(original_dtype)?;
189            v = v.to_dtype(original_dtype)?;
190        }
191
192        let (q, k, v) = if q_len != 1 {
193            let q = q
194                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
195                .transpose(1, 2)?;
196            let k = k
197                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198                .transpose(1, 2)?;
199            let v = v
200                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
201                .transpose(1, 2)?;
202            (q, k, v)
203        } else {
204            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
205            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
206            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
207            (q, k, v)
208        };
209
210        let (q, k) = self
211            .rotary_emb
212            .forward(&q, &k, seqlen_offsets, position_ids)?;
213
214        let mut attn_output = match &self.paged_attn {
215            Some(paged_attn) => match metadata {
216                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
217                    &q,
218                    &k,
219                    &v,
220                    attention_mask,
221                    Some(key_cache),
222                    Some(value_cache),
223                    input_metadata,
224                    &self.sdpa_params,
225                    Some(flash_params),
226                )?,
227                None => {
228                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
229                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
230                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
231                    // Sanity check.
232                    assert!(attention_mask.is_some());
233                    paged_attn.forward(
234                        &q,
235                        &k,
236                        &v,
237                        attention_mask,
238                        None,
239                        None,
240                        &input_metadata,
241                        &self.sdpa_params,
242                        Some(flash_params),
243                    )?
244                }
245            },
246            _ => {
247                let (k, v) = kv_cache.append(&k, &v)?;
248
249                Sdpa.run_attention(
250                    &q,
251                    &k,
252                    &v,
253                    attention_mask,
254                    Some(flash_params),
255                    &self.sdpa_params,
256                )?
257            }
258        };
259
260        if let Some(t) = self.q_proj.quantized_act_type() {
261            attn_output = attn_output.to_dtype(t)?;
262        }
263        attn_output = if attention_mask.is_some() {
264            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
265        } else {
266            attn_output.reshape((b_sz, q_len, ()))?
267        };
268        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
269        if self.q_proj.quantized_act_type().is_some() {
270            res = res.to_dtype(original_dtype)?;
271        }
272        Ok(res)
273    }
274}
275
276#[derive(Clone)]
277struct Mlp {
278    w1: Arc<dyn QuantMethod>,
279    w2: Arc<dyn QuantMethod>,
280    w3: Arc<dyn QuantMethod>,
281    act_fn: Activation,
282}
283
284impl Mlp {
285    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
286        let hidden_size = cfg.hidden_size;
287        let i_size = cfg.intermediate_size;
288
289        let w1 = ColumnParallelLayer::new(
290            hidden_size,
291            i_size,
292            &cfg.quantization_config,
293            false,
294            comm,
295            vb.pp("w1"),
296        )?;
297        let w2 = RowParallelLayer::new(
298            i_size,
299            hidden_size,
300            &cfg.quantization_config,
301            false,
302            comm,
303            vb.pp("w2"),
304        )?;
305        let w3 = ColumnParallelLayer::new(
306            hidden_size,
307            i_size,
308            &cfg.quantization_config,
309            false,
310            comm,
311            vb.pp("w3"),
312        )?;
313
314        Ok(Self {
315            w1,
316            w2,
317            w3,
318            act_fn: cfg.hidden_act,
319        })
320    }
321
322    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
323        let original_dtype = xs.dtype();
324        let mut xs = xs.clone();
325        if let Some(t) = self.w1.quantized_act_type() {
326            xs = xs.to_dtype(t)?;
327        }
328        let mut current_hidden_states =
329            MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
330        let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
331        current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?;
332        let mut res = MatMul.qmethod_matmul(&current_hidden_states, &*self.w2)?;
333        if self.w1.quantized_act_type().is_some() {
334            res = res.to_dtype(original_dtype)?;
335        }
336        Ok(res)
337    }
338}
339
340struct MoeMlp {
341    gate: candle_nn::Linear,
342    experts: Vec<Mlp>,
343    router_jitter_noise: f64,
344    num_experts: usize,
345}
346
347impl MoeMlp {
348    fn new(
349        cfg: &Config,
350        vb: ShardedVarBuilder,
351        layer_device: Device,
352        comm: &Arc<mistralrs_quant::Comm>,
353    ) -> Result<Self> {
354        let num_experts = cfg.num_local_experts;
355        let gate = layers::linear_no_bias(
356            cfg.hidden_size,
357            num_experts,
358            vb.pp("gate").set_device(layer_device),
359        )?;
360
361        let experts_vb = vb.pp("experts");
362        let mut experts = Vec::with_capacity(num_experts);
363        for i in 0..num_experts {
364            experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
365        }
366
367        Ok(Self {
368            gate,
369            experts,
370            router_jitter_noise: cfg.router_jitter_noise,
371            num_experts,
372        })
373    }
374
375    fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
376        // Compute mask for sparsity
377        let selected_experts = scores.argmax_keepdim(D::Minus1)?;
378        let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
379        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
380        let mask_logits_threshold = mask_logits_threshold
381            .broadcast_sub(scores)?
382            .broadcast_div(&factor)?
383            .gt(2. * jitter_eps)?;
384
385        // Apply mask
386        let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
387
388        // Compute scores
389        let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
390        let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
391
392        // Mask out first expert
393        let masked_scores = scores.scatter_add(
394            &selected_experts
395                .broadcast_as(scores.shape())?
396                .contiguous()?,
397            &(scores.ones_like()? * f64::NEG_INFINITY)?,
398            D::Minus1,
399        )?;
400
401        // Compute mask for sparsity
402        let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
403        let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
404        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
405        let mask_logits_threshold = mask_logits_threshold
406            .broadcast_sub(scores)?
407            .broadcast_div(&factor)?
408            .gt(2. * jitter_eps)?;
409
410        // Apply mask
411        let masked_gates_top2 =
412            masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
413        let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
414        let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
415
416        let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
417        let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
418
419        Ok((multiplier, selected_experts))
420    }
421
422    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423        let (bs, seq, hidden) = xs.dims3()?;
424        let xs = xs.reshape(((), hidden))?;
425        let xs_dev = xs.device();
426        let xs = xs.to_device(&Device::Cpu)?;
427
428        // Sparse MoE block accumulates hidden states on CPU, but MLP and gate weights are untouched (maybe on GPU)
429
430        let router_logits = self
431            .gate
432            .forward(&xs.to_device(xs_dev)?)?
433            .to_device(&Device::Cpu)?;
434        let (routing_weights, selected_experts) = self.sparsemixer(
435            &router_logits.to_device(&Device::Cpu)?,
436            self.router_jitter_noise,
437        )?;
438
439        let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs_dev)?;
440
441        // One hot encode the selected experts to create an expert mask
442        // this will be used to easily index which expert to activate
443        let experts_mask =
444            candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
445                .permute((2, 1, 0))?;
446
447        // Loop over all avail experts in the model and perform the computation on each expert
448        for expert_idx in 0..self.num_experts {
449            let expert = &self.experts[expert_idx];
450            let expert_mask = experts_mask.i(expert_idx)?;
451            assert_eq!(expert_mask.rank(), 2);
452            let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
453            let idx = nonzero_mask.i((.., 0))?;
454            let top_x = nonzero_mask.i((.., 1))?;
455
456            if top_x.dim(0)? == 0 {
457                continue;
458            }
459
460            // Index the correct hidden staters and compute the expert hidden state
461            // for the current expert, we need to make sure to multiply the output hidden
462            // states by `routing_weights` on the corresponding tokens (top-1, top-2)
463            let current_state = xs.index_select(&top_x, 0)?.reshape((1, (), hidden))?;
464            let current_routing_weights = routing_weights
465                .index_select(&top_x, 0)?
466                .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
467            let exp_out = expert
468                .forward(&current_state.to_device(xs_dev)?)?
469                .to_device(&Device::Cpu)?;
470
471            let current_hidden_states = exp_out.broadcast_mul(&current_routing_weights)?;
472
473            final_hidden_states = final_hidden_states.index_add(
474                &top_x.contiguous()?.to_device(xs_dev)?,
475                &current_hidden_states
476                    .squeeze(0)?
477                    .to_dtype(xs.dtype())?
478                    .to_device(xs_dev)?,
479                0,
480            )?;
481        }
482
483        final_hidden_states
484            .reshape((bs, seq, hidden))?
485            .to_device(xs_dev)
486    }
487}
488
489struct DecoderLayer {
490    self_attn: Attention,
491    mlp: MoeMlp,
492    input_layernorm: LayerNorm,
493    post_attention_layernorm: LayerNorm,
494}
495
496impl DecoderLayer {
497    #[allow(clippy::too_many_arguments)]
498    fn new(
499        rotary_emb: Arc<PhiRotaryEmbedding>,
500        cfg: &Config,
501        vb: ShardedVarBuilder,
502        mapper: &dyn DeviceMapper,
503        layer_idx: usize,
504        loading_isq: bool,
505        paged_attn: Option<PagedAttention>,
506        real_device: Device,
507        comm: &Arc<mistralrs_quant::Comm>,
508    ) -> Result<Self> {
509        let self_attn = Attention::new(
510            rotary_emb,
511            cfg,
512            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
513            paged_attn,
514            comm,
515        )?;
516        let mlp = MoeMlp::new(
517            cfg,
518            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
519            mapper
520                .device_for(layer_idx, false)
521                .cloned()
522                .unwrap_or(real_device),
523            comm,
524        )?;
525        let input_layernorm = layer_norm(
526            cfg.hidden_size,
527            cfg.rms_norm_eps,
528            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
529        )?;
530        let post_attention_layernorm = layer_norm(
531            cfg.hidden_size,
532            cfg.rms_norm_eps,
533            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
534        )?;
535        Ok(Self {
536            self_attn,
537            mlp,
538            input_layernorm,
539            post_attention_layernorm,
540        })
541    }
542
543    #[allow(clippy::too_many_arguments)]
544    fn forward(
545        &self,
546        xs: &Tensor,
547        attention_mask: Option<&Tensor>,
548        seqlen_offsets: &[usize],
549        position_ids: &[usize],
550        kv_cache: &mut KvCache,
551        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
552        flash_params: &FlashParams,
553    ) -> Result<Tensor> {
554        let residual = xs;
555        let xs = self.input_layernorm.forward(xs)?;
556        let xs = self.self_attn.forward(
557            &xs,
558            attention_mask,
559            seqlen_offsets,
560            position_ids,
561            kv_cache,
562            metadata,
563            flash_params,
564        )?;
565        let xs = (xs + residual)?;
566        let residual = &xs;
567        let xs = self
568            .mlp
569            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
570        residual + xs
571    }
572}
573
574pub struct Model {
575    embed_tokens: candle_nn::Embedding,
576    layers: Vec<DecoderLayer>,
577    norm: LayerNorm,
578    lm_head: Arc<dyn QuantMethod>,
579    device: Device,
580    cache: EitherCache,
581    max_seq_len: usize,
582    mapper: Box<dyn DeviceMapper + Send + Sync>,
583    sliding_window: Option<usize>,
584    cfg: ModelConfigMetadata,
585}
586
587impl Model {
588    pub fn new(
589        cfg: &Config,
590        vb: ShardedVarBuilder,
591        _is_gptx: bool,
592        normal_loading_metadata: NormalLoadingMetadata,
593        attention_mechanism: AttentionImplementation,
594    ) -> Result<Self> {
595        if let Some(ref quant_cfg) = &cfg.quantization_config {
596            tracing::info!(
597                "Using {} quantization: {}.",
598                quant_cfg.name(),
599                quant_cfg.get_bits_name(&vb)
600            );
601        }
602        let mapper = normal_loading_metadata.mapper;
603        let vb_m = vb.pp("model");
604
605        let embed_tokens = layers::embedding(
606            cfg.vocab_size,
607            cfg.hidden_size,
608            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
609            &cfg.quantization_config,
610        )?;
611        let mut ropes = HashMap::new();
612        for layer_idx in 0..cfg.num_hidden_layers {
613            let device = mapper
614                .device_for(layer_idx, false)
615                .unwrap_or(&normal_loading_metadata.real_device);
616            ropes.insert(
617                device.location(),
618                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
619            );
620        }
621        let vb_l = vb_m.pp("layers");
622        let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
623            0..cfg.num_hidden_layers,
624            "Loading repeating layers",
625            &normal_loading_metadata.multi_progress,
626        )
627        .par_iter_if_isq(|layer_idx| {
628            let device = mapper
629                .device_for(layer_idx, false)
630                .unwrap_or(&normal_loading_metadata.real_device);
631            let rotary_emb = ropes
632                .get(&device.location())
633                .expect("No RoPE for device location!")
634                .clone();
635            let paged_attn = match &attention_mechanism {
636                AttentionImplementation::Eager => None,
637                AttentionImplementation::PagedAttention => {
638                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
639                }
640            };
641            let comm = mapper.get_comm_for(layer_idx)?;
642            DecoderLayer::new(
643                rotary_emb.clone(),
644                cfg,
645                vb_l.pp(layer_idx),
646                &*mapper,
647                layer_idx,
648                normal_loading_metadata.loading_isq,
649                paged_attn,
650                normal_loading_metadata.real_device.clone(),
651                &comm,
652            )
653        })?;
654        let norm = layer_norm(
655            cfg.hidden_size,
656            cfg.rms_norm_eps,
657            mapper.set_nm_device(vb_m.pp("norm"), false),
658        )?;
659        let lm_head = if !cfg.tie_word_embeddings {
660            ReplicatedLayer::new(
661                cfg.hidden_size,
662                cfg.vocab_size,
663                &cfg.quantization_config,
664                cfg.lm_head_bias,
665                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
666            )?
667        } else {
668            unreachable!()
669        };
670        Ok(Self {
671            embed_tokens,
672            layers,
673            norm,
674            lm_head,
675            device: normal_loading_metadata.real_device,
676            cache: EitherCache::Normal(NormalCache::new_sliding(
677                cfg.num_hidden_layers,
678                cfg.max_position_embeddings,
679                cfg.sliding_window,
680            )),
681            max_seq_len: cfg.max_position_embeddings,
682            sliding_window: cfg.sliding_window,
683            cfg: ModelConfigMetadata {
684                max_seq_len: cfg.max_position_embeddings,
685                num_layers: cfg.num_hidden_layers,
686                hidden_size: cfg.hidden_size,
687                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
688                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
689                    .max(1),
690                sliding_window: cfg.sliding_window,
691                k_head_dim: cfg.head_dim(),
692                v_head_dim: cfg.head_dim(),
693            },
694            mapper,
695        })
696    }
697
698    pub fn forward(
699        &self,
700        input_ids: &Tensor,
701        seqlen_offsets: &[usize],
702        position_ids: &[usize],
703        context_lens: Vec<(usize, usize)>,
704        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
705        flash_params: &FlashParams,
706    ) -> Result<Tensor> {
707        let mut xs = self.embed_tokens.forward(input_ids)?;
708        let cache = &mut self.cache.normal().0;
709        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
710            input_ids,
711            metadata
712                .as_ref()
713                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
714                .unwrap_or(cache as &dyn PastKvLenCache),
715            self.sliding_window,
716            xs.dtype(),
717            self.cfg.num_attn_heads,
718        )?;
719        // PagedAttention prompt chunking
720        let attention_mask = attention_mask.filter(|_| {
721            metadata
722                .as_ref()
723                .map(|(_, meta)| meta.is_first_prompt_chunk)
724                .unwrap_or(true)
725        });
726
727        for (i, layer) in self.layers.iter().enumerate() {
728            xs = self.mapper.map(xs, i)?;
729            xs = layer.forward(
730                &xs,
731                attention_mask
732                    .as_ref()
733                    .map(|m| m.to_device(xs.device()).unwrap())
734                    .as_ref(),
735                seqlen_offsets,
736                position_ids,
737                &mut cache[i],
738                metadata
739                    .as_ref()
740                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
741                flash_params,
742            )?
743        }
744        let xs = xs.to_device(&self.device)?;
745        let mut xs = xs.apply(&self.norm)?;
746        if let Some(t) = self.lm_head.quantized_act_type() {
747            xs = xs.to_dtype(t)?;
748        }
749        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
750    }
751}
752
753impl IsqModel for Model {
754    fn get_layers(
755        &mut self,
756    ) -> (
757        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
758        &dyn DeviceMapper,
759    ) {
760        let mut tensors = Vec::new();
761        tensors.push((&mut self.lm_head, None));
762        for (i, layer) in self.layers.iter_mut().enumerate() {
763            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
764            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
765            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
766            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
767            for expert in &mut layer.mlp.experts {
768                tensors.push((&mut expert.w1, Some(i)));
769                tensors.push((&mut expert.w2, Some(i)));
770                tensors.push((&mut expert.w3, Some(i)));
771            }
772        }
773        (tensors, &*self.mapper)
774    }
775    fn get_layers_moe_experts_only(
776        &mut self,
777    ) -> (
778        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
779        &dyn DeviceMapper,
780    ) {
781        let mut tensors = Vec::new();
782        tensors.push((&mut self.lm_head, None));
783        for (i, layer) in self.layers.iter_mut().enumerate() {
784            for expert in &mut layer.mlp.experts {
785                tensors.push((&mut expert.w1, Some(i)));
786                tensors.push((&mut expert.w2, Some(i)));
787                tensors.push((&mut expert.w3, Some(i)));
788            }
789        }
790        (tensors, &*self.mapper)
791    }
792
793    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
794        let uvb = UnVarBuilder::new();
795
796        let uvb_m = uvb.pp("model");
797        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
798        uvb_m.pp("norm").add(&self.norm);
799
800        for (layer_idx, layer) in self.layers.iter().enumerate() {
801            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
802            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
803            uvb_l
804                .pp("post_attention_layernorm")
805                .add(&layer.post_attention_layernorm);
806        }
807
808        uvb.to_safetensors()
809    }
810
811    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
812        let uvb = UnVarBuilder::new();
813
814        let uvb_m = uvb.pp("model");
815        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
816        uvb_m.pp("norm").add(&self.norm);
817
818        for (layer_idx, layer) in self.layers.iter().enumerate() {
819            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
820            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
821            uvb_l
822                .pp("post_attention_layernorm")
823                .add(&layer.post_attention_layernorm);
824
825            let uvb_attn = uvb_l.pp("self_attn");
826            uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
827            uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
828            uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
829            uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
830        }
831
832        Some(uvb.to_safetensors())
833    }
834}
835
836impl NormalModel for Model {
837    fn forward(
838        &self,
839        input_ids: &Tensor,
840        seqlen_offsets: &[usize],
841        context_lens: Vec<(usize, usize)>,
842        position_ids: Vec<usize>,
843        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
844        flash_params: &FlashParams,
845    ) -> Result<Tensor> {
846        self.forward(
847            input_ids,
848            seqlen_offsets,
849            &position_ids,
850            context_lens,
851            metadata,
852            flash_params,
853        )
854    }
855    fn xlora_forward(
856        &self,
857        _input_ids: &Tensor,
858        _input_ids_full: &Tensor,
859        _seqlen_offsets: &[usize],
860        _seqlen_offsets_full: &[usize],
861        _no_kv_cache: bool,
862        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
863        _context_lens: Vec<(usize, usize)>,
864        _position_ids: Vec<usize>,
865        _flash_params: &FlashParams,
866        _flash_params_full: &FlashParams,
867    ) -> Result<Tensor> {
868        unimplemented!()
869    }
870    fn cache(&self) -> &EitherCache {
871        &self.cache
872    }
873    fn cache_mut(&mut self) -> &mut EitherCache {
874        &mut self.cache
875    }
876    fn device(&self) -> &Device {
877        &self.device
878    }
879    fn is_xlora(&self) -> bool {
880        false
881    }
882    fn max_seq_len(&self) -> usize {
883        self.max_seq_len
884    }
885    fn config(&self) -> &ModelConfigMetadata {
886        &self.cfg
887    }
888}
889
890impl AnyMoeBaseModelMixin for Model {}