mistralrs_core/models/
phi3_5_moe.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{
6    ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
7    RowParallelLayer, ShardedVarBuilder,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12    amoe::AnyMoeBaseModelMixin,
13    attention::SdpaParams,
14    device_map::DeviceMapper,
15    layers::{
16        self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
17        PhiRotaryEmbedding, Sdpa,
18    },
19    layers_masker::{masked_fill, PastKvLenCache},
20    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21    pipeline::{
22        extract_logits,
23        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25    },
26    serde_default_fn,
27    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
33#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
34pub struct Config {
35    pub(crate) vocab_size: usize,
36    pub(crate) hidden_act: Activation,
37    pub(crate) hidden_size: usize,
38    pub(crate) intermediate_size: usize,
39    pub(crate) num_hidden_layers: usize,
40    pub(crate) num_attention_heads: usize,
41    pub(crate) num_key_value_heads: usize,
42    pub(crate) rms_norm_eps: f64,
43    pub(crate) rope_theta: f64,
44    pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
45    pub(crate) max_position_embeddings: usize,
46    pub(crate) sliding_window: Option<usize>,
47    pub(crate) original_max_position_embeddings: usize,
48
49    pub(crate) quantization_config: Option<QuantizedConfig>,
50    pub(crate) lm_head_bias: bool,
51    pub(crate) attention_bias: bool,
52    pub(crate) num_local_experts: usize,
53    pub(crate) router_jitter_noise: f64,
54    #[serde(default = "word_emb_default")]
55    pub(crate) tie_word_embeddings: bool,
56}
57
58impl From<Config> for PhiRopeConfig {
59    fn from(val: Config) -> Self {
60        PhiRopeConfig {
61            rope_scaling: val.rope_scaling,
62            max_position_embeddings: val.max_position_embeddings,
63            original_max_position_embeddings: val.original_max_position_embeddings,
64            rope_theta: val.rope_theta,
65            head_dim: val.hidden_size / val.num_attention_heads,
66            partial_rotary_factor: None,
67        }
68    }
69}
70
71impl Config {
72    pub fn head_dim(&self) -> usize {
73        self.hidden_size / self.num_attention_heads
74    }
75}
76
77struct Attention {
78    q_proj: Arc<dyn QuantMethod>,
79    k_proj: Arc<dyn QuantMethod>,
80    v_proj: Arc<dyn QuantMethod>,
81    o_proj: Arc<dyn QuantMethod>,
82    num_heads: usize,
83    num_kv_heads: usize,
84    head_dim: usize,
85    rotary_emb: Arc<PhiRotaryEmbedding>,
86    paged_attn: Option<PagedAttention>,
87    sdpa_params: SdpaParams,
88}
89
90impl Attention {
91    fn new(
92        rotary_emb: Arc<PhiRotaryEmbedding>,
93        cfg: &Config,
94        vb: ShardedVarBuilder,
95        paged_attn: Option<PagedAttention>,
96        comm: &Arc<mistralrs_quant::Comm>,
97    ) -> Result<Self> {
98        let num_heads = cfg.num_attention_heads;
99        let num_kv_heads = cfg.num_key_value_heads;
100        let head_dim = cfg.head_dim();
101
102        let q_proj = ColumnParallelLayer::new(
103            cfg.hidden_size,
104            num_heads * head_dim,
105            &cfg.quantization_config,
106            cfg.attention_bias,
107            comm,
108            vb.pp("q_proj"),
109        )?;
110        let kv_shard = mistralrs_quant::compute_kv_shard(
111            cfg.num_key_value_heads,
112            cfg.hidden_size / cfg.num_attention_heads,
113            comm,
114        );
115        let k_proj = ColumnParallelLayer::new_with_shard(
116            cfg.hidden_size,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            cfg.attention_bias,
120            comm,
121            kv_shard,
122            vb.pp("k_proj"),
123        )?;
124        let v_proj = ColumnParallelLayer::new_with_shard(
125            cfg.hidden_size,
126            num_kv_heads * head_dim,
127            &cfg.quantization_config,
128            cfg.attention_bias,
129            comm,
130            kv_shard,
131            vb.pp("v_proj"),
132        )?;
133        let o_proj = RowParallelLayer::new(
134            num_heads * head_dim,
135            cfg.hidden_size,
136            &cfg.quantization_config,
137            cfg.attention_bias,
138            comm,
139            vb.pp("o_proj"),
140        )?;
141
142        Ok(Self {
143            q_proj,
144            k_proj,
145            v_proj,
146            o_proj,
147            rotary_emb,
148            num_heads: num_heads / comm.world_size(),
149            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
150            head_dim,
151            paged_attn,
152            sdpa_params: SdpaParams {
153                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
154                    cfg.num_key_value_heads,
155                    cfg.num_attention_heads,
156                    comm,
157                ),
158                softcap: None,
159                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
160                sliding_window: cfg.sliding_window,
161            },
162        })
163    }
164
165    #[allow(clippy::too_many_arguments)]
166    fn forward(
167        &self,
168        xs: &Tensor,
169        attention_mask: Option<&Tensor>,
170        seqlen_offsets: &[usize],
171        position_ids: &[usize],
172        kv_cache: &mut KvCache,
173        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
174        flash_params: &FlashParams,
175    ) -> Result<Tensor> {
176        let (b_sz, q_len, _) = xs.dims3()?;
177
178        let original_dtype = xs.dtype();
179        let mut xs = xs.clone();
180        if let Some(t) = self.q_proj.quantized_act_type() {
181            xs = xs.to_dtype(t)?;
182        }
183        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
184        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
185        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
186        if self.q_proj.quantized_act_type().is_some() {
187            q = q.to_dtype(original_dtype)?;
188            k = k.to_dtype(original_dtype)?;
189            v = v.to_dtype(original_dtype)?;
190        }
191
192        let (q, k, v) = if q_len != 1 {
193            let q = q
194                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
195                .transpose(1, 2)?;
196            let k = k
197                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198                .transpose(1, 2)?;
199            let v = v
200                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
201                .transpose(1, 2)?;
202            (q, k, v)
203        } else {
204            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
205            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
206            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
207            (q, k, v)
208        };
209
210        let (q, k) = self
211            .rotary_emb
212            .forward(&q, &k, seqlen_offsets, position_ids)?;
213
214        let mut attn_output = match &self.paged_attn {
215            Some(paged_attn) => match metadata {
216                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
217                    &q,
218                    &k,
219                    &v,
220                    attention_mask,
221                    Some(key_cache),
222                    Some(value_cache),
223                    input_metadata,
224                    &self.sdpa_params,
225                    Some(flash_params),
226                )?,
227                None => {
228                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
229                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
230                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
231                    // Sanity check.
232                    assert!(attention_mask.is_some());
233                    paged_attn.forward(
234                        &q,
235                        &k,
236                        &v,
237                        attention_mask,
238                        None,
239                        None,
240                        &input_metadata,
241                        &self.sdpa_params,
242                        Some(flash_params),
243                    )?
244                }
245            },
246            _ => {
247                let (k, v) = kv_cache.append(&k, &v)?;
248
249                Sdpa.run_attention(
250                    &q,
251                    &k,
252                    &v,
253                    attention_mask,
254                    Some(flash_params),
255                    &self.sdpa_params,
256                )?
257            }
258        };
259
260        if let Some(t) = self.q_proj.quantized_act_type() {
261            attn_output = attn_output.to_dtype(t)?;
262        }
263        attn_output = if attention_mask.is_some() {
264            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
265        } else {
266            attn_output.reshape((b_sz, q_len, ()))?
267        };
268        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
269        if self.q_proj.quantized_act_type().is_some() {
270            res = res.to_dtype(original_dtype)?;
271        }
272        Ok(res)
273    }
274}
275
276#[derive(Clone)]
277struct Mlp {
278    w1: Arc<dyn QuantMethod>,
279    w2: Arc<dyn QuantMethod>,
280    w3: Arc<dyn QuantMethod>,
281    act_fn: Activation,
282}
283
284impl Mlp {
285    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
286        let hidden_size = cfg.hidden_size;
287        let i_size = cfg.intermediate_size;
288
289        let w1 = ColumnParallelLayer::new(
290            hidden_size,
291            i_size,
292            &cfg.quantization_config,
293            false,
294            comm,
295            vb.pp("w1"),
296        )?;
297        let w2 = RowParallelLayer::new(
298            i_size,
299            hidden_size,
300            &cfg.quantization_config,
301            false,
302            comm,
303            vb.pp("w2"),
304        )?;
305        let w3 = ColumnParallelLayer::new(
306            hidden_size,
307            i_size,
308            &cfg.quantization_config,
309            false,
310            comm,
311            vb.pp("w3"),
312        )?;
313
314        Ok(Self {
315            w1,
316            w2,
317            w3,
318            act_fn: cfg.hidden_act,
319        })
320    }
321
322    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
323        let original_dtype = xs.dtype();
324        let mut xs = xs.clone();
325        if let Some(t) = self.w1.quantized_act_type() {
326            xs = xs.to_dtype(t)?;
327        }
328        let mut current_hidden_states =
329            MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
330        let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
331        current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?;
332        let mut res = MatMul.qmethod_matmul(&current_hidden_states, &*self.w2)?;
333        if self.w1.quantized_act_type().is_some() {
334            res = res.to_dtype(original_dtype)?;
335        }
336        Ok(res)
337    }
338}
339
340struct MoeMlp {
341    gate: candle_nn::Linear,
342    experts: Vec<Mlp>,
343    router_jitter_noise: f64,
344    num_experts: usize,
345}
346
347impl MoeMlp {
348    fn new(
349        cfg: &Config,
350        vb: ShardedVarBuilder,
351        layer_device: Device,
352        comm: &Arc<mistralrs_quant::Comm>,
353    ) -> Result<Self> {
354        let num_experts = cfg.num_local_experts;
355        let gate = layers::linear_no_bias(
356            cfg.hidden_size,
357            num_experts,
358            vb.pp("gate").set_device(layer_device),
359        )?;
360
361        let experts_vb = vb.pp("experts");
362        let mut experts = Vec::with_capacity(num_experts);
363        for i in 0..num_experts {
364            experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
365        }
366
367        Ok(Self {
368            gate,
369            experts,
370            router_jitter_noise: cfg.router_jitter_noise,
371            num_experts,
372        })
373    }
374
375    fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
376        // Compute mask for sparsity
377        let selected_experts = scores.argmax_keepdim(D::Minus1)?;
378        let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
379        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
380        let mask_logits_threshold = mask_logits_threshold
381            .broadcast_sub(scores)?
382            .broadcast_div(&factor)?
383            .gt(2. * jitter_eps)?;
384
385        // Apply mask
386        let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
387
388        // Compute scores
389        let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
390        let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
391
392        // Mask out first expert
393        let masked_scores = scores.scatter_add(
394            &selected_experts
395                .broadcast_as(scores.shape())?
396                .contiguous()?,
397            &(scores.ones_like()? * f64::NEG_INFINITY)?,
398            D::Minus1,
399        )?;
400
401        // Compute mask for sparsity
402        let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
403        let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
404        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
405        let mask_logits_threshold = mask_logits_threshold
406            .broadcast_sub(scores)?
407            .broadcast_div(&factor)?
408            .gt(2. * jitter_eps)?;
409
410        // Apply mask
411        let masked_gates_top2 =
412            masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
413        let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
414        let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
415
416        let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
417        let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
418
419        Ok((multiplier, selected_experts))
420    }
421
422    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423        let (bs, seq, hidden) = xs.dims3()?;
424        let xs = xs.reshape(((), hidden))?;
425        let xs_dev = xs.device();
426        let xs = xs.to_device(&Device::Cpu)?;
427
428        // Sparse MoE block accumulates hidden states on CPU, but MLP and gate weights are untouched (maybe on GPU)
429
430        let router_logits = self
431            .gate
432            .forward(&xs.to_device(xs_dev)?)?
433            .to_device(&Device::Cpu)?;
434        let (routing_weights, selected_experts) = self.sparsemixer(
435            &router_logits.to_device(&Device::Cpu)?,
436            self.router_jitter_noise,
437        )?;
438
439        let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs.device())?;
440
441        // One hot encode the selected experts to create an expert mask
442        // this will be used to easily index which expert to activate
443        let experts_mask =
444            candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
445                .permute((2, 1, 0))?;
446
447        // Loop over all avail experts in the model and perform the computation on each expert
448        for expert_idx in 0..self.num_experts {
449            let expert = &self.experts[expert_idx];
450            let expert_mask = experts_mask.i(expert_idx)?;
451            assert_eq!(expert_mask.rank(), 2);
452            let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
453            let idx = nonzero_mask.i((.., 0))?;
454            let top_x = nonzero_mask.i((.., 1))?;
455
456            if top_x.dim(0)? == 0 {
457                continue;
458            }
459
460            // Index the correct hidden staters and compute the expert hidden state
461            // for the current expert, we need to make sure to multiply the output hidden
462            // states by `routing_weights` on the corresponding tokens (top-1, top-2)
463            let current_state = xs.index_select(&top_x, 0)?.reshape((1, (), hidden))?;
464            let current_routing_weights = routing_weights
465                .index_select(&top_x, 0)?
466                .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
467            let exp_out = expert
468                .forward(&current_state.to_device(xs_dev)?)?
469                .to_device(&Device::Cpu)?;
470
471            let current_hidden_states = exp_out.broadcast_mul(&current_routing_weights)?;
472
473            final_hidden_states = final_hidden_states.index_add(
474                &top_x.contiguous()?,
475                &current_hidden_states.squeeze(0)?.to_dtype(xs.dtype())?,
476                0,
477            )?;
478        }
479
480        final_hidden_states
481            .reshape((bs, seq, hidden))?
482            .to_device(xs_dev)
483    }
484}
485
486struct DecoderLayer {
487    self_attn: Attention,
488    mlp: MoeMlp,
489    input_layernorm: LayerNorm,
490    post_attention_layernorm: LayerNorm,
491}
492
493impl DecoderLayer {
494    #[allow(clippy::too_many_arguments)]
495    fn new(
496        rotary_emb: Arc<PhiRotaryEmbedding>,
497        cfg: &Config,
498        vb: ShardedVarBuilder,
499        mapper: &dyn DeviceMapper,
500        layer_idx: usize,
501        loading_isq: bool,
502        paged_attn: Option<PagedAttention>,
503        real_device: Device,
504        comm: &Arc<mistralrs_quant::Comm>,
505    ) -> Result<Self> {
506        let self_attn = Attention::new(
507            rotary_emb,
508            cfg,
509            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
510            paged_attn,
511            comm,
512        )?;
513        let mlp = MoeMlp::new(
514            cfg,
515            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
516            mapper
517                .device_for(layer_idx, false)
518                .cloned()
519                .unwrap_or(real_device),
520            comm,
521        )?;
522        let input_layernorm = layer_norm(
523            cfg.hidden_size,
524            cfg.rms_norm_eps,
525            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
526        )?;
527        let post_attention_layernorm = layer_norm(
528            cfg.hidden_size,
529            cfg.rms_norm_eps,
530            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
531        )?;
532        Ok(Self {
533            self_attn,
534            mlp,
535            input_layernorm,
536            post_attention_layernorm,
537        })
538    }
539
540    #[allow(clippy::too_many_arguments)]
541    fn forward(
542        &self,
543        xs: &Tensor,
544        attention_mask: Option<&Tensor>,
545        seqlen_offsets: &[usize],
546        position_ids: &[usize],
547        kv_cache: &mut KvCache,
548        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
549        flash_params: &FlashParams,
550    ) -> Result<Tensor> {
551        let residual = xs;
552        let xs = self.input_layernorm.forward(xs)?;
553        let xs = self.self_attn.forward(
554            &xs,
555            attention_mask,
556            seqlen_offsets,
557            position_ids,
558            kv_cache,
559            metadata,
560            flash_params,
561        )?;
562        let xs = (xs + residual)?;
563        let residual = &xs;
564        let xs = self
565            .mlp
566            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
567        residual + xs
568    }
569}
570
571pub struct Model {
572    embed_tokens: candle_nn::Embedding,
573    layers: Vec<DecoderLayer>,
574    norm: LayerNorm,
575    lm_head: Arc<dyn QuantMethod>,
576    device: Device,
577    cache: EitherCache,
578    max_seq_len: usize,
579    mapper: Box<dyn DeviceMapper + Send + Sync>,
580    sliding_window: Option<usize>,
581    cfg: ModelConfigMetadata,
582}
583
584impl Model {
585    pub fn new(
586        cfg: &Config,
587        vb: ShardedVarBuilder,
588        _is_gptx: bool,
589        normal_loading_metadata: NormalLoadingMetadata,
590        attention_mechanism: AttentionImplementation,
591    ) -> Result<Self> {
592        if let Some(ref quant_cfg) = &cfg.quantization_config {
593            tracing::info!(
594                "Using {} quantization: {}.",
595                quant_cfg.name(),
596                quant_cfg.get_bits_name(&vb)
597            );
598        }
599        let mapper = normal_loading_metadata.mapper;
600        let vb_m = vb.pp("model");
601
602        let embed_tokens = layers::embedding(
603            cfg.vocab_size,
604            cfg.hidden_size,
605            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
606            &cfg.quantization_config,
607        )?;
608        let mut ropes = HashMap::new();
609        for layer_idx in 0..cfg.num_hidden_layers {
610            let device = mapper
611                .device_for(layer_idx, false)
612                .unwrap_or(&normal_loading_metadata.real_device);
613            ropes.insert(
614                device.location(),
615                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
616            );
617        }
618        let vb_l = vb_m.pp("layers");
619        let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
620            0..cfg.num_hidden_layers,
621            "Loading repeating layers",
622            &normal_loading_metadata.multi_progress,
623        )
624        .par_iter_if_isq(|layer_idx| {
625            let device = mapper
626                .device_for(layer_idx, false)
627                .unwrap_or(&normal_loading_metadata.real_device);
628            let rotary_emb = ropes
629                .get(&device.location())
630                .expect("No RoPE for device location!")
631                .clone();
632            let paged_attn = match &attention_mechanism {
633                AttentionImplementation::Eager => None,
634                AttentionImplementation::PagedAttention => {
635                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
636                }
637            };
638            let comm = mapper.get_comm_for(layer_idx)?;
639            DecoderLayer::new(
640                rotary_emb.clone(),
641                cfg,
642                vb_l.pp(layer_idx),
643                &*mapper,
644                layer_idx,
645                normal_loading_metadata.loading_isq,
646                paged_attn,
647                normal_loading_metadata.real_device.clone(),
648                &comm,
649            )
650        })?;
651        let norm = layer_norm(
652            cfg.hidden_size,
653            cfg.rms_norm_eps,
654            mapper.set_nm_device(vb_m.pp("norm"), false),
655        )?;
656        let lm_head = if !cfg.tie_word_embeddings {
657            ReplicatedLayer::new(
658                cfg.hidden_size,
659                cfg.vocab_size,
660                &cfg.quantization_config,
661                cfg.lm_head_bias,
662                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
663            )?
664        } else {
665            unreachable!()
666        };
667        Ok(Self {
668            embed_tokens,
669            layers,
670            norm,
671            lm_head,
672            device: normal_loading_metadata.real_device,
673            cache: EitherCache::Normal(NormalCache::new_sliding(
674                cfg.num_hidden_layers,
675                cfg.max_position_embeddings,
676                cfg.sliding_window,
677            )),
678            max_seq_len: cfg.max_position_embeddings,
679            sliding_window: cfg.sliding_window,
680            cfg: ModelConfigMetadata {
681                max_seq_len: cfg.max_position_embeddings,
682                num_layers: cfg.num_hidden_layers,
683                hidden_size: cfg.hidden_size,
684                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
685                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
686                    .max(1),
687                sliding_window: cfg.sliding_window,
688                k_head_dim: cfg.head_dim(),
689                v_head_dim: cfg.head_dim(),
690            },
691            mapper,
692        })
693    }
694
695    pub fn forward(
696        &self,
697        input_ids: &Tensor,
698        seqlen_offsets: &[usize],
699        position_ids: &[usize],
700        context_lens: Vec<(usize, usize)>,
701        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
702        flash_params: &FlashParams,
703    ) -> Result<Tensor> {
704        let mut xs = self.embed_tokens.forward(input_ids)?;
705        let cache = &mut self.cache.normal().0;
706        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
707            input_ids,
708            metadata
709                .as_ref()
710                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
711                .unwrap_or(cache as &dyn PastKvLenCache),
712            self.sliding_window,
713            xs.dtype(),
714            self.cfg.num_attn_heads,
715        )?;
716        // PagedAttention prompt chunking
717        let attention_mask = attention_mask.filter(|_| {
718            metadata
719                .as_ref()
720                .map(|(_, meta)| meta.is_first_prompt_chunk)
721                .unwrap_or(true)
722        });
723
724        for (i, layer) in self.layers.iter().enumerate() {
725            xs = self.mapper.map(xs, i)?;
726            xs = layer.forward(
727                &xs,
728                attention_mask
729                    .as_ref()
730                    .map(|m| m.to_device(xs.device()).unwrap())
731                    .as_ref(),
732                seqlen_offsets,
733                position_ids,
734                &mut cache[i],
735                metadata
736                    .as_ref()
737                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
738                flash_params,
739            )?
740        }
741        let xs = xs.to_device(&self.device)?;
742        let mut xs = xs.apply(&self.norm)?;
743        if let Some(t) = self.lm_head.quantized_act_type() {
744            xs = xs.to_dtype(t)?;
745        }
746        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
747    }
748}
749
750impl IsqModel for Model {
751    fn get_layers(
752        &mut self,
753    ) -> (
754        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
755        &dyn DeviceMapper,
756    ) {
757        let mut tensors = Vec::new();
758        tensors.push((&mut self.lm_head, None));
759        for (i, layer) in self.layers.iter_mut().enumerate() {
760            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
761            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
762            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
763            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
764            for expert in &mut layer.mlp.experts {
765                tensors.push((&mut expert.w1, Some(i)));
766                tensors.push((&mut expert.w2, Some(i)));
767                tensors.push((&mut expert.w3, Some(i)));
768            }
769        }
770        (tensors, &*self.mapper)
771    }
772    fn get_layers_moe_experts_only(
773        &mut self,
774    ) -> (
775        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
776        &dyn DeviceMapper,
777    ) {
778        let mut tensors = Vec::new();
779        tensors.push((&mut self.lm_head, None));
780        for (i, layer) in self.layers.iter_mut().enumerate() {
781            for expert in &mut layer.mlp.experts {
782                tensors.push((&mut expert.w1, Some(i)));
783                tensors.push((&mut expert.w2, Some(i)));
784                tensors.push((&mut expert.w3, Some(i)));
785            }
786        }
787        (tensors, &*self.mapper)
788    }
789
790    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
791        let uvb = UnVarBuilder::new();
792
793        let uvb_m = uvb.pp("model");
794        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
795        uvb_m.pp("norm").add(&self.norm);
796
797        for (layer_idx, layer) in self.layers.iter().enumerate() {
798            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
799            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
800            uvb_l
801                .pp("post_attention_layernorm")
802                .add(&layer.post_attention_layernorm);
803        }
804
805        uvb.to_safetensors()
806    }
807
808    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
809        let uvb = UnVarBuilder::new();
810
811        let uvb_m = uvb.pp("model");
812        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
813        uvb_m.pp("norm").add(&self.norm);
814
815        for (layer_idx, layer) in self.layers.iter().enumerate() {
816            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
817            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
818            uvb_l
819                .pp("post_attention_layernorm")
820                .add(&layer.post_attention_layernorm);
821
822            let uvb_attn = uvb_l.pp("self_attn");
823            uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
824            uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
825            uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
826            uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
827        }
828
829        Some(uvb.to_safetensors())
830    }
831}
832
833impl NormalModel for Model {
834    fn forward(
835        &self,
836        input_ids: &Tensor,
837        seqlen_offsets: &[usize],
838        context_lens: Vec<(usize, usize)>,
839        position_ids: Vec<usize>,
840        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
841        flash_params: &FlashParams,
842    ) -> Result<Tensor> {
843        self.forward(
844            input_ids,
845            seqlen_offsets,
846            &position_ids,
847            context_lens,
848            metadata,
849            flash_params,
850        )
851    }
852    fn xlora_forward(
853        &self,
854        _input_ids: &Tensor,
855        _input_ids_full: &Tensor,
856        _seqlen_offsets: &[usize],
857        _seqlen_offsets_full: &[usize],
858        _no_kv_cache: bool,
859        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
860        _context_lens: Vec<(usize, usize)>,
861        _position_ids: Vec<usize>,
862        _flash_params: &FlashParams,
863        _flash_params_full: &FlashParams,
864    ) -> Result<Tensor> {
865        unimplemented!()
866    }
867    fn cache(&self) -> &EitherCache {
868        &self.cache
869    }
870    fn cache_mut(&mut self) -> &mut EitherCache {
871        &mut self.cache
872    }
873    fn device(&self) -> &Device {
874        &self.device
875    }
876    fn is_xlora(&self) -> bool {
877        false
878    }
879    fn max_seq_len(&self) -> usize {
880        self.max_seq_len
881    }
882    fn config(&self) -> &ModelConfigMetadata {
883        &self.cfg
884    }
885}
886
887impl AnyMoeBaseModelMixin for Model {}