mistralrs_core/models/
phi3_5_moe.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{
6    ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
7    RowParallelLayer, ShardedVarBuilder,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12    amoe::AnyMoeBaseModelMixin,
13    attention::SdpaParams,
14    device_map::DeviceMapper,
15    layers::{
16        self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
17        PhiRotaryEmbedding, Sdpa,
18    },
19    layers_masker::{masked_fill, PastKvLenCache},
20    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21    pipeline::{
22        extract_logits,
23        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25    },
26    serde_default_fn,
27    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
33#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
34pub struct Config {
35    pub(crate) vocab_size: usize,
36    pub(crate) hidden_act: Activation,
37    pub(crate) hidden_size: usize,
38    pub(crate) intermediate_size: usize,
39    pub(crate) num_hidden_layers: usize,
40    pub(crate) num_attention_heads: usize,
41    pub(crate) num_key_value_heads: usize,
42    pub(crate) rms_norm_eps: f64,
43    pub(crate) rope_theta: f64,
44    pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
45    pub(crate) max_position_embeddings: usize,
46    pub(crate) sliding_window: Option<usize>,
47    pub(crate) original_max_position_embeddings: usize,
48
49    pub(crate) quantization_config: Option<QuantizedConfig>,
50    pub(crate) lm_head_bias: bool,
51    pub(crate) attention_bias: bool,
52    pub(crate) num_local_experts: usize,
53    pub(crate) router_jitter_noise: f64,
54    #[serde(default = "word_emb_default")]
55    pub(crate) tie_word_embeddings: bool,
56}
57
58impl From<Config> for PhiRopeConfig {
59    fn from(val: Config) -> Self {
60        PhiRopeConfig {
61            rope_scaling: val.rope_scaling,
62            max_position_embeddings: val.max_position_embeddings,
63            original_max_position_embeddings: val.original_max_position_embeddings,
64            rope_theta: val.rope_theta,
65            head_dim: val.hidden_size / val.num_attention_heads,
66            partial_rotary_factor: None,
67        }
68    }
69}
70
71impl Config {
72    pub fn head_dim(&self) -> usize {
73        self.hidden_size / self.num_attention_heads
74    }
75}
76
77struct Attention {
78    q_proj: Arc<dyn QuantMethod>,
79    k_proj: Arc<dyn QuantMethod>,
80    v_proj: Arc<dyn QuantMethod>,
81    o_proj: Arc<dyn QuantMethod>,
82    num_heads: usize,
83    num_kv_heads: usize,
84    head_dim: usize,
85    rotary_emb: Arc<PhiRotaryEmbedding>,
86    paged_attn: Option<PagedAttention>,
87    sdpa_params: SdpaParams,
88}
89
90impl Attention {
91    fn new(
92        rotary_emb: Arc<PhiRotaryEmbedding>,
93        cfg: &Config,
94        vb: ShardedVarBuilder,
95        paged_attn: Option<PagedAttention>,
96        comm: &Arc<mistralrs_quant::Comm>,
97    ) -> Result<Self> {
98        let num_heads = cfg.num_attention_heads;
99        let num_kv_heads = cfg.num_key_value_heads;
100        let head_dim = cfg.head_dim();
101
102        let q_proj = ColumnParallelLayer::new(
103            cfg.hidden_size,
104            num_heads * head_dim,
105            &cfg.quantization_config,
106            cfg.attention_bias,
107            comm,
108            vb.pp("q_proj"),
109        )?;
110        let kv_shard = mistralrs_quant::compute_kv_shard(
111            cfg.num_key_value_heads,
112            cfg.hidden_size / cfg.num_attention_heads,
113            comm,
114        );
115        let k_proj = ColumnParallelLayer::new_with_shard(
116            cfg.hidden_size,
117            num_kv_heads * head_dim,
118            &cfg.quantization_config,
119            cfg.attention_bias,
120            comm,
121            kv_shard,
122            vb.pp("k_proj"),
123        )?;
124        let v_proj = ColumnParallelLayer::new_with_shard(
125            cfg.hidden_size,
126            num_kv_heads * head_dim,
127            &cfg.quantization_config,
128            cfg.attention_bias,
129            comm,
130            kv_shard,
131            vb.pp("v_proj"),
132        )?;
133        let o_proj = RowParallelLayer::new(
134            num_heads * head_dim,
135            cfg.hidden_size,
136            &cfg.quantization_config,
137            cfg.attention_bias,
138            comm,
139            vb.pp("o_proj"),
140        )?;
141
142        Ok(Self {
143            q_proj,
144            k_proj,
145            v_proj,
146            o_proj,
147            rotary_emb,
148            num_heads: num_heads / comm.world_size(),
149            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
150            head_dim,
151            paged_attn,
152            sdpa_params: SdpaParams {
153                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
154                    cfg.num_key_value_heads,
155                    cfg.num_attention_heads,
156                    comm,
157                ),
158                softcap: None,
159                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
160                sliding_window: cfg.sliding_window,
161            },
162        })
163    }
164
165    #[allow(clippy::too_many_arguments)]
166    fn forward(
167        &self,
168        xs: &Tensor,
169        attention_mask: Option<&Tensor>,
170        seqlen_offsets: &[usize],
171        position_ids: &[usize],
172        kv_cache: &mut KvCache,
173        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
174        flash_params: &FlashParams,
175    ) -> Result<Tensor> {
176        let (b_sz, q_len, _) = xs.dims3()?;
177
178        let original_dtype = xs.dtype();
179        let mut xs = xs.clone();
180        if let Some(t) = self.q_proj.quantized_act_type() {
181            xs = xs.to_dtype(t)?;
182        }
183        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
184        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
185        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
186        if self.q_proj.quantized_act_type().is_some() {
187            q = q.to_dtype(original_dtype)?;
188            k = k.to_dtype(original_dtype)?;
189            v = v.to_dtype(original_dtype)?;
190        }
191
192        let (q, k, v) = if q_len != 1 {
193            let q = q
194                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
195                .transpose(1, 2)?;
196            let k = k
197                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198                .transpose(1, 2)?;
199            let v = v
200                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
201                .transpose(1, 2)?;
202            (q, k, v)
203        } else {
204            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
205            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
206            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
207            (q, k, v)
208        };
209
210        let (q, k) = self
211            .rotary_emb
212            .forward(&q, &k, seqlen_offsets, position_ids)?;
213
214        let mut attn_output = match &self.paged_attn {
215            Some(paged_attn) => match metadata {
216                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
217                    &q,
218                    &k,
219                    &v,
220                    attention_mask,
221                    Some(key_cache),
222                    Some(value_cache),
223                    input_metadata,
224                    &self.sdpa_params,
225                    Some(flash_params),
226                )?,
227                None => {
228                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
229                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
230                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
231                    // Sanity check.
232                    assert!(attention_mask.is_some());
233                    paged_attn.forward(
234                        &q,
235                        &k,
236                        &v,
237                        attention_mask,
238                        None,
239                        None,
240                        &input_metadata,
241                        &self.sdpa_params,
242                        Some(flash_params),
243                    )?
244                }
245            },
246            _ => {
247                let (k, v) = kv_cache.append(&k, &v)?;
248
249                Sdpa.run_attention(
250                    &q,
251                    &k,
252                    &v,
253                    attention_mask,
254                    Some(flash_params),
255                    &self.sdpa_params,
256                )?
257            }
258        };
259
260        if let Some(t) = self.q_proj.quantized_act_type() {
261            attn_output = attn_output.to_dtype(t)?;
262        }
263        attn_output = if attention_mask.is_some() {
264            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
265        } else {
266            attn_output.reshape((b_sz, q_len, ()))?
267        };
268        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
269        if self.q_proj.quantized_act_type().is_some() {
270            res = res.to_dtype(original_dtype)?;
271        }
272        Ok(res)
273    }
274}
275
276#[derive(Clone)]
277struct Mlp {
278    w1: Arc<dyn QuantMethod>,
279    w2: Arc<dyn QuantMethod>,
280    w3: Arc<dyn QuantMethod>,
281    act_fn: Activation,
282}
283
284impl Mlp {
285    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
286        let hidden_size = cfg.hidden_size;
287        let i_size = cfg.intermediate_size;
288
289        let w1 = ColumnParallelLayer::new(
290            hidden_size,
291            i_size,
292            &cfg.quantization_config,
293            false,
294            comm,
295            vb.pp("w1"),
296        )?;
297        let w2 = RowParallelLayer::new(
298            i_size,
299            hidden_size,
300            &cfg.quantization_config,
301            false,
302            comm,
303            vb.pp("w2"),
304        )?;
305        let w3 = ColumnParallelLayer::new(
306            hidden_size,
307            i_size,
308            &cfg.quantization_config,
309            false,
310            comm,
311            vb.pp("w3"),
312        )?;
313
314        Ok(Self {
315            w1,
316            w2,
317            w3,
318            act_fn: cfg.hidden_act,
319        })
320    }
321
322    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
323        let original_dtype = xs.dtype();
324        let mut xs = xs.clone();
325        if let Some(t) = self.w1.quantized_act_type() {
326            xs = xs.to_dtype(t)?;
327        }
328        let w1_out = MatMul.qmethod_matmul(&xs, &*self.w1)?;
329        let w3_out = MatMul.qmethod_matmul(&xs, &*self.w3)?;
330        let current_hidden_states = crate::ops::mul_and_act(&w1_out, &w3_out, self.act_fn)?;
331        let mut res = MatMul.qmethod_matmul(&current_hidden_states, &*self.w2)?;
332        if self.w1.quantized_act_type().is_some() {
333            res = res.to_dtype(original_dtype)?;
334        }
335        Ok(res)
336    }
337}
338
339struct MoeMlp {
340    gate: candle_nn::Linear,
341    experts: Vec<Mlp>,
342    router_jitter_noise: f64,
343    num_experts: usize,
344}
345
346impl MoeMlp {
347    fn new(
348        cfg: &Config,
349        vb: ShardedVarBuilder,
350        layer_device: Device,
351        comm: &Arc<mistralrs_quant::Comm>,
352    ) -> Result<Self> {
353        let num_experts = cfg.num_local_experts;
354        let gate = layers::linear_no_bias(
355            cfg.hidden_size,
356            num_experts,
357            vb.pp("gate").set_device(layer_device),
358        )?;
359
360        let experts_vb = vb.pp("experts");
361        let mut experts = Vec::with_capacity(num_experts);
362        for i in 0..num_experts {
363            experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
364        }
365
366        Ok(Self {
367            gate,
368            experts,
369            router_jitter_noise: cfg.router_jitter_noise,
370            num_experts,
371        })
372    }
373
374    fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
375        // Compute mask for sparsity
376        let selected_experts = scores.argmax_keepdim(D::Minus1)?;
377        let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
378        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
379        let mask_logits_threshold = mask_logits_threshold
380            .broadcast_sub(scores)?
381            .broadcast_div(&factor)?
382            .gt(2. * jitter_eps)?;
383
384        // Apply mask
385        let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
386
387        // Compute scores
388        let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
389        let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
390
391        // Mask out first expert
392        let masked_scores = scores.scatter_add(
393            &selected_experts
394                .broadcast_as(scores.shape())?
395                .contiguous()?,
396            &(scores.ones_like()? * f64::NEG_INFINITY)?,
397            D::Minus1,
398        )?;
399
400        // Compute mask for sparsity
401        let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
402        let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
403        let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
404        let mask_logits_threshold = mask_logits_threshold
405            .broadcast_sub(scores)?
406            .broadcast_div(&factor)?
407            .gt(2. * jitter_eps)?;
408
409        // Apply mask
410        let masked_gates_top2 =
411            masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
412        let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
413        let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
414
415        let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
416        let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
417
418        Ok((multiplier, selected_experts))
419    }
420
421    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
422        let (bs, seq, hidden) = xs.dims3()?;
423        let xs = xs.reshape(((), hidden))?;
424        let xs_dev = xs.device();
425        let xs = xs.to_device(&Device::Cpu)?;
426
427        // Sparse MoE block accumulates hidden states on CPU, but MLP and gate weights are untouched (maybe on GPU)
428
429        let router_logits = self
430            .gate
431            .forward(&xs.to_device(xs_dev)?)?
432            .to_device(&Device::Cpu)?;
433        let (routing_weights, selected_experts) = self.sparsemixer(
434            &router_logits.to_device(&Device::Cpu)?,
435            self.router_jitter_noise,
436        )?;
437
438        let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs_dev)?;
439
440        // One hot encode the selected experts to create an expert mask
441        // this will be used to easily index which expert to activate
442        let experts_mask =
443            candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
444                .permute((2, 1, 0))?;
445
446        // Loop over all avail experts in the model and perform the computation on each expert
447        for expert_idx in 0..self.num_experts {
448            let expert = &self.experts[expert_idx];
449            let expert_mask = experts_mask.i(expert_idx)?;
450            assert_eq!(expert_mask.rank(), 2);
451            let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
452            let idx = nonzero_mask.i((.., 0))?;
453            let top_x = nonzero_mask.i((.., 1))?;
454
455            if top_x.dim(0)? == 0 {
456                continue;
457            }
458
459            // Index the correct hidden staters and compute the expert hidden state
460            // for the current expert, we need to make sure to multiply the output hidden
461            // states by `routing_weights` on the corresponding tokens (top-1, top-2)
462            let current_state = xs.index_select(&top_x, 0)?.reshape((1, (), hidden))?;
463            let current_routing_weights = routing_weights
464                .index_select(&top_x, 0)?
465                .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
466            let exp_out = expert
467                .forward(&current_state.to_device(xs_dev)?)?
468                .to_device(&Device::Cpu)?;
469
470            let current_hidden_states = exp_out.broadcast_mul(&current_routing_weights)?;
471
472            final_hidden_states = final_hidden_states.index_add(
473                &top_x.contiguous()?.to_device(xs_dev)?,
474                &current_hidden_states
475                    .squeeze(0)?
476                    .to_dtype(xs.dtype())?
477                    .to_device(xs_dev)?,
478                0,
479            )?;
480        }
481
482        final_hidden_states
483            .reshape((bs, seq, hidden))?
484            .to_device(xs_dev)
485    }
486}
487
488struct DecoderLayer {
489    self_attn: Attention,
490    mlp: MoeMlp,
491    input_layernorm: LayerNorm,
492    post_attention_layernorm: LayerNorm,
493}
494
495impl DecoderLayer {
496    #[allow(clippy::too_many_arguments)]
497    fn new(
498        rotary_emb: Arc<PhiRotaryEmbedding>,
499        cfg: &Config,
500        vb: ShardedVarBuilder,
501        mapper: &dyn DeviceMapper,
502        layer_idx: usize,
503        loading_isq: bool,
504        paged_attn: Option<PagedAttention>,
505        real_device: Device,
506        comm: &Arc<mistralrs_quant::Comm>,
507    ) -> Result<Self> {
508        let self_attn = Attention::new(
509            rotary_emb,
510            cfg,
511            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
512            paged_attn,
513            comm,
514        )?;
515        let mlp = MoeMlp::new(
516            cfg,
517            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
518            mapper
519                .device_for(layer_idx, false)
520                .cloned()
521                .unwrap_or(real_device),
522            comm,
523        )?;
524        let input_layernorm = layer_norm(
525            cfg.hidden_size,
526            cfg.rms_norm_eps,
527            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
528        )?;
529        let post_attention_layernorm = layer_norm(
530            cfg.hidden_size,
531            cfg.rms_norm_eps,
532            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
533        )?;
534        Ok(Self {
535            self_attn,
536            mlp,
537            input_layernorm,
538            post_attention_layernorm,
539        })
540    }
541
542    #[allow(clippy::too_many_arguments)]
543    fn forward(
544        &self,
545        xs: &Tensor,
546        attention_mask: Option<&Tensor>,
547        seqlen_offsets: &[usize],
548        position_ids: &[usize],
549        kv_cache: &mut KvCache,
550        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
551        flash_params: &FlashParams,
552    ) -> Result<Tensor> {
553        let residual = xs;
554        let xs = self.input_layernorm.forward(xs)?;
555        let xs = self.self_attn.forward(
556            &xs,
557            attention_mask,
558            seqlen_offsets,
559            position_ids,
560            kv_cache,
561            metadata,
562            flash_params,
563        )?;
564        let xs = (xs + residual)?;
565        let residual = &xs;
566        let xs = self
567            .mlp
568            .forward(&xs.apply(&self.post_attention_layernorm)?)?;
569        residual + xs
570    }
571}
572
573pub struct Model {
574    embed_tokens: candle_nn::Embedding,
575    layers: Vec<DecoderLayer>,
576    norm: LayerNorm,
577    lm_head: Arc<dyn QuantMethod>,
578    device: Device,
579    cache: EitherCache,
580    max_seq_len: usize,
581    mapper: Box<dyn DeviceMapper + Send + Sync>,
582    sliding_window: Option<usize>,
583    cfg: ModelConfigMetadata,
584}
585
586impl Model {
587    pub fn new(
588        cfg: &Config,
589        vb: ShardedVarBuilder,
590        _is_gptx: bool,
591        normal_loading_metadata: NormalLoadingMetadata,
592        attention_mechanism: AttentionImplementation,
593    ) -> Result<Self> {
594        if let Some(ref quant_cfg) = &cfg.quantization_config {
595            tracing::info!(
596                "Using {} quantization: {}.",
597                quant_cfg.name(),
598                quant_cfg.get_bits_name(&vb)
599            );
600        }
601        let mapper = normal_loading_metadata.mapper;
602        let vb_m = vb.pp("model");
603
604        let embed_tokens = layers::embedding(
605            cfg.vocab_size,
606            cfg.hidden_size,
607            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
608            &cfg.quantization_config,
609        )?;
610        let mut ropes = HashMap::new();
611        for layer_idx in 0..cfg.num_hidden_layers {
612            let device = mapper
613                .device_for(layer_idx, false)
614                .unwrap_or(&normal_loading_metadata.real_device);
615            ropes.insert(
616                device.location(),
617                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
618            );
619        }
620        let vb_l = vb_m.pp("layers");
621        let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
622            0..cfg.num_hidden_layers,
623            "Loading repeating layers",
624            &normal_loading_metadata.multi_progress,
625        )
626        .par_iter_if_isq(|layer_idx| {
627            let device = mapper
628                .device_for(layer_idx, false)
629                .unwrap_or(&normal_loading_metadata.real_device);
630            let rotary_emb = ropes
631                .get(&device.location())
632                .expect("No RoPE for device location!")
633                .clone();
634            let paged_attn = match &attention_mechanism {
635                AttentionImplementation::Eager => None,
636                AttentionImplementation::PagedAttention => {
637                    Some(PagedAttention::new(cfg.head_dim(), device, None)?)
638                }
639            };
640            let comm = mapper.get_comm_for(layer_idx)?;
641            DecoderLayer::new(
642                rotary_emb.clone(),
643                cfg,
644                vb_l.pp(layer_idx),
645                &*mapper,
646                layer_idx,
647                normal_loading_metadata.loading_isq,
648                paged_attn,
649                normal_loading_metadata.real_device.clone(),
650                &comm,
651            )
652        })?;
653        let norm = layer_norm(
654            cfg.hidden_size,
655            cfg.rms_norm_eps,
656            mapper.set_nm_device(vb_m.pp("norm"), false),
657        )?;
658        let lm_head = if !cfg.tie_word_embeddings {
659            ReplicatedLayer::new(
660                cfg.hidden_size,
661                cfg.vocab_size,
662                &cfg.quantization_config,
663                cfg.lm_head_bias,
664                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
665            )?
666        } else {
667            unreachable!()
668        };
669        Ok(Self {
670            embed_tokens,
671            layers,
672            norm,
673            lm_head,
674            device: normal_loading_metadata.real_device,
675            cache: EitherCache::Normal(NormalCache::new_sliding(
676                cfg.num_hidden_layers,
677                cfg.max_position_embeddings,
678                cfg.sliding_window,
679            )),
680            max_seq_len: cfg.max_position_embeddings,
681            sliding_window: cfg.sliding_window,
682            cfg: ModelConfigMetadata {
683                max_seq_len: cfg.max_position_embeddings,
684                num_layers: cfg.num_hidden_layers,
685                hidden_size: cfg.hidden_size,
686                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
687                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
688                    .max(1),
689                sliding_window: cfg.sliding_window,
690                k_head_dim: cfg.head_dim(),
691                v_head_dim: cfg.head_dim(),
692            },
693            mapper,
694        })
695    }
696
697    pub fn forward(
698        &self,
699        input_ids: &Tensor,
700        seqlen_offsets: &[usize],
701        position_ids: &[usize],
702        context_lens: Vec<(usize, usize)>,
703        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
704        flash_params: &FlashParams,
705    ) -> Result<Tensor> {
706        let mut xs = self.embed_tokens.forward(input_ids)?;
707        let cache = &mut self.cache.normal().0;
708        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
709            input_ids,
710            metadata
711                .as_ref()
712                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
713                .unwrap_or(cache as &dyn PastKvLenCache),
714            self.sliding_window,
715            xs.dtype(),
716            self.cfg.num_attn_heads,
717        )?;
718        // PagedAttention prompt chunking
719        let attention_mask = attention_mask.filter(|_| {
720            metadata
721                .as_ref()
722                .map(|(_, meta)| meta.is_first_prompt_chunk)
723                .unwrap_or(true)
724        });
725
726        for (i, layer) in self.layers.iter().enumerate() {
727            xs = self.mapper.map(xs, i)?;
728            xs = layer.forward(
729                &xs,
730                attention_mask
731                    .as_ref()
732                    .map(|m| m.to_device(xs.device()).unwrap())
733                    .as_ref(),
734                seqlen_offsets,
735                position_ids,
736                &mut cache[i],
737                metadata
738                    .as_ref()
739                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
740                flash_params,
741            )?
742        }
743        let xs = xs.to_device(&self.device)?;
744        let mut xs = xs.apply(&self.norm)?;
745        if let Some(t) = self.lm_head.quantized_act_type() {
746            xs = xs.to_dtype(t)?;
747        }
748        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
749    }
750}
751
752impl IsqModel for Model {
753    fn get_layers(
754        &mut self,
755    ) -> (
756        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
757        &dyn DeviceMapper,
758    ) {
759        let mut tensors = Vec::new();
760        tensors.push((&mut self.lm_head, None));
761        for (i, layer) in self.layers.iter_mut().enumerate() {
762            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
763            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
764            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
765            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
766            for expert in &mut layer.mlp.experts {
767                tensors.push((&mut expert.w1, Some(i)));
768                tensors.push((&mut expert.w2, Some(i)));
769                tensors.push((&mut expert.w3, Some(i)));
770            }
771        }
772        (tensors, &*self.mapper)
773    }
774    fn get_layers_moe_experts_only(
775        &mut self,
776    ) -> (
777        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
778        &dyn DeviceMapper,
779    ) {
780        let mut tensors = Vec::new();
781        tensors.push((&mut self.lm_head, None));
782        for (i, layer) in self.layers.iter_mut().enumerate() {
783            for expert in &mut layer.mlp.experts {
784                tensors.push((&mut expert.w1, Some(i)));
785                tensors.push((&mut expert.w2, Some(i)));
786                tensors.push((&mut expert.w3, Some(i)));
787            }
788        }
789        (tensors, &*self.mapper)
790    }
791
792    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
793        let uvb = UnVarBuilder::new();
794
795        let uvb_m = uvb.pp("model");
796        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
797        uvb_m.pp("norm").add(&self.norm);
798
799        for (layer_idx, layer) in self.layers.iter().enumerate() {
800            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
801            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
802            uvb_l
803                .pp("post_attention_layernorm")
804                .add(&layer.post_attention_layernorm);
805        }
806
807        uvb.to_safetensors()
808    }
809
810    fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
811        let uvb = UnVarBuilder::new();
812
813        let uvb_m = uvb.pp("model");
814        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
815        uvb_m.pp("norm").add(&self.norm);
816
817        for (layer_idx, layer) in self.layers.iter().enumerate() {
818            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
819            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
820            uvb_l
821                .pp("post_attention_layernorm")
822                .add(&layer.post_attention_layernorm);
823
824            let uvb_attn = uvb_l.pp("self_attn");
825            uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
826            uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
827            uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
828            uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
829        }
830
831        Some(uvb.to_safetensors())
832    }
833}
834
835impl NormalModel for Model {
836    fn forward(
837        &self,
838        input_ids: &Tensor,
839        seqlen_offsets: &[usize],
840        context_lens: Vec<(usize, usize)>,
841        position_ids: Vec<usize>,
842        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
843        flash_params: &FlashParams,
844    ) -> Result<Tensor> {
845        self.forward(
846            input_ids,
847            seqlen_offsets,
848            &position_ids,
849            context_lens,
850            metadata,
851            flash_params,
852        )
853    }
854    fn xlora_forward(
855        &self,
856        _input_ids: &Tensor,
857        _input_ids_full: &Tensor,
858        _seqlen_offsets: &[usize],
859        _seqlen_offsets_full: &[usize],
860        _no_kv_cache: bool,
861        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
862        _context_lens: Vec<(usize, usize)>,
863        _position_ids: Vec<usize>,
864        _flash_params: &FlashParams,
865        _flash_params_full: &FlashParams,
866    ) -> Result<Tensor> {
867        unimplemented!()
868    }
869    fn cache(&self) -> &EitherCache {
870        &self.cache
871    }
872    fn cache_mut(&mut self) -> &mut EitherCache {
873        &mut self.cache
874    }
875    fn device(&self) -> &Device {
876        &self.device
877    }
878    fn is_xlora(&self) -> bool {
879        false
880    }
881    fn max_seq_len(&self) -> usize {
882        self.max_seq_len
883    }
884    fn config(&self) -> &ModelConfigMetadata {
885        &self.cfg
886    }
887}
888
889impl AnyMoeBaseModelMixin for Model {}