mistralrs_core/models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3/// Mixtral Model
4/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
5/// https://mistral.ai/news/mixtral-of-experts/
6use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9    ShardedVarBuilder,
10};
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14use crate::{
15    amoe::AnyMoeBaseModelMixin,
16    attention::SdpaParams,
17    device_map::DeviceMapper,
18    layers::{self, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
19    layers_masker::PastKvLenCache,
20    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21    pipeline::{
22        extract_logits,
23        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25    },
26    serde_default_fn,
27    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
33#[derive(Debug, Clone, Deserialize, Serialize)]
34pub struct Config {
35    pub(crate) vocab_size: usize,
36    pub(crate) hidden_size: usize,
37    pub(crate) intermediate_size: usize,
38    pub(crate) num_hidden_layers: usize,
39    pub(crate) num_attention_heads: usize,
40    pub(crate) num_key_value_heads: usize,
41    pub(crate) hidden_act: Activation,
42    pub(crate) max_position_embeddings: usize,
43    pub(crate) rms_norm_eps: f64,
44    pub(crate) rope_theta: f64,
45    pub(crate) sliding_window: Option<usize>,
46    pub(crate) num_experts_per_tok: usize,
47    pub(crate) num_local_experts: usize,
48    pub(crate) use_flash_attn: bool,
49    pub(crate) quantization_config: Option<QuantizedConfig>,
50    #[serde(default = "word_emb_default")]
51    pub(crate) tie_word_embeddings: bool,
52}
53
54struct Attention {
55    q_proj: Arc<dyn QuantMethod>,
56    k_proj: Arc<dyn QuantMethod>,
57    v_proj: Arc<dyn QuantMethod>,
58    o_proj: Arc<dyn QuantMethod>,
59    num_heads: usize,
60    num_kv_heads: usize,
61    head_dim: usize,
62    rotary_emb: Arc<RotaryEmbedding>,
63    paged_attn: Option<PagedAttention>,
64    sdpa_params: SdpaParams,
65}
66
67impl Attention {
68    fn new(
69        rotary_emb: Arc<RotaryEmbedding>,
70        cfg: &Config,
71        vb: ShardedVarBuilder,
72        paged_attn: Option<PagedAttention>,
73        comm: &Arc<mistralrs_quant::Comm>,
74    ) -> Result<Self> {
75        let hidden_sz = cfg.hidden_size;
76        let num_heads = cfg.num_attention_heads;
77        let num_kv_heads = cfg.num_key_value_heads;
78        let head_dim = hidden_sz / num_heads;
79        let q_proj = ColumnParallelLayer::new(
80            hidden_sz,
81            num_heads * head_dim,
82            &cfg.quantization_config,
83            false,
84            comm,
85            vb.pp("q_proj"),
86        )?;
87        let kv_shard = mistralrs_quant::compute_kv_shard(
88            cfg.num_key_value_heads,
89            cfg.hidden_size / cfg.num_attention_heads,
90            comm,
91        );
92        let k_proj = ColumnParallelLayer::new_with_shard(
93            hidden_sz,
94            num_kv_heads * head_dim,
95            &cfg.quantization_config,
96            false,
97            comm,
98            kv_shard,
99            vb.pp("k_proj"),
100        )?;
101        let v_proj = ColumnParallelLayer::new_with_shard(
102            hidden_sz,
103            num_kv_heads * head_dim,
104            &cfg.quantization_config,
105            false,
106            comm,
107            kv_shard,
108            vb.pp("v_proj"),
109        )?;
110        let o_proj = RowParallelLayer::new(
111            num_heads * head_dim,
112            hidden_sz,
113            &cfg.quantization_config,
114            false,
115            comm,
116            vb.pp("o_proj"),
117        )?;
118        Ok(Self {
119            q_proj,
120            k_proj,
121            v_proj,
122            o_proj,
123            num_heads: num_heads / comm.world_size(),
124            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
125            head_dim,
126            rotary_emb,
127            paged_attn,
128            sdpa_params: SdpaParams {
129                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
130                    cfg.num_key_value_heads,
131                    cfg.num_attention_heads,
132                    comm,
133                ),
134                use_flash_attn: cfg.use_flash_attn,
135                softcap: None,
136                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
137                sliding_window: cfg.sliding_window,
138            },
139        })
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    fn forward(
144        &self,
145        xs: &Tensor,
146        attention_mask: Option<&Tensor>,
147        seqlen_offsets: &[usize],
148        kv_cache: &mut KvCache,
149        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
150        flash_params: &FlashParams,
151    ) -> Result<Tensor> {
152        let (b_sz, q_len, _) = xs.dims3()?;
153
154        let original_dtype = xs.dtype();
155        let mut xs = xs.clone();
156        if let Some(t) = self.q_proj.quantized_act_type() {
157            xs = xs.to_dtype(t)?;
158        }
159        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
160        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
161        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
162        if self.q_proj.quantized_act_type().is_some() {
163            q = q.to_dtype(original_dtype)?;
164            k = k.to_dtype(original_dtype)?;
165            v = v.to_dtype(original_dtype)?;
166        }
167
168        let (q, k, v) = if q_len != 1 {
169            let q = q
170                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
171                .transpose(1, 2)?;
172            let k = k
173                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
174                .transpose(1, 2)?;
175            let v = v
176                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
177                .transpose(1, 2)?;
178            (q, k, v)
179        } else {
180            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
181            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
183            (q, k, v)
184        };
185
186        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
187
188        let mut attn_output = match &self.paged_attn {
189            Some(paged_attn) => match metadata {
190                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
191                    &q,
192                    &k,
193                    &v,
194                    attention_mask,
195                    Some(key_cache),
196                    Some(value_cache),
197                    input_metadata,
198                    &self.sdpa_params,
199                    Some(flash_params),
200                )?,
201                None => {
202                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
203                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
204                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
205                    // Sanity check.
206                    assert!(attention_mask.is_some());
207                    paged_attn.forward(
208                        &q,
209                        &k,
210                        &v,
211                        attention_mask,
212                        None,
213                        None,
214                        &input_metadata,
215                        &self.sdpa_params,
216                        Some(flash_params),
217                    )?
218                }
219            },
220            None => {
221                let (k, v) = kv_cache.append(&k, &v)?;
222
223                Sdpa.run_attention(
224                    &q,
225                    &k,
226                    &v,
227                    attention_mask,
228                    Some(flash_params),
229                    &self.sdpa_params,
230                )?
231            }
232        };
233
234        if let Some(t) = self.q_proj.quantized_act_type() {
235            attn_output = attn_output.to_dtype(t)?;
236        }
237        attn_output = if attention_mask.is_some() {
238            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
239        } else {
240            attn_output.reshape((b_sz, q_len, ()))?
241        };
242        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
243        if self.q_proj.quantized_act_type().is_some() {
244            res = res.to_dtype(original_dtype)?;
245        }
246        Ok(res)
247    }
248}
249
250#[derive(Clone)]
251struct BlockSparseTop2MLP {
252    w1: Arc<dyn QuantMethod>,
253    w2: Arc<dyn QuantMethod>,
254    w3: Arc<dyn QuantMethod>,
255    act_fn: Activation,
256}
257
258impl BlockSparseTop2MLP {
259    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
260        let hidden_sz = cfg.hidden_size;
261        let intermediate_sz = cfg.intermediate_size;
262        let w1 = ColumnParallelLayer::new(
263            hidden_sz,
264            intermediate_sz,
265            &cfg.quantization_config,
266            false,
267            comm,
268            vb.pp("w1"),
269        )?;
270        let w2 = RowParallelLayer::new(
271            intermediate_sz,
272            hidden_sz,
273            &cfg.quantization_config,
274            false,
275            comm,
276            vb.pp("w2"),
277        )?;
278        let w3 = ColumnParallelLayer::new(
279            hidden_sz,
280            intermediate_sz,
281            &cfg.quantization_config,
282            false,
283            comm,
284            vb.pp("w3"),
285        )?;
286        Ok(Self {
287            w1,
288            w2,
289            w3,
290            act_fn: cfg.hidden_act,
291        })
292    }
293}
294
295impl Module for BlockSparseTop2MLP {
296    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
297        let original_dtype = xs.dtype();
298        let mut xs = xs.clone();
299        if let Some(t) = self.w1.quantized_act_type() {
300            xs = xs.to_dtype(t)?;
301        }
302        let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
303        let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
304        let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?;
305        if self.w1.quantized_act_type().is_some() {
306            res = res.to_dtype(original_dtype)?;
307        }
308        Ok(res)
309    }
310}
311
312#[derive(Clone)]
313struct SparseMoeBlock {
314    gate: Arc<dyn QuantMethod>,
315    experts: Vec<BlockSparseTop2MLP>,
316    num_experts_per_tok: usize,
317}
318
319impl SparseMoeBlock {
320    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
321        let gate = mistralrs_quant::linear_no_bias(
322            cfg.hidden_size,
323            cfg.num_local_experts,
324            &cfg.quantization_config,
325            vb.pp("gate"),
326        )?;
327        let mut experts = Vec::with_capacity(cfg.num_local_experts);
328        let vb = vb.pp("experts");
329        for idx in 0..cfg.num_local_experts {
330            let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx), comm)?;
331            experts.push(expert)
332        }
333        Ok(SparseMoeBlock {
334            gate,
335            experts,
336            num_experts_per_tok: cfg.num_experts_per_tok,
337        })
338    }
339}
340
341impl Module for SparseMoeBlock {
342    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
343        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
344        let xs = xs.reshape(((), hidden_dim))?;
345
346        let original_dtype = xs.dtype();
347        let mut xs = xs.clone();
348        if let Some(t) = self.gate.quantized_act_type() {
349            xs = xs.to_dtype(t)?;
350        }
351        let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?;
352        if self.gate.quantized_act_type().is_some() {
353            router_logits = router_logits.to_dtype(original_dtype)?;
354        }
355
356        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
357
358        // In order to extract topk, we extract the data from the tensor and manipulate it
359        // directly. Maybe we will want to use some custom ops instead at some point.
360        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
361
362        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
363        // top_x contains the row indexes to evaluate for each expert.
364        let mut top_x = vec![vec![]; self.experts.len()];
365        let mut selected_rws = vec![vec![]; self.experts.len()];
366        for (row_idx, rw) in routing_weights.iter().enumerate() {
367            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
368            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
369            let mut sum_routing_weights = 0f32;
370            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
371                let expert_idx = expert_idx as usize;
372                let routing_weight = rw[expert_idx];
373                sum_routing_weights += routing_weight;
374                top_x[expert_idx].push(row_idx as u32);
375            }
376            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
377                let expert_idx = expert_idx as usize;
378                let routing_weight = rw[expert_idx];
379                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
380            }
381        }
382
383        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
384        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
385
386        let mut ys = xs.zeros_like()?;
387        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
388            let top_x = &top_x[expert_idx];
389            if top_x.is_empty() {
390                continue;
391            }
392            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
393            let selected_rws =
394                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
395            // Index the correct hidden states and compute the expert hidden state for
396            // the current expert. We need to make sure to multiply the output hidden
397            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
398            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
399            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
400            let current_hidden_states = expert_layer.forward(&current_state)?;
401            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
402            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
403        }
404
405        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
406        Ok(ys)
407    }
408}
409
410struct DecoderLayer {
411    self_attn: Attention,
412    block_sparse_moe: SparseMoeBlock,
413    input_layernorm: RmsNorm,
414    post_attention_layernorm: RmsNorm,
415}
416
417impl DecoderLayer {
418    #[allow(clippy::too_many_arguments)]
419    fn new(
420        rotary_emb: Arc<RotaryEmbedding>,
421        cfg: &Config,
422        vb: ShardedVarBuilder,
423        mapper: &dyn DeviceMapper,
424        layer_idx: usize,
425        loading_isq: bool,
426        paged_attn: Option<PagedAttention>,
427        comm: &Arc<mistralrs_quant::Comm>,
428    ) -> Result<Self> {
429        let self_attn = Attention::new(
430            rotary_emb,
431            cfg,
432            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
433            paged_attn,
434            comm,
435        )?;
436        let block_sparse_moe = SparseMoeBlock::new(
437            cfg,
438            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
439            comm,
440        )?;
441        let input_layernorm = RmsNorm::new(
442            cfg.hidden_size,
443            cfg.rms_norm_eps,
444            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
445        )?;
446        let post_attention_layernorm = RmsNorm::new(
447            cfg.hidden_size,
448            cfg.rms_norm_eps,
449            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
450        )?;
451        Ok(Self {
452            self_attn,
453            block_sparse_moe,
454            input_layernorm,
455            post_attention_layernorm,
456        })
457    }
458
459    #[allow(clippy::too_many_arguments)]
460    fn forward(
461        &self,
462        xs: &Tensor,
463        attention_mask: Option<&Tensor>,
464        seqlen_offsets: &[usize],
465        kv_cache: &mut KvCache,
466        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
467        flash_params: &FlashParams,
468    ) -> Result<Tensor> {
469        let residual = xs;
470        let xs = self.input_layernorm.forward(xs)?;
471        let xs = self.self_attn.forward(
472            &xs,
473            attention_mask,
474            seqlen_offsets,
475            kv_cache,
476            metadata,
477            flash_params,
478        )?;
479        let xs = (xs + residual)?;
480        let residual = &xs;
481        let xs = xs
482            .apply(&self.post_attention_layernorm)?
483            .apply(&self.block_sparse_moe)?
484            .to_dtype(residual.dtype())?;
485        residual + xs
486    }
487}
488
489pub struct Model {
490    embed_tokens: candle_nn::Embedding,
491    layers: Vec<DecoderLayer>,
492    norm: RmsNorm,
493    lm_head: Arc<dyn QuantMethod>,
494    sliding_window: Option<usize>,
495    device: Device,
496    cache: EitherCache,
497    max_seq_len: usize,
498    mapper: Box<dyn DeviceMapper + Send + Sync>,
499    cfg: ModelConfigMetadata,
500}
501
502impl Model {
503    pub fn new(
504        cfg: &Config,
505        vb: ShardedVarBuilder,
506        is_gptx: bool,
507        normal_loading_metadata: NormalLoadingMetadata,
508        attention_mechanism: AttentionImplementation,
509    ) -> Result<Self> {
510        if let Some(ref quant_cfg) = &cfg.quantization_config {
511            tracing::info!(
512                "Using {} quantization: {}.",
513                quant_cfg.name(),
514                quant_cfg.get_bits_name(&vb)
515            );
516        }
517        let mapper = normal_loading_metadata.mapper;
518        let vb_m = vb.pp("model");
519
520        let embed_tokens = layers::embedding(
521            cfg.vocab_size,
522            cfg.hidden_size,
523            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
524            &cfg.quantization_config,
525        )?;
526        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
527        let mut ropes = HashMap::new();
528        for layer_idx in 0..cfg.num_hidden_layers {
529            let device = mapper
530                .device_for(layer_idx, false)
531                .unwrap_or(&normal_loading_metadata.real_device);
532            ropes.insert(
533                device.location(),
534                Arc::new(RotaryEmbedding::new(
535                    cfg.rope_theta as f32,
536                    head_dim,
537                    cfg.max_position_embeddings,
538                    device,
539                    is_gptx,
540                    vb_m.dtype(),
541                )?),
542            );
543        }
544        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
545        let vb_l = vb_m.pp("layers");
546        for layer_idx in NiceProgressBar::<_, 'b'>(
547            0..cfg.num_hidden_layers,
548            "Loading repeating layers",
549            &normal_loading_metadata.multi_progress,
550        ) {
551            let device = mapper
552                .device_for(layer_idx, false)
553                .unwrap_or(&normal_loading_metadata.real_device);
554            let rotary_emb = ropes
555                .get(&device.location())
556                .expect("No RoPE for device location!")
557                .clone();
558            let paged_attn = match &attention_mechanism {
559                AttentionImplementation::Eager => None,
560                AttentionImplementation::PagedAttention => {
561                    Some(PagedAttention::new(head_dim, device, None)?)
562                }
563            };
564            let comm = mapper.get_comm_for(layer_idx)?;
565            let layer = DecoderLayer::new(
566                rotary_emb.clone(),
567                cfg,
568                vb_l.pp(layer_idx),
569                &*mapper,
570                layer_idx,
571                normal_loading_metadata.loading_isq,
572                paged_attn,
573                &comm,
574            )?;
575            layers.push(layer)
576        }
577        let norm = RmsNorm::new(
578            cfg.hidden_size,
579            cfg.rms_norm_eps,
580            mapper.set_nm_device(vb_m.pp("norm"), false),
581        )?;
582        let lm_head = if !cfg.tie_word_embeddings {
583            ReplicatedLayer::new(
584                cfg.hidden_size,
585                cfg.vocab_size,
586                &None,
587                false,
588                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
589            )?
590        } else {
591            ReplicatedLayer::from_linear(candle_nn::Linear::new(
592                mapper.cast_nm_device(
593                    embed_tokens.embeddings(),
594                    normal_loading_metadata.loading_isq,
595                )?,
596                None,
597            ))?
598        };
599        Ok(Self {
600            embed_tokens,
601            layers,
602            norm,
603            lm_head,
604            sliding_window: cfg.sliding_window,
605            device: normal_loading_metadata.real_device,
606            cache: EitherCache::Normal(NormalCache::new_sliding(
607                cfg.num_hidden_layers,
608                cfg.max_position_embeddings,
609                cfg.sliding_window,
610            )),
611            max_seq_len: cfg.max_position_embeddings,
612            cfg: ModelConfigMetadata {
613                max_seq_len: cfg.max_position_embeddings,
614                num_layers: cfg.num_hidden_layers,
615                hidden_size: cfg.hidden_size,
616                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
617                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
618                    .max(1),
619                sliding_window: cfg.sliding_window,
620                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
621                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
622            },
623            mapper,
624        })
625    }
626
627    pub fn forward(
628        &self,
629        input_ids: &Tensor,
630        seqlen_offsets: &[usize],
631        context_lens: Vec<(usize, usize)>,
632        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
633        flash_params: &FlashParams,
634    ) -> Result<Tensor> {
635        let mut xs = self.embed_tokens.forward(input_ids)?;
636        let cache = &mut self.cache.normal().0;
637        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
638            input_ids,
639            metadata
640                .as_ref()
641                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
642                .unwrap_or(cache as &dyn PastKvLenCache),
643            self.sliding_window,
644            xs.dtype(),
645            self.cfg.num_attn_heads,
646        )?;
647        // PagedAttention prompt chunking
648        let attention_mask = attention_mask.filter(|_| {
649            metadata
650                .as_ref()
651                .map(|(_, meta)| meta.is_first_prompt_chunk)
652                .unwrap_or(true)
653        });
654        for (i, layer) in self.layers.iter().enumerate() {
655            xs = self.mapper.map(xs, i)?;
656            xs = layer.forward(
657                &xs,
658                attention_mask
659                    .as_ref()
660                    .map(|m| m.to_device(xs.device()).unwrap())
661                    .as_ref(),
662                seqlen_offsets,
663                &mut cache[i],
664                metadata
665                    .as_ref()
666                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
667                flash_params,
668            )?;
669        }
670        let xs = xs.to_device(&self.device)?;
671        let mut xs = xs.apply(&self.norm)?;
672        if let Some(t) = self.lm_head.quantized_act_type() {
673            xs = xs.to_dtype(t)?;
674        }
675        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
676    }
677}
678
679impl IsqModel for Model {
680    fn get_layers(
681        &mut self,
682    ) -> (
683        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
684        &dyn DeviceMapper,
685    ) {
686        let mut tensors = Vec::new();
687        tensors.push((&mut self.lm_head, None));
688        for (i, layer) in self.layers.iter_mut().enumerate() {
689            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
690            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
691            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
692            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
693            tensors.push((&mut layer.block_sparse_moe.gate, Some(i)));
694            for expert in &mut layer.block_sparse_moe.experts {
695                tensors.push((&mut expert.w1, Some(i)));
696                tensors.push((&mut expert.w2, Some(i)));
697                tensors.push((&mut expert.w3, Some(i)));
698            }
699        }
700        (tensors, &*self.mapper)
701    }
702
703    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
704        let uvb = UnVarBuilder::new();
705
706        let uvb_m = uvb.pp("model");
707        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
708        uvb_m.pp("norm").add(&self.norm);
709
710        for (layer_idx, layer) in self.layers.iter().enumerate() {
711            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
712            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
713            uvb_l
714                .pp("post_attention_layernorm")
715                .add(&layer.post_attention_layernorm);
716        }
717
718        uvb.to_safetensors()
719    }
720}
721
722impl NormalModel for Model {
723    fn forward(
724        &self,
725        input_ids: &Tensor,
726        seqlen_offsets: &[usize],
727        context_lens: Vec<(usize, usize)>,
728        _position_ids: Vec<usize>,
729        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
730        flash_params: &FlashParams,
731    ) -> Result<Tensor> {
732        self.forward(
733            input_ids,
734            seqlen_offsets,
735            context_lens,
736            metadata,
737            flash_params,
738        )
739    }
740    fn xlora_forward(
741        &self,
742        _input_ids: &Tensor,
743        _input_ids_full: &Tensor,
744        _seqlen_offsets: &[usize],
745        _seqlen_offsets_full: &[usize],
746        _no_kv_cache: bool,
747        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
748        _context_lens: Vec<(usize, usize)>,
749        _position_ids: Vec<usize>,
750        _flash_params: &FlashParams,
751        _flash_params_full: &FlashParams,
752    ) -> Result<Tensor> {
753        unimplemented!()
754    }
755    fn cache(&self) -> &EitherCache {
756        &self.cache
757    }
758    fn cache_mut(&mut self) -> &mut EitherCache {
759        &mut self.cache
760    }
761    fn device(&self) -> &Device {
762        &self.device
763    }
764    fn is_xlora(&self) -> bool {
765        false
766    }
767    fn max_seq_len(&self) -> usize {
768        self.max_seq_len
769    }
770    fn config(&self) -> &ModelConfigMetadata {
771        &self.cfg
772    }
773}
774
775impl AnyMoeBaseModelMixin for Model {}