mistralrs_core/models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3/// Mixtral Model
4/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
5/// https://mistral.ai/news/mixtral-of-experts/
6use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9    ShardedVarBuilder,
10};
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14use crate::{
15    amoe::AnyMoeBaseModelMixin,
16    attention::SdpaParams,
17    device_map::DeviceMapper,
18    layers::{self, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
19    layers_masker::PastKvLenCache,
20    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21    pipeline::{
22        extract_logits,
23        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25    },
26    serde_default_fn,
27    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
33#[derive(Debug, Clone, Deserialize, Serialize)]
34pub struct Config {
35    pub(crate) vocab_size: usize,
36    pub(crate) hidden_size: usize,
37    pub(crate) intermediate_size: usize,
38    pub(crate) num_hidden_layers: usize,
39    pub(crate) num_attention_heads: usize,
40    pub(crate) num_key_value_heads: usize,
41    pub(crate) hidden_act: Activation,
42    pub(crate) max_position_embeddings: usize,
43    pub(crate) rms_norm_eps: f64,
44    pub(crate) rope_theta: f64,
45    pub(crate) sliding_window: Option<usize>,
46    pub(crate) num_experts_per_tok: usize,
47    pub(crate) num_local_experts: usize,
48    pub(crate) use_flash_attn: bool,
49    pub(crate) quantization_config: Option<QuantizedConfig>,
50    #[serde(default = "word_emb_default")]
51    pub(crate) tie_word_embeddings: bool,
52}
53
54struct Attention {
55    q_proj: Arc<dyn QuantMethod>,
56    k_proj: Arc<dyn QuantMethod>,
57    v_proj: Arc<dyn QuantMethod>,
58    o_proj: Arc<dyn QuantMethod>,
59    num_heads: usize,
60    num_kv_heads: usize,
61    head_dim: usize,
62    rotary_emb: Arc<RotaryEmbedding>,
63    paged_attn: Option<PagedAttention>,
64    sdpa_params: SdpaParams,
65}
66
67impl Attention {
68    fn new(
69        rotary_emb: Arc<RotaryEmbedding>,
70        cfg: &Config,
71        vb: ShardedVarBuilder,
72        paged_attn: Option<PagedAttention>,
73        comm: &Arc<mistralrs_quant::Comm>,
74    ) -> Result<Self> {
75        let hidden_sz = cfg.hidden_size;
76        let num_heads = cfg.num_attention_heads;
77        let num_kv_heads = cfg.num_key_value_heads;
78        let head_dim = hidden_sz / num_heads;
79        let q_proj = ColumnParallelLayer::new(
80            hidden_sz,
81            num_heads * head_dim,
82            &cfg.quantization_config,
83            false,
84            comm,
85            vb.pp("q_proj"),
86        )?;
87        let kv_shard = mistralrs_quant::compute_kv_shard(
88            cfg.num_key_value_heads,
89            cfg.hidden_size / cfg.num_attention_heads,
90            comm,
91        );
92        let k_proj = ColumnParallelLayer::new_with_shard(
93            hidden_sz,
94            num_kv_heads * head_dim,
95            &cfg.quantization_config,
96            false,
97            comm,
98            kv_shard,
99            vb.pp("k_proj"),
100        )?;
101        let v_proj = ColumnParallelLayer::new_with_shard(
102            hidden_sz,
103            num_kv_heads * head_dim,
104            &cfg.quantization_config,
105            false,
106            comm,
107            kv_shard,
108            vb.pp("v_proj"),
109        )?;
110        let o_proj = RowParallelLayer::new(
111            num_heads * head_dim,
112            hidden_sz,
113            &cfg.quantization_config,
114            false,
115            comm,
116            vb.pp("o_proj"),
117        )?;
118        Ok(Self {
119            q_proj,
120            k_proj,
121            v_proj,
122            o_proj,
123            num_heads: num_heads / comm.world_size(),
124            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
125            head_dim,
126            rotary_emb,
127            paged_attn,
128            sdpa_params: SdpaParams {
129                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
130                    cfg.num_key_value_heads,
131                    cfg.num_attention_heads,
132                    comm,
133                ),
134                use_flash_attn: cfg.use_flash_attn,
135                softcap: None,
136                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
137                sliding_window: cfg.sliding_window,
138            },
139        })
140    }
141
142    #[allow(clippy::too_many_arguments)]
143    fn forward(
144        &self,
145        xs: &Tensor,
146        attention_mask: Option<&Tensor>,
147        seqlen_offsets: &[usize],
148        kv_cache: &mut KvCache,
149        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
150        flash_params: &FlashParams,
151    ) -> Result<Tensor> {
152        let (b_sz, q_len, _) = xs.dims3()?;
153
154        let original_dtype = xs.dtype();
155        let mut xs = xs.clone();
156        if let Some(t) = self.q_proj.quantized_act_type() {
157            xs = xs.to_dtype(t)?;
158        }
159        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
160        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
161        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
162        if self.q_proj.quantized_act_type().is_some() {
163            q = q.to_dtype(original_dtype)?;
164            k = k.to_dtype(original_dtype)?;
165            v = v.to_dtype(original_dtype)?;
166        }
167
168        let (q, k, v) = if q_len != 1 {
169            let q = q
170                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
171                .transpose(1, 2)?;
172            let k = k
173                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
174                .transpose(1, 2)?;
175            let v = v
176                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
177                .transpose(1, 2)?;
178            (q, k, v)
179        } else {
180            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
181            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
183            (q, k, v)
184        };
185
186        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
187
188        let mut attn_output = match &self.paged_attn {
189            Some(paged_attn) => match metadata {
190                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
191                    &q,
192                    &k,
193                    &v,
194                    attention_mask,
195                    Some(key_cache),
196                    Some(value_cache),
197                    input_metadata,
198                    &self.sdpa_params,
199                    Some(flash_params),
200                )?,
201                None => {
202                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
203                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
204                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
205                    // Sanity check.
206                    assert!(attention_mask.is_some());
207                    paged_attn.forward(
208                        &q,
209                        &k,
210                        &v,
211                        attention_mask,
212                        None,
213                        None,
214                        &input_metadata,
215                        &self.sdpa_params,
216                        Some(flash_params),
217                    )?
218                }
219            },
220            None => {
221                let (k, v) = kv_cache.append(&k, &v)?;
222
223                Sdpa.run_attention(
224                    &q,
225                    &k,
226                    &v,
227                    attention_mask,
228                    Some(flash_params),
229                    &self.sdpa_params,
230                )?
231            }
232        };
233
234        if let Some(t) = self.q_proj.quantized_act_type() {
235            attn_output = attn_output.to_dtype(t)?;
236        }
237        attn_output = if attention_mask.is_some() {
238            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
239        } else {
240            attn_output.reshape((b_sz, q_len, ()))?
241        };
242        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
243        if self.q_proj.quantized_act_type().is_some() {
244            res = res.to_dtype(original_dtype)?;
245        }
246        Ok(res)
247    }
248}
249
250#[derive(Clone)]
251struct BlockSparseTop2MLP {
252    w1: Arc<dyn QuantMethod>,
253    w2: Arc<dyn QuantMethod>,
254    w3: Arc<dyn QuantMethod>,
255    act_fn: Activation,
256}
257
258impl BlockSparseTop2MLP {
259    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
260        let hidden_sz = cfg.hidden_size;
261        let intermediate_sz = cfg.intermediate_size;
262        let w1 = ColumnParallelLayer::new(
263            hidden_sz,
264            intermediate_sz,
265            &cfg.quantization_config,
266            false,
267            comm,
268            vb.pp("w1"),
269        )?;
270        let w2 = RowParallelLayer::new(
271            intermediate_sz,
272            hidden_sz,
273            &cfg.quantization_config,
274            false,
275            comm,
276            vb.pp("w2"),
277        )?;
278        let w3 = ColumnParallelLayer::new(
279            hidden_sz,
280            intermediate_sz,
281            &cfg.quantization_config,
282            false,
283            comm,
284            vb.pp("w3"),
285        )?;
286        Ok(Self {
287            w1,
288            w2,
289            w3,
290            act_fn: cfg.hidden_act,
291        })
292    }
293}
294
295impl Module for BlockSparseTop2MLP {
296    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
297        let original_dtype = xs.dtype();
298        let mut xs = xs.clone();
299        if let Some(t) = self.w1.quantized_act_type() {
300            xs = xs.to_dtype(t)?;
301        }
302        let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
303        let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
304        let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?;
305        if self.w1.quantized_act_type().is_some() {
306            res = res.to_dtype(original_dtype)?;
307        }
308        Ok(res)
309    }
310}
311
312#[derive(Clone)]
313struct SparseMoeBlock {
314    gate: Arc<dyn QuantMethod>,
315    experts: Vec<BlockSparseTop2MLP>,
316    num_experts_per_tok: usize,
317}
318
319impl SparseMoeBlock {
320    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
321        let gate = mistralrs_quant::linear_no_bias(
322            cfg.hidden_size,
323            cfg.num_local_experts,
324            &cfg.quantization_config,
325            vb.pp("gate"),
326        )?;
327        let mut experts = Vec::with_capacity(cfg.num_local_experts);
328        let vb = vb.pp("experts");
329        for idx in 0..cfg.num_local_experts {
330            let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx), comm)?;
331            experts.push(expert)
332        }
333        Ok(SparseMoeBlock {
334            gate,
335            experts,
336            num_experts_per_tok: cfg.num_experts_per_tok,
337        })
338    }
339}
340
341impl Module for SparseMoeBlock {
342    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
343        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
344        let xs = xs.reshape(((), hidden_dim))?;
345
346        let original_dtype = xs.dtype();
347        let mut xs = xs.clone();
348        if let Some(t) = self.gate.quantized_act_type() {
349            xs = xs.to_dtype(t)?;
350        }
351        let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?;
352        if self.gate.quantized_act_type().is_some() {
353            router_logits = router_logits.to_dtype(original_dtype)?;
354        }
355
356        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
357
358        // In order to extract topk, we extract the data from the tensor and manipulate it
359        // directly. Maybe we will want to use some custom ops instead at some point.
360        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
361
362        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
363        // top_x contains the row indexes to evaluate for each expert.
364        let mut top_x = vec![vec![]; self.experts.len()];
365        let mut selected_rws = vec![vec![]; self.experts.len()];
366        for (row_idx, rw) in routing_weights.iter().enumerate() {
367            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
368            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
369            let mut sum_routing_weights = 0f32;
370            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
371                let expert_idx = expert_idx as usize;
372                let routing_weight = rw[expert_idx];
373                sum_routing_weights += routing_weight;
374                top_x[expert_idx].push(row_idx as u32);
375            }
376            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
377                let expert_idx = expert_idx as usize;
378                let routing_weight = rw[expert_idx];
379                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
380            }
381        }
382
383        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
384        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
385
386        let mut ys = xs.zeros_like()?;
387        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
388            let top_x = &top_x[expert_idx];
389            if top_x.is_empty() {
390                continue;
391            }
392            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
393            let selected_rws =
394                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
395            // Index the correct hidden states and compute the expert hidden state for
396            // the current expert. We need to make sure to multiply the output hidden
397            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
398            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
399            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
400            let current_hidden_states = expert_layer.forward(&current_state)?;
401            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
402            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
403        }
404
405        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
406        Ok(ys)
407    }
408}
409
410struct DecoderLayer {
411    self_attn: Attention,
412    block_sparse_moe: SparseMoeBlock,
413    input_layernorm: RmsNorm,
414    post_attention_layernorm: RmsNorm,
415}
416
417impl DecoderLayer {
418    #[allow(clippy::too_many_arguments)]
419    fn new(
420        rotary_emb: Arc<RotaryEmbedding>,
421        cfg: &Config,
422        vb: ShardedVarBuilder,
423        mapper: &dyn DeviceMapper,
424        layer_idx: usize,
425        loading_isq: bool,
426        paged_attn: Option<PagedAttention>,
427        comm: &Arc<mistralrs_quant::Comm>,
428    ) -> Result<Self> {
429        let self_attn = Attention::new(
430            rotary_emb,
431            cfg,
432            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
433            paged_attn,
434            comm,
435        )?;
436        let block_sparse_moe = SparseMoeBlock::new(
437            cfg,
438            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
439            comm,
440        )?;
441        let input_layernorm = RmsNorm::new(
442            cfg.hidden_size,
443            cfg.rms_norm_eps,
444            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
445        )?;
446        let post_attention_layernorm = RmsNorm::new(
447            cfg.hidden_size,
448            cfg.rms_norm_eps,
449            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
450        )?;
451        Ok(Self {
452            self_attn,
453            block_sparse_moe,
454            input_layernorm,
455            post_attention_layernorm,
456        })
457    }
458
459    #[allow(clippy::too_many_arguments)]
460    fn forward(
461        &self,
462        xs: &Tensor,
463        attention_mask: Option<&Tensor>,
464        seqlen_offsets: &[usize],
465        kv_cache: &mut KvCache,
466        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
467        flash_params: &FlashParams,
468    ) -> Result<Tensor> {
469        let residual = xs;
470        let xs = self.input_layernorm.forward(xs)?;
471        let xs = self.self_attn.forward(
472            &xs,
473            attention_mask,
474            seqlen_offsets,
475            kv_cache,
476            metadata,
477            flash_params,
478        )?;
479        let xs = (xs + residual)?;
480        let residual = &xs;
481        let xs = xs
482            .apply(&self.post_attention_layernorm)?
483            .apply(&self.block_sparse_moe)?
484            .to_dtype(residual.dtype())?;
485        residual + xs
486    }
487}
488
489pub struct Model {
490    embed_tokens: candle_nn::Embedding,
491    layers: Vec<DecoderLayer>,
492    norm: RmsNorm,
493    lm_head: Arc<dyn QuantMethod>,
494    sliding_window: Option<usize>,
495    device: Device,
496    cache: EitherCache,
497    max_seq_len: usize,
498    mapper: Box<dyn DeviceMapper + Send + Sync>,
499    cfg: ModelConfigMetadata,
500}
501
502impl Model {
503    pub fn new(
504        cfg: &Config,
505        vb: ShardedVarBuilder,
506        is_gptx: bool,
507        normal_loading_metadata: NormalLoadingMetadata,
508        attention_mechanism: AttentionImplementation,
509    ) -> Result<Self> {
510        if let Some(ref quant_cfg) = &cfg.quantization_config {
511            tracing::info!(
512                "Using {} quantization: {}.",
513                quant_cfg.quant_method.to_string(),
514                quant_cfg.get_bits_name(&vb)
515            );
516        }
517        let mapper = normal_loading_metadata.mapper;
518        let vb_m = vb.pp("model");
519
520        let embed_tokens = layers::embedding(
521            cfg.vocab_size,
522            cfg.hidden_size,
523            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
524        )?;
525        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
526        let mut ropes = HashMap::new();
527        for layer_idx in 0..cfg.num_hidden_layers {
528            let device = mapper
529                .device_for(layer_idx, false)
530                .unwrap_or(&normal_loading_metadata.real_device);
531            ropes.insert(
532                device.location(),
533                Arc::new(RotaryEmbedding::new(
534                    cfg.rope_theta as f32,
535                    head_dim,
536                    cfg.max_position_embeddings,
537                    device,
538                    is_gptx,
539                    vb_m.dtype(),
540                )?),
541            );
542        }
543        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
544        let vb_l = vb_m.pp("layers");
545        for layer_idx in NiceProgressBar::<_, 'b'>(
546            0..cfg.num_hidden_layers,
547            "Loading repeating layers",
548            &normal_loading_metadata.multi_progress,
549        ) {
550            let device = mapper
551                .device_for(layer_idx, false)
552                .unwrap_or(&normal_loading_metadata.real_device);
553            let rotary_emb = ropes
554                .get(&device.location())
555                .expect("No RoPE for device location!")
556                .clone();
557            let paged_attn = match &attention_mechanism {
558                AttentionImplementation::Eager => None,
559                AttentionImplementation::PagedAttention => {
560                    Some(PagedAttention::new(head_dim, device, None)?)
561                }
562            };
563            let comm = mapper.get_comm_for(layer_idx)?;
564            let layer = DecoderLayer::new(
565                rotary_emb.clone(),
566                cfg,
567                vb_l.pp(layer_idx),
568                &*mapper,
569                layer_idx,
570                normal_loading_metadata.loading_isq,
571                paged_attn,
572                &comm,
573            )?;
574            layers.push(layer)
575        }
576        let norm = RmsNorm::new(
577            cfg.hidden_size,
578            cfg.rms_norm_eps,
579            mapper.set_nm_device(vb_m.pp("norm"), false),
580        )?;
581        let lm_head = if !cfg.tie_word_embeddings {
582            ReplicatedLayer::new(
583                cfg.hidden_size,
584                cfg.vocab_size,
585                &None,
586                false,
587                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
588            )?
589        } else {
590            ReplicatedLayer::from_linear(candle_nn::Linear::new(
591                mapper.cast_nm_device(
592                    embed_tokens.embeddings(),
593                    normal_loading_metadata.loading_isq,
594                )?,
595                None,
596            ))?
597        };
598        Ok(Self {
599            embed_tokens,
600            layers,
601            norm,
602            lm_head,
603            sliding_window: cfg.sliding_window,
604            device: normal_loading_metadata.real_device,
605            cache: EitherCache::Normal(NormalCache::new_sliding(
606                cfg.num_hidden_layers,
607                cfg.max_position_embeddings,
608                cfg.sliding_window,
609            )),
610            max_seq_len: cfg.max_position_embeddings,
611            cfg: ModelConfigMetadata {
612                max_seq_len: cfg.max_position_embeddings,
613                num_layers: cfg.num_hidden_layers,
614                hidden_size: cfg.hidden_size,
615                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
616                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
617                    .max(1),
618                sliding_window: cfg.sliding_window,
619                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
620                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
621            },
622            mapper,
623        })
624    }
625
626    pub fn forward(
627        &self,
628        input_ids: &Tensor,
629        seqlen_offsets: &[usize],
630        context_lens: Vec<(usize, usize)>,
631        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
632        flash_params: &FlashParams,
633    ) -> Result<Tensor> {
634        let mut xs = self.embed_tokens.forward(input_ids)?;
635        let cache = &mut self.cache.normal().0;
636        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
637            input_ids,
638            metadata
639                .as_ref()
640                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
641                .unwrap_or(cache as &dyn PastKvLenCache),
642            self.sliding_window,
643            xs.dtype(),
644            self.cfg.num_attn_heads,
645        )?;
646        // PagedAttention prompt chunking
647        let attention_mask = attention_mask.filter(|_| {
648            metadata
649                .as_ref()
650                .map(|(_, meta)| meta.is_first_prompt_chunk)
651                .unwrap_or(true)
652        });
653        for (i, layer) in self.layers.iter().enumerate() {
654            xs = self.mapper.map(xs, i)?;
655            xs = layer.forward(
656                &xs,
657                attention_mask
658                    .as_ref()
659                    .map(|m| m.to_device(xs.device()).unwrap())
660                    .as_ref(),
661                seqlen_offsets,
662                &mut cache[i],
663                metadata
664                    .as_ref()
665                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
666                flash_params,
667            )?;
668        }
669        let xs = xs.to_device(&self.device)?;
670        let mut xs = xs.apply(&self.norm)?;
671        if let Some(t) = self.lm_head.quantized_act_type() {
672            xs = xs.to_dtype(t)?;
673        }
674        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
675    }
676}
677
678impl IsqModel for Model {
679    fn get_layers(
680        &mut self,
681    ) -> (
682        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
683        &dyn DeviceMapper,
684    ) {
685        let mut tensors = Vec::new();
686        tensors.push((&mut self.lm_head, None));
687        for (i, layer) in self.layers.iter_mut().enumerate() {
688            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
689            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
690            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
691            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
692            tensors.push((&mut layer.block_sparse_moe.gate, Some(i)));
693            for expert in &mut layer.block_sparse_moe.experts {
694                tensors.push((&mut expert.w1, Some(i)));
695                tensors.push((&mut expert.w2, Some(i)));
696                tensors.push((&mut expert.w3, Some(i)));
697            }
698        }
699        (tensors, &*self.mapper)
700    }
701
702    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
703        let uvb = UnVarBuilder::new();
704
705        let uvb_m = uvb.pp("model");
706        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
707        uvb_m.pp("norm").add(&self.norm);
708
709        for (layer_idx, layer) in self.layers.iter().enumerate() {
710            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
711            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
712            uvb_l
713                .pp("post_attention_layernorm")
714                .add(&layer.post_attention_layernorm);
715        }
716
717        uvb.to_safetensors()
718    }
719}
720
721impl NormalModel for Model {
722    fn forward(
723        &self,
724        input_ids: &Tensor,
725        seqlen_offsets: &[usize],
726        context_lens: Vec<(usize, usize)>,
727        _position_ids: Vec<usize>,
728        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
729        flash_params: &FlashParams,
730    ) -> Result<Tensor> {
731        self.forward(
732            input_ids,
733            seqlen_offsets,
734            context_lens,
735            metadata,
736            flash_params,
737        )
738    }
739    fn xlora_forward(
740        &self,
741        _input_ids: &Tensor,
742        _input_ids_full: &Tensor,
743        _seqlen_offsets: &[usize],
744        _seqlen_offsets_full: &[usize],
745        _no_kv_cache: bool,
746        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
747        _context_lens: Vec<(usize, usize)>,
748        _position_ids: Vec<usize>,
749        _flash_params: &FlashParams,
750        _flash_params_full: &FlashParams,
751    ) -> Result<Tensor> {
752        unimplemented!()
753    }
754    fn cache(&self) -> &EitherCache {
755        &self.cache
756    }
757    fn cache_mut(&mut self) -> &mut EitherCache {
758        &mut self.cache
759    }
760    fn device(&self) -> &Device {
761        &self.device
762    }
763    fn is_xlora(&self) -> bool {
764        false
765    }
766    fn max_seq_len(&self) -> usize {
767        self.max_seq_len
768    }
769    fn config(&self) -> &ModelConfigMetadata {
770        &self.cfg
771    }
772}
773
774impl AnyMoeBaseModelMixin for Model {}