mistralrs_core/models/
phi3_5_moe.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
6use candle_nn::LayerNorm;
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9    ShardedVarBuilder,
10};
11use std::{collections::HashMap, sync::Arc};
12
13use crate::{
14    amoe::AnyMoeBaseModelMixin,
15    attention::SdpaParams,
16    device_map::DeviceMapper,
17    layers::{
18        self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19        PhiRotaryEmbedding, Sdpa,
20    },
21    layers_masker::{masked_fill, PastKvLenCache},
22    ops::NonZeroOp,
23    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24    pipeline::{
25        extract_logits,
26        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28    },
29    serde_default_fn,
30    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32
33serde_default_fn!(bool, word_emb_default, false);
34
35// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
36#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
37pub struct Config {
38    pub(crate) vocab_size: usize,
39    pub(crate) hidden_act: Activation,
40    pub(crate) hidden_size: usize,
41    pub(crate) intermediate_size: usize,
42    pub(crate) num_hidden_layers: usize,
43    pub(crate) num_attention_heads: usize,
44    pub(crate) num_key_value_heads: usize,
45    pub(crate) rms_norm_eps: f64,
46    pub(crate) rope_theta: f64,
47    pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
48    pub(crate) max_position_embeddings: usize,
49    pub(crate) use_flash_attn: bool,
50    pub(crate) sliding_window: Option<usize>,
51    pub(crate) original_max_position_embeddings: usize,
52    pub(crate) quantization_config: Option<QuantizedConfig>,
53    pub(crate) lm_head_bias: bool,
54    pub(crate) attention_bias: bool,
55    pub(crate) num_local_experts: usize,
56    pub(crate) router_jitter_noise: f64,
57    #[serde(default = "word_emb_default")]
58    pub(crate) tie_word_embeddings: bool,
59}
60
61impl From<Config> for PhiRopeConfig {
62    fn from(val: Config) -> Self {
63        PhiRopeConfig {
64            rope_scaling: val.rope_scaling,
65            max_position_embeddings: val.max_position_embeddings,
66            original_max_position_embeddings: val.original_max_position_embeddings,
67            rope_theta: val.rope_theta,
68            head_dim: val.hidden_size / val.num_attention_heads,
69            partial_rotary_factor: None,
70        }
71    }
72}
73
74impl Config {
75    pub fn head_dim(&self) -> usize {
76        self.hidden_size / self.num_attention_heads
77    }
78}
79
80struct Attention {
81    q_proj: Arc<dyn QuantMethod>,
82    k_proj: Arc<dyn QuantMethod>,
83    v_proj: Arc<dyn QuantMethod>,
84    o_proj: Arc<dyn QuantMethod>,
85    num_heads: usize,
86    num_kv_heads: usize,
87    head_dim: usize,
88    rotary_emb: Arc<PhiRotaryEmbedding>,
89    paged_attn: Option<PagedAttention>,
90    sdpa_params: SdpaParams,
91}
92
93impl Attention {
94    fn new(
95        rotary_emb: Arc<PhiRotaryEmbedding>,
96        cfg: &Config,
97        vb: ShardedVarBuilder,
98        paged_attn: Option<PagedAttention>,
99        comm: &Arc<mistralrs_quant::Comm>,
100    ) -> Result<Self> {
101        let num_heads = cfg.num_attention_heads;
102        let num_kv_heads = cfg.num_key_value_heads;
103        let head_dim = cfg.head_dim();
104
105        let q_proj = ColumnParallelLayer::new(
106            cfg.hidden_size,
107            num_heads * head_dim,
108            &cfg.quantization_config,
109            cfg.attention_bias,
110            comm,
111            vb.pp("q_proj"),
112        )?;
113        let kv_shard = mistralrs_quant::compute_kv_shard(
114            cfg.num_key_value_heads,
115            cfg.hidden_size / cfg.num_attention_heads,
116            comm,
117        );
118        let k_proj = ColumnParallelLayer::new_with_shard(
119            cfg.hidden_size,
120            num_kv_heads * head_dim,
121            &cfg.quantization_config,
122            cfg.attention_bias,
123            comm,
124            kv_shard,
125            vb.pp("k_proj"),
126        )?;
127        let v_proj = ColumnParallelLayer::new_with_shard(
128            cfg.hidden_size,
129            num_kv_heads * head_dim,
130            &cfg.quantization_config,
131            cfg.attention_bias,
132            comm,
133            kv_shard,
134            vb.pp("v_proj"),
135        )?;
136        let o_proj = RowParallelLayer::new(
137            num_heads * head_dim,
138            cfg.hidden_size,
139            &cfg.quantization_config,
140            cfg.attention_bias,
141            comm,
142            vb.pp("o_proj"),
143        )?;
144
145        Ok(Self {
146            q_proj,
147            k_proj,
148            v_proj,
149            o_proj,
150            rotary_emb,
151            num_heads: num_heads / comm.world_size(),
152            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
153            head_dim,
154            paged_attn,
155            sdpa_params: SdpaParams {
156                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
157                    cfg.num_key_value_heads,
158                    cfg.num_attention_heads,
159                    comm,
160                ),
161                use_flash_attn: cfg.use_flash_attn,
162                softcap: None,
163                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
164                sliding_window: cfg.sliding_window,
165            },
166        })
167    }
168
169    #[allow(clippy::too_many_arguments)]
170    fn forward(
171        &self,
172        xs: &Tensor,
173        attention_mask: Option<&Tensor>,
174        seqlen_offsets: &[usize],
175        position_ids: &[usize],
176        kv_cache: &mut KvCache,
177        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
178        flash_params: &FlashParams,
179    ) -> Result<Tensor> {
180        let (b_sz, q_len, _) = xs.dims3()?;
181
182        let original_dtype = xs.dtype();
183        let mut xs = xs.clone();
184        if let Some(t) = self.q_proj.quantized_act_type() {
185            xs = xs.to_dtype(t)?;
186        }
187        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
188        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
189        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
190        if self.q_proj.quantized_act_type().is_some() {
191            q = q.to_dtype(original_dtype)?;
192            k = k.to_dtype(original_dtype)?;
193            v = v.to_dtype(original_dtype)?;
194        }
195
196        let (q, k, v) = if q_len != 1 {
197            let q = q
198                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
199                .transpose(1, 2)?;
200            let k = k
201                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
202                .transpose(1, 2)?;
203            let v = v
204                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
205                .transpose(1, 2)?;
206            (q, k, v)
207        } else {
208            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
209            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
210            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
211            (q, k, v)
212        };
213
214        let (q, k) = self
215            .rotary_emb
216            .forward(&q, &k, seqlen_offsets, position_ids)?;
217
218        let mut attn_output = match &self.paged_attn {
219            Some(paged_attn) => match metadata {
220                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
221                    &q,
222                    &k,
223                    &v,
224                    attention_mask,
225                    Some(key_cache),
226                    Some(value_cache),
227                    input_metadata,
228                    &self.sdpa_params,
229                    Some(flash_params),
230                )?,
231                None => {
232                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
233                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
234                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
235                    // Sanity check.
236                    assert!(attention_mask.is_some());
237                    paged_attn.forward(
238                        &q,
239                        &k,
240                        &v,
241                        attention_mask,
242                        None,
243                        None,
244                        &input_metadata,
245                        &self.sdpa_params,
246                        Some(flash_params),
247                    )?
248                }
249            },
250            _ => {
251                let (k, v) = kv_cache.append(&k, &v)?;
252
253                Sdpa.run_attention(
254                    &q,
255                    &k,
256                    &v,
257                    attention_mask,
258                    Some(flash_params),
259                    &self.sdpa_params,
260                )?
261            }
262        };
263
264        if let Some(t) = self.q_proj.quantized_act_type() {
265            attn_output = attn_output.to_dtype(t)?;
266        }
267        attn_output = if attention_mask.is_some() {
268            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
269        } else {
270            attn_output.reshape((b_sz, q_len, ()))?
271        };
272        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
273        if self.q_proj.quantized_act_type().is_some() {
274            res = res.to_dtype(original_dtype)?;
275        }
276        Ok(res)
277    }
278}
279
280#[derive(Clone)]
281struct Mlp {
282    w1: Arc<dyn QuantMethod>,
283    w2: Arc<dyn QuantMethod>,
284    w3: Arc<dyn QuantMethod>,
285    act_fn: Activation,
286}
287
288impl Mlp {
289    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
290        let hidden_size = cfg.hidden_size;
291        let i_size = cfg.intermediate_size;
292
293        let w1 = ColumnParallelLayer::new(
294            hidden_size,
295            i_size,
296            &cfg.quantization_config,
297            false,
298            comm,
299            vb.pp("w1"),
300        )?;
301        let w2 = RowParallelLayer::new(
302            i_size,
303            hidden_size,
304            &cfg.quantization_config,
305            false,
306            comm,
307            vb.pp("w2"),
308        )?;
309        let w3 = ColumnParallelLayer::new(
310            hidden_size,
311            i_size,
312            &cfg.quantization_config,
313            false,
314            comm,
315            vb.pp("w3"),
316        )?;
317
318        Ok(Self {
319            w1,
320            w2,
321            w3,
322            act_fn: cfg.hidden_act,
323        })
324    }
325
326    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
327        let original_dtype = xs.dtype();
328        let mut xs = xs.clone();
329        if let Some(t) = self.w1.quantized_act_type() {
330            xs = xs.to_dtype(t)?;
331        }
332        let mut current_hidden_states =
333            MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
334        let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
335        current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?;
336        let mut res = MatMul.qmethod_matmul(&current_hidden_states, &*self.w2)?;
337        if self.w1.quantized_act_type().is_some() {
338            res = res.to_dtype(original_dtype)?;
339        }
340        Ok(res)
341    }
342}
343
344struct MoeMlp {
345    gate: candle_nn::Linear,
346    experts: Vec<Mlp>,
347    router_jitter_noise: f64,
348    num_experts: usize,
349}
350
351impl MoeMlp {
352    fn new(
353        cfg: &Config,
354        vb: ShardedVarBuilder,
355        layer_device: Device,
356        comm: &Arc<mistralrs_quant::Comm>,
357    ) -> Result<Self> {
358        let num_experts = cfg.num_local_experts;
359        let gate = layers::linear_no_bias(
360            cfg.hidden_size,
361            num_experts,
362            vb.pp("gate").set_device(layer_device),
363        )?;
364
365        let experts_vb = vb.pp("experts");
366        let mut experts = Vec::with_capacity(num_experts);
367        for i in 0..num_experts {
368            experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
369        }
370
371        Ok(Self {
372            gate,
373            experts,
374            router_jitter_noise: cfg.router_jitter_noise,
375            num_experts,
376        })
377    }
378
379    fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
380        // Compute mask for sparsity
381        let selected_experts = scores.argmax_keepdim(D::Minus1)?;
382        let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
383        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
384        let mask_logits_threshold = mask_logits_threshold
385            .broadcast_sub(scores)?
386            .broadcast_div(&factor)?
387            .gt(2. * jitter_eps)?;
388
389        // Apply mask
390        let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
391
392        // Compute scores
393        let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
394        let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
395
396        // Mask out first expert
397        let masked_scores = scores.scatter_add(
398            &selected_experts
399                .broadcast_as(scores.shape())?
400                .contiguous()?,
401            &(scores.ones_like()? * f64::NEG_INFINITY)?,
402            D::Minus1,
403        )?;
404
405        // Compute mask for sparsity
406        let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
407        let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
408        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
409        let mask_logits_threshold = mask_logits_threshold
410            .broadcast_sub(scores)?
411            .broadcast_div(&factor)?
412            .gt(2. * jitter_eps)?;
413
414        // Apply mask
415        let masked_gates_top2 =
416            masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
417        let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
418        let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
419
420        let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
421        let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
422
423        Ok((multiplier, selected_experts))
424    }
425
426    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
427        let (bs, seq, hidden) = xs.dims3()?;
428        let xs = xs.reshape(((), hidden))?;
429        let xs_dev = xs.device();
430        let xs = xs.to_device(&Device::Cpu)?;
431
432        // Sparse MoE block accumulates hidden states on CPU, but MLP and gate weights are untouched (maybe on GPU)
433
434        let router_logits = self
435            .gate
436            .forward(&xs.to_device(xs_dev)?)?
437            .to_device(&Device::Cpu)?;
438        let (routing_weights, selected_experts) = self.sparsemixer(
439            &router_logits.to_device(&Device::Cpu)?,
440            self.router_jitter_noise,
441        )?;
442
443        let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs.device())?;
444
445        // One hot encode the selected experts to create an expert mask
446        // this will be used to easily index which expert to activate
447        let experts_mask =
448            candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
449                .permute((2, 1, 0))?;
450
451        // Loop over all avail experts in the model and perform the computation on each expert
452        for expert_idx in 0..self.num_experts {
453            let expert = &self.experts[expert_idx];
454            let expert_mask = experts_mask.i(expert_idx)?;
455            assert_eq!(expert_mask.rank(), 2);
456            let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
457            let idx = nonzero_mask.i((.., 0))?;
458            let top_x = nonzero_mask.i((.., 1))?;
459
460            if top_x.dim(0)? == 0 {
461                continue;
462            }
463
464            // Index the correct hidden staters and compute the expert hidden state
465            // for the current expert, we need to make sure to multiply the output hidden
466            // states by `routing_weights` on the corresponding tokens (top-1, top-2)
467            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden))?;
468            let current_routing_weights = routing_weights
469                .index_select(&top_x, 0)?
470                .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
471            let exp_out = expert
472                .forward(&current_state.to_device(xs_dev)?)?
473                .to_device(&Device::Cpu)?;
474
475            let current_hidden_states = exp_out.broadcast_mul(&current_routing_weights)?;
476
477            final_hidden_states = final_hidden_states.index_add(
478                &top_x.contiguous()?,
479                &current_hidden_states.to_dtype(xs.dtype())?,
480                0,
481            )?;
482        }
483
484        final_hidden_states
485            .reshape((bs, seq, hidden))?
486            .to_device(xs_dev)
487    }
488}
489
490struct DecoderLayer {
491    self_attn: Attention,
492    mlp: MoeMlp,
493    input_layernorm: LayerNorm,
494    post_attention_layernorm: LayerNorm,
495}
496
497impl DecoderLayer {
498    #[allow(clippy::too_many_arguments)]
499    fn new(
500        rotary_emb: Arc<PhiRotaryEmbedding>,
501        cfg: &Config,
502        vb: ShardedVarBuilder,
503        mapper: &dyn DeviceMapper,
504        layer_idx: usize,
505        loading_isq: bool,
506        paged_attn: Option<PagedAttention>,
507        real_device: Device,
508        comm: &Arc<mistralrs_quant::Comm>,
509    ) -> Result<Self> {
510        let self_attn = Attention::new(
511            rotary_emb,
512            cfg,
513            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
514            paged_attn,
515            comm,
516        )?;
517        let mlp = MoeMlp::new(
518            cfg,
519            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
520            mapper
521                .device_for(layer_idx, false)
522                .cloned()
523                .unwrap_or(real_device),
524            comm,
525        )?;
526        let input_layernorm = layer_norm(
527            cfg.hidden_size,
528            cfg.rms_norm_eps,
529            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
530        )?;
531        let post_attention_layernorm = layer_norm(
532            cfg.hidden_size,
533            cfg.rms_norm_eps,
534            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
535        )?;
536        Ok(Self {
537            self_attn,
538            mlp,
539            input_layernorm,
540            post_attention_layernorm,
541        })
542    }
543
544    #[allow(clippy::too_many_arguments)]
545    fn forward(
546        &self,
547        xs: &Tensor,
548        attention_mask: Option<&Tensor>,
549        seqlen_offsets: &[usize],
550        position_ids: &[usize],
551        kv_cache: &mut KvCache,
552        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
553        flash_params: &FlashParams,
554    ) -> Result<Tensor> {
555        let residual = xs;
556        let xs = self.input_layernorm.forward(xs)?;
557        let xs = self.self_attn.forward(
558            &xs,
559            attention_mask,
560            seqlen_offsets,
561            position_ids,
562            kv_cache,
563            metadata,
564            flash_params,
565        )?;
566        let xs = (xs + residual)?;
567        let residual = &xs;
568        let xs = self
569            .mlp
570            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
571        residual + xs
572    }
573}
574
575pub struct Model {
576    embed_tokens: candle_nn::Embedding,
577    layers: Vec<DecoderLayer>,
578    norm: LayerNorm,
579    lm_head: Arc<dyn QuantMethod>,
580    device: Device,
581    cache: EitherCache,
582    max_seq_len: usize,
583    mapper: Box<dyn DeviceMapper + Send + Sync>,
584    sliding_window: Option<usize>,
585    cfg: ModelConfigMetadata,
586}
587
588impl Model {
589    pub fn new(
590        cfg: &Config,
591        vb: ShardedVarBuilder,
592        _is_gptx: bool,
593        normal_loading_metadata: NormalLoadingMetadata,
594        attention_mechanism: AttentionImplementation,
595    ) -> Result<Self> {
596        if let Some(ref quant_cfg) = &cfg.quantization_config {
597            tracing::info!(
598                "Using {} quantization: {}.",
599                quant_cfg.name(),
600                quant_cfg.get_bits_name(&vb)
601            );
602        }
603        let mapper = normal_loading_metadata.mapper;
604        let vb_m = vb.pp("model");
605
606        let embed_tokens = layers::embedding(
607            cfg.vocab_size,
608            cfg.hidden_size,
609            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
610            &cfg.quantization_config,
611        )?;
612        let mut ropes = HashMap::new();
613        for layer_idx in 0..cfg.num_hidden_layers {
614            let device = mapper
615                .device_for(layer_idx, false)
616                .unwrap_or(&normal_loading_metadata.real_device);
617            ropes.insert(
618                device.location(),
619                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
620            );
621        }
622        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
623        let vb_l = vb_m.pp("layers");
624        for layer_idx in NiceProgressBar::<_, 'b'>(
625            0..cfg.num_hidden_layers,
626            "Loading repeating layers",
627            &normal_loading_metadata.multi_progress,
628        ) {
629            let device = mapper
630                .device_for(layer_idx, false)
631                .unwrap_or(&normal_loading_metadata.real_device);
632            let rotary_emb = ropes
633                .get(&device.location())
634                .expect("No RoPE for device location!")
635                .clone();
636            let paged_attn = match &attention_mechanism {
637                AttentionImplementation::Eager => None,
638                AttentionImplementation::PagedAttention => {
639                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
640                }
641            };
642            let comm = mapper.get_comm_for(layer_idx)?;
643            let layer = DecoderLayer::new(
644                rotary_emb.clone(),
645                cfg,
646                vb_l.pp(layer_idx),
647                &*mapper,
648                layer_idx,
649                normal_loading_metadata.loading_isq,
650                paged_attn,
651                normal_loading_metadata.real_device.clone(),
652                &comm,
653            )?;
654            layers.push(layer)
655        }
656        let norm = layer_norm(
657            cfg.hidden_size,
658            cfg.rms_norm_eps,
659            mapper.set_nm_device(vb_m.pp("norm"), false),
660        )?;
661        let lm_head = if !cfg.tie_word_embeddings {
662            ReplicatedLayer::new(
663                cfg.hidden_size,
664                cfg.vocab_size,
665                &None,
666                cfg.lm_head_bias,
667                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
668            )?
669        } else {
670            unreachable!()
671        };
672        Ok(Self {
673            embed_tokens,
674            layers,
675            norm,
676            lm_head,
677            device: normal_loading_metadata.real_device,
678            cache: EitherCache::Normal(NormalCache::new_sliding(
679                cfg.num_hidden_layers,
680                cfg.max_position_embeddings,
681                cfg.sliding_window,
682            )),
683            max_seq_len: cfg.max_position_embeddings,
684            sliding_window: cfg.sliding_window,
685            cfg: ModelConfigMetadata {
686                max_seq_len: cfg.max_position_embeddings,
687                num_layers: cfg.num_hidden_layers,
688                hidden_size: cfg.hidden_size,
689                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
690                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
691                    .max(1),
692                sliding_window: cfg.sliding_window,
693                k_head_dim: cfg.head_dim(),
694                v_head_dim: cfg.head_dim(),
695            },
696            mapper,
697        })
698    }
699
700    pub fn forward(
701        &self,
702        input_ids: &Tensor,
703        seqlen_offsets: &[usize],
704        position_ids: &[usize],
705        context_lens: Vec<(usize, usize)>,
706        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
707        flash_params: &FlashParams,
708    ) -> Result<Tensor> {
709        let mut xs = self.embed_tokens.forward(input_ids)?;
710        let cache = &mut self.cache.normal().0;
711        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
712            input_ids,
713            metadata
714                .as_ref()
715                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
716                .unwrap_or(cache as &dyn PastKvLenCache),
717            self.sliding_window,
718            xs.dtype(),
719            self.cfg.num_attn_heads,
720        )?;
721        // PagedAttention prompt chunking
722        let attention_mask = attention_mask.filter(|_| {
723            metadata
724                .as_ref()
725                .map(|(_, meta)| meta.is_first_prompt_chunk)
726                .unwrap_or(true)
727        });
728
729        for (i, layer) in self.layers.iter().enumerate() {
730            xs = self.mapper.map(xs, i)?;
731            xs = layer.forward(
732                &xs,
733                attention_mask
734                    .as_ref()
735                    .map(|m| m.to_device(xs.device()).unwrap())
736                    .as_ref(),
737                seqlen_offsets,
738                position_ids,
739                &mut cache[i],
740                metadata
741                    .as_ref()
742                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
743                flash_params,
744            )?
745        }
746        let xs = xs.to_device(&self.device)?;
747        let mut xs = xs.apply(&self.norm)?;
748        if let Some(t) = self.lm_head.quantized_act_type() {
749            xs = xs.to_dtype(t)?;
750        }
751        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
752    }
753}
754
755impl IsqModel for Model {
756    fn get_layers(
757        &mut self,
758    ) -> (
759        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
760        &dyn DeviceMapper,
761    ) {
762        let mut tensors = Vec::new();
763        tensors.push((&mut self.lm_head, None));
764        for (i, layer) in self.layers.iter_mut().enumerate() {
765            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
766            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
767            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
768            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
769            for expert in &mut layer.mlp.experts {
770                tensors.push((&mut expert.w1, Some(i)));
771                tensors.push((&mut expert.w2, Some(i)));
772                tensors.push((&mut expert.w3, Some(i)));
773            }
774        }
775        (tensors, &*self.mapper)
776    }
777    fn get_layers_moe_experts_only(
778        &mut self,
779    ) -> (
780        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
781        &dyn DeviceMapper,
782    ) {
783        let mut tensors = Vec::new();
784        tensors.push((&mut self.lm_head, None));
785        for (i, layer) in self.layers.iter_mut().enumerate() {
786            for expert in &mut layer.mlp.experts {
787                tensors.push((&mut expert.w1, Some(i)));
788                tensors.push((&mut expert.w2, Some(i)));
789                tensors.push((&mut expert.w3, Some(i)));
790            }
791        }
792        (tensors, &*self.mapper)
793    }
794
795    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
796        let uvb = UnVarBuilder::new();
797
798        let uvb_m = uvb.pp("model");
799        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
800        uvb_m.pp("norm").add(&self.norm);
801
802        for (layer_idx, layer) in self.layers.iter().enumerate() {
803            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
804            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
805            uvb_l
806                .pp("post_attention_layernorm")
807                .add(&layer.post_attention_layernorm);
808        }
809
810        uvb.to_safetensors()
811    }
812
813    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
814        let uvb = UnVarBuilder::new();
815
816        let uvb_m = uvb.pp("model");
817        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
818        uvb_m.pp("norm").add(&self.norm);
819
820        for (layer_idx, layer) in self.layers.iter().enumerate() {
821            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
822            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
823            uvb_l
824                .pp("post_attention_layernorm")
825                .add(&layer.post_attention_layernorm);
826
827            let uvb_attn = uvb_l.pp("self_attn");
828            uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
829            uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
830            uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
831            uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
832        }
833
834        Some(uvb.to_safetensors())
835    }
836}
837
838impl NormalModel for Model {
839    fn forward(
840        &self,
841        input_ids: &Tensor,
842        seqlen_offsets: &[usize],
843        context_lens: Vec<(usize, usize)>,
844        position_ids: Vec<usize>,
845        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
846        flash_params: &FlashParams,
847    ) -> Result<Tensor> {
848        self.forward(
849            input_ids,
850            seqlen_offsets,
851            &position_ids,
852            context_lens,
853            metadata,
854            flash_params,
855        )
856    }
857    fn xlora_forward(
858        &self,
859        _input_ids: &Tensor,
860        _input_ids_full: &Tensor,
861        _seqlen_offsets: &[usize],
862        _seqlen_offsets_full: &[usize],
863        _no_kv_cache: bool,
864        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
865        _context_lens: Vec<(usize, usize)>,
866        _position_ids: Vec<usize>,
867        _flash_params: &FlashParams,
868        _flash_params_full: &FlashParams,
869    ) -> Result<Tensor> {
870        unimplemented!()
871    }
872    fn cache(&self) -> &EitherCache {
873        &self.cache
874    }
875    fn cache_mut(&mut self) -> &mut EitherCache {
876        &mut self.cache
877    }
878    fn device(&self) -> &Device {
879        &self.device
880    }
881    fn is_xlora(&self) -> bool {
882        false
883    }
884    fn max_seq_len(&self) -> usize {
885        self.max_seq_len
886    }
887    fn config(&self) -> &ModelConfigMetadata {
888        &self.cfg
889    }
890}
891
892impl AnyMoeBaseModelMixin for Model {}