mistralrs_core/models/
mixtral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3/// Mixtral Model
4/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
5/// https://mistral.ai/news/mixtral-of-experts/
6use candle_core::{DType, Device, Module, Result, Tensor};
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9    ShardedVarBuilder,
10};
11use serde::{Deserialize, Serialize};
12use std::{collections::HashMap, sync::Arc};
13
14use crate::{
15    amoe::AnyMoeBaseModelMixin,
16    attention::SdpaParams,
17    device_map::DeviceMapper,
18    layers::{self, Activation, CausalMasker, MatMul, RmsNorm, RotaryEmbedding, Sdpa},
19    layers_masker::PastKvLenCache,
20    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21    pipeline::{
22        extract_logits,
23        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25    },
26    serde_default_fn,
27    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
33#[derive(Debug, Clone, Deserialize, Serialize)]
34pub struct Config {
35    pub(crate) vocab_size: usize,
36    pub(crate) hidden_size: usize,
37    pub(crate) intermediate_size: usize,
38    pub(crate) num_hidden_layers: usize,
39    pub(crate) num_attention_heads: usize,
40    pub(crate) num_key_value_heads: usize,
41    pub(crate) hidden_act: Activation,
42    pub(crate) max_position_embeddings: usize,
43    pub(crate) rms_norm_eps: f64,
44    pub(crate) rope_theta: f64,
45    pub(crate) sliding_window: Option<usize>,
46    pub(crate) num_experts_per_tok: usize,
47    pub(crate) num_local_experts: usize,
48    pub(crate) quantization_config: Option<QuantizedConfig>,
49    #[serde(default = "word_emb_default")]
50    pub(crate) tie_word_embeddings: bool,
51}
52
53struct Attention {
54    q_proj: Arc<dyn QuantMethod>,
55    k_proj: Arc<dyn QuantMethod>,
56    v_proj: Arc<dyn QuantMethod>,
57    o_proj: Arc<dyn QuantMethod>,
58    num_heads: usize,
59    num_kv_heads: usize,
60    head_dim: usize,
61    rotary_emb: Arc<RotaryEmbedding>,
62    paged_attn: Option<PagedAttention>,
63    sdpa_params: SdpaParams,
64}
65
66impl Attention {
67    fn new(
68        rotary_emb: Arc<RotaryEmbedding>,
69        cfg: &Config,
70        vb: ShardedVarBuilder,
71        paged_attn: Option<PagedAttention>,
72        comm: &Arc<mistralrs_quant::Comm>,
73    ) -> Result<Self> {
74        let hidden_sz = cfg.hidden_size;
75        let num_heads = cfg.num_attention_heads;
76        let num_kv_heads = cfg.num_key_value_heads;
77        let head_dim = hidden_sz / num_heads;
78        let q_proj = ColumnParallelLayer::new(
79            hidden_sz,
80            num_heads * head_dim,
81            &cfg.quantization_config,
82            false,
83            comm,
84            vb.pp("q_proj"),
85        )?;
86        let kv_shard = mistralrs_quant::compute_kv_shard(
87            cfg.num_key_value_heads,
88            cfg.hidden_size / cfg.num_attention_heads,
89            comm,
90        );
91        let k_proj = ColumnParallelLayer::new_with_shard(
92            hidden_sz,
93            num_kv_heads * head_dim,
94            &cfg.quantization_config,
95            false,
96            comm,
97            kv_shard,
98            vb.pp("k_proj"),
99        )?;
100        let v_proj = ColumnParallelLayer::new_with_shard(
101            hidden_sz,
102            num_kv_heads * head_dim,
103            &cfg.quantization_config,
104            false,
105            comm,
106            kv_shard,
107            vb.pp("v_proj"),
108        )?;
109        let o_proj = RowParallelLayer::new(
110            num_heads * head_dim,
111            hidden_sz,
112            &cfg.quantization_config,
113            false,
114            comm,
115            vb.pp("o_proj"),
116        )?;
117        Ok(Self {
118            q_proj,
119            k_proj,
120            v_proj,
121            o_proj,
122            num_heads: num_heads / comm.world_size(),
123            num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
124            head_dim,
125            rotary_emb,
126            paged_attn,
127            sdpa_params: SdpaParams {
128                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
129                    cfg.num_key_value_heads,
130                    cfg.num_attention_heads,
131                    comm,
132                ),
133                softcap: None,
134                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
135                sliding_window: cfg.sliding_window,
136            },
137        })
138    }
139
140    #[allow(clippy::too_many_arguments)]
141    fn forward(
142        &self,
143        xs: &Tensor,
144        attention_mask: Option<&Tensor>,
145        seqlen_offsets: &[usize],
146        kv_cache: &mut KvCache,
147        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
148        flash_params: &FlashParams,
149    ) -> Result<Tensor> {
150        let (b_sz, q_len, _) = xs.dims3()?;
151
152        let original_dtype = xs.dtype();
153        let mut xs = xs.clone();
154        if let Some(t) = self.q_proj.quantized_act_type() {
155            xs = xs.to_dtype(t)?;
156        }
157        let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
158        let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
159        let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
160        if self.q_proj.quantized_act_type().is_some() {
161            q = q.to_dtype(original_dtype)?;
162            k = k.to_dtype(original_dtype)?;
163            v = v.to_dtype(original_dtype)?;
164        }
165
166        let (q, k, v) = if q_len != 1 {
167            let q = q
168                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169                .transpose(1, 2)?;
170            let k = k
171                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172                .transpose(1, 2)?;
173            let v = v
174                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175                .transpose(1, 2)?;
176            (q, k, v)
177        } else {
178            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
179            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181            (q, k, v)
182        };
183
184        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
185
186        let mut attn_output = match &self.paged_attn {
187            Some(paged_attn) => match metadata {
188                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
189                    &q,
190                    &k,
191                    &v,
192                    attention_mask,
193                    Some(key_cache),
194                    Some(value_cache),
195                    input_metadata,
196                    &self.sdpa_params,
197                    Some(flash_params),
198                )?,
199                None => {
200                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
201                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
202                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
203                    // Sanity check.
204                    assert!(attention_mask.is_some());
205                    paged_attn.forward(
206                        &q,
207                        &k,
208                        &v,
209                        attention_mask,
210                        None,
211                        None,
212                        &input_metadata,
213                        &self.sdpa_params,
214                        Some(flash_params),
215                    )?
216                }
217            },
218            None => {
219                let (k, v) = kv_cache.append(&k, &v)?;
220
221                Sdpa.run_attention(
222                    &q,
223                    &k,
224                    &v,
225                    attention_mask,
226                    Some(flash_params),
227                    &self.sdpa_params,
228                )?
229            }
230        };
231
232        if let Some(t) = self.q_proj.quantized_act_type() {
233            attn_output = attn_output.to_dtype(t)?;
234        }
235        attn_output = if attention_mask.is_some() {
236            attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
237        } else {
238            attn_output.reshape((b_sz, q_len, ()))?
239        };
240        let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
241        if self.q_proj.quantized_act_type().is_some() {
242            res = res.to_dtype(original_dtype)?;
243        }
244        Ok(res)
245    }
246}
247
248#[derive(Clone)]
249struct BlockSparseTop2MLP {
250    w1: Arc<dyn QuantMethod>,
251    w2: Arc<dyn QuantMethod>,
252    w3: Arc<dyn QuantMethod>,
253    act_fn: Activation,
254}
255
256impl BlockSparseTop2MLP {
257    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
258        let hidden_sz = cfg.hidden_size;
259        let intermediate_sz = cfg.intermediate_size;
260        let w1 = ColumnParallelLayer::new(
261            hidden_sz,
262            intermediate_sz,
263            &cfg.quantization_config,
264            false,
265            comm,
266            vb.pp("w1"),
267        )?;
268        let w2 = RowParallelLayer::new(
269            intermediate_sz,
270            hidden_sz,
271            &cfg.quantization_config,
272            false,
273            comm,
274            vb.pp("w2"),
275        )?;
276        let w3 = ColumnParallelLayer::new(
277            hidden_sz,
278            intermediate_sz,
279            &cfg.quantization_config,
280            false,
281            comm,
282            vb.pp("w3"),
283        )?;
284        Ok(Self {
285            w1,
286            w2,
287            w3,
288            act_fn: cfg.hidden_act,
289        })
290    }
291}
292
293impl Module for BlockSparseTop2MLP {
294    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
295        let original_dtype = xs.dtype();
296        let mut xs = xs.clone();
297        if let Some(t) = self.w1.quantized_act_type() {
298            xs = xs.to_dtype(t)?;
299        }
300        let lhs = MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
301        let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
302        let mut res = MatMul.qmethod_matmul(&(lhs * rhs)?, &*self.w2)?;
303        if self.w1.quantized_act_type().is_some() {
304            res = res.to_dtype(original_dtype)?;
305        }
306        Ok(res)
307    }
308}
309
310#[derive(Clone)]
311struct SparseMoeBlock {
312    gate: Arc<dyn QuantMethod>,
313    experts: Vec<BlockSparseTop2MLP>,
314    num_experts_per_tok: usize,
315}
316
317impl SparseMoeBlock {
318    fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
319        let gate = mistralrs_quant::linear_no_bias(
320            cfg.hidden_size,
321            cfg.num_local_experts,
322            &cfg.quantization_config,
323            vb.pp("gate"),
324        )?;
325        let mut experts = Vec::with_capacity(cfg.num_local_experts);
326        let vb = vb.pp("experts");
327        for idx in 0..cfg.num_local_experts {
328            let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx), comm)?;
329            experts.push(expert)
330        }
331        Ok(SparseMoeBlock {
332            gate,
333            experts,
334            num_experts_per_tok: cfg.num_experts_per_tok,
335        })
336    }
337}
338
339impl Module for SparseMoeBlock {
340    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
341        let (b_size, seq_len, hidden_dim) = xs.dims3()?;
342        let xs = xs.reshape(((), hidden_dim))?;
343
344        let original_dtype = xs.dtype();
345        let mut xs = xs.clone();
346        if let Some(t) = self.gate.quantized_act_type() {
347            xs = xs.to_dtype(t)?;
348        }
349        let mut router_logits = MatMul.qmethod_matmul(&xs, &*self.gate)?;
350        if self.gate.quantized_act_type().is_some() {
351            router_logits = router_logits.to_dtype(original_dtype)?;
352        }
353
354        let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
355
356        // In order to extract topk, we extract the data from the tensor and manipulate it
357        // directly. Maybe we will want to use some custom ops instead at some point.
358        let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
359
360        // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
361        // top_x contains the row indexes to evaluate for each expert.
362        let mut top_x = vec![vec![]; self.experts.len()];
363        let mut selected_rws = vec![vec![]; self.experts.len()];
364        for (row_idx, rw) in routing_weights.iter().enumerate() {
365            let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
366            dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
367            let mut sum_routing_weights = 0f32;
368            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
369                let expert_idx = expert_idx as usize;
370                let routing_weight = rw[expert_idx];
371                sum_routing_weights += routing_weight;
372                top_x[expert_idx].push(row_idx as u32);
373            }
374            for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
375                let expert_idx = expert_idx as usize;
376                let routing_weight = rw[expert_idx];
377                selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
378            }
379        }
380
381        // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
382        // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
383
384        let mut ys = xs.zeros_like()?;
385        for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
386            let top_x = &top_x[expert_idx];
387            if top_x.is_empty() {
388                continue;
389            }
390            let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
391            let selected_rws =
392                Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
393            // Index the correct hidden states and compute the expert hidden state for
394            // the current expert. We need to make sure to multiply the output hidden
395            // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
396            let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
397            // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
398            let current_hidden_states = expert_layer.forward(&current_state)?;
399            let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
400            ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
401        }
402
403        let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
404        Ok(ys)
405    }
406}
407
408struct DecoderLayer {
409    self_attn: Attention,
410    block_sparse_moe: SparseMoeBlock,
411    input_layernorm: RmsNorm,
412    post_attention_layernorm: RmsNorm,
413}
414
415impl DecoderLayer {
416    #[allow(clippy::too_many_arguments)]
417    fn new(
418        rotary_emb: Arc<RotaryEmbedding>,
419        cfg: &Config,
420        vb: ShardedVarBuilder,
421        mapper: &dyn DeviceMapper,
422        layer_idx: usize,
423        loading_isq: bool,
424        paged_attn: Option<PagedAttention>,
425        comm: &Arc<mistralrs_quant::Comm>,
426    ) -> Result<Self> {
427        let self_attn = Attention::new(
428            rotary_emb,
429            cfg,
430            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
431            paged_attn,
432            comm,
433        )?;
434        let block_sparse_moe = SparseMoeBlock::new(
435            cfg,
436            mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
437            comm,
438        )?;
439        let input_layernorm = RmsNorm::new(
440            cfg.hidden_size,
441            cfg.rms_norm_eps,
442            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
443        )?;
444        let post_attention_layernorm = RmsNorm::new(
445            cfg.hidden_size,
446            cfg.rms_norm_eps,
447            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
448        )?;
449        Ok(Self {
450            self_attn,
451            block_sparse_moe,
452            input_layernorm,
453            post_attention_layernorm,
454        })
455    }
456
457    #[allow(clippy::too_many_arguments)]
458    fn forward(
459        &self,
460        xs: &Tensor,
461        attention_mask: Option<&Tensor>,
462        seqlen_offsets: &[usize],
463        kv_cache: &mut KvCache,
464        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
465        flash_params: &FlashParams,
466    ) -> Result<Tensor> {
467        let residual = xs;
468        let xs = self.input_layernorm.forward(xs)?;
469        let xs = self.self_attn.forward(
470            &xs,
471            attention_mask,
472            seqlen_offsets,
473            kv_cache,
474            metadata,
475            flash_params,
476        )?;
477        let xs = (xs + residual)?;
478        let residual = &xs;
479        let xs = xs
480            .apply(&self.post_attention_layernorm)?
481            .apply(&self.block_sparse_moe)?
482            .to_dtype(residual.dtype())?;
483        residual + xs
484    }
485}
486
487pub struct Model {
488    embed_tokens: candle_nn::Embedding,
489    layers: Vec<DecoderLayer>,
490    norm: RmsNorm,
491    lm_head: Arc<dyn QuantMethod>,
492    sliding_window: Option<usize>,
493    device: Device,
494    cache: EitherCache,
495    max_seq_len: usize,
496    mapper: Box<dyn DeviceMapper + Send + Sync>,
497    cfg: ModelConfigMetadata,
498}
499
500impl Model {
501    pub fn new(
502        cfg: &Config,
503        vb: ShardedVarBuilder,
504        is_gptx: bool,
505        normal_loading_metadata: NormalLoadingMetadata,
506        attention_mechanism: AttentionImplementation,
507    ) -> Result<Self> {
508        if let Some(ref quant_cfg) = &cfg.quantization_config {
509            tracing::info!(
510                "Using {} quantization: {}.",
511                quant_cfg.name(),
512                quant_cfg.get_bits_name(&vb)
513            );
514        }
515        let mapper = normal_loading_metadata.mapper;
516        let vb_m = vb.pp("model");
517
518        let embed_tokens = layers::embedding(
519            cfg.vocab_size,
520            cfg.hidden_size,
521            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
522            &cfg.quantization_config,
523        )?;
524        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
525        let mut ropes = HashMap::new();
526        for layer_idx in 0..cfg.num_hidden_layers {
527            let device = mapper
528                .device_for(layer_idx, false)
529                .unwrap_or(&normal_loading_metadata.real_device);
530            ropes.insert(
531                device.location(),
532                Arc::new(RotaryEmbedding::new(
533                    cfg.rope_theta as f32,
534                    head_dim,
535                    cfg.max_position_embeddings,
536                    device,
537                    is_gptx,
538                    vb_m.dtype(),
539                )?),
540            );
541        }
542        let vb_l = vb_m.pp("layers");
543        let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
544            0..cfg.num_hidden_layers,
545            "Loading repeating layers",
546            &normal_loading_metadata.multi_progress,
547        )
548        .par_iter_if_isq(|layer_idx| {
549            let device = mapper
550                .device_for(layer_idx, false)
551                .unwrap_or(&normal_loading_metadata.real_device);
552            let rotary_emb = ropes
553                .get(&device.location())
554                .expect("No RoPE for device location!")
555                .clone();
556            let paged_attn = match &attention_mechanism {
557                AttentionImplementation::Eager => None,
558                AttentionImplementation::PagedAttention => {
559                    Some(PagedAttention::new(head_dim, device, None)?)
560                }
561            };
562            let comm = mapper.get_comm_for(layer_idx)?;
563            DecoderLayer::new(
564                rotary_emb.clone(),
565                cfg,
566                vb_l.pp(layer_idx),
567                &*mapper,
568                layer_idx,
569                normal_loading_metadata.loading_isq,
570                paged_attn,
571                &comm,
572            )
573        })?;
574        let norm = RmsNorm::new(
575            cfg.hidden_size,
576            cfg.rms_norm_eps,
577            mapper.set_nm_device(vb_m.pp("norm"), false),
578        )?;
579        let lm_head = if !cfg.tie_word_embeddings {
580            ReplicatedLayer::new(
581                cfg.hidden_size,
582                cfg.vocab_size,
583                &cfg.quantization_config,
584                false,
585                mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
586            )?
587        } else {
588            ReplicatedLayer::from_linear(candle_nn::Linear::new(
589                mapper.cast_nm_device(
590                    embed_tokens.embeddings(),
591                    normal_loading_metadata.loading_isq,
592                )?,
593                None,
594            ))?
595        };
596        Ok(Self {
597            embed_tokens,
598            layers,
599            norm,
600            lm_head,
601            sliding_window: cfg.sliding_window,
602            device: normal_loading_metadata.real_device,
603            cache: EitherCache::Normal(NormalCache::new_sliding(
604                cfg.num_hidden_layers,
605                cfg.max_position_embeddings,
606                cfg.sliding_window,
607            )),
608            max_seq_len: cfg.max_position_embeddings,
609            cfg: ModelConfigMetadata {
610                max_seq_len: cfg.max_position_embeddings,
611                num_layers: cfg.num_hidden_layers,
612                hidden_size: cfg.hidden_size,
613                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
614                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
615                    .max(1),
616                sliding_window: cfg.sliding_window,
617                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
618                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
619            },
620            mapper,
621        })
622    }
623
624    pub fn forward(
625        &self,
626        input_ids: &Tensor,
627        seqlen_offsets: &[usize],
628        context_lens: Vec<(usize, usize)>,
629        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
630        flash_params: &FlashParams,
631    ) -> Result<Tensor> {
632        let mut xs = self.embed_tokens.forward(input_ids)?;
633        let cache = &mut self.cache.normal().0;
634        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
635            input_ids,
636            metadata
637                .as_ref()
638                .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
639                .unwrap_or(cache as &dyn PastKvLenCache),
640            self.sliding_window,
641            xs.dtype(),
642            self.cfg.num_attn_heads,
643        )?;
644        // PagedAttention prompt chunking
645        let attention_mask = attention_mask.filter(|_| {
646            metadata
647                .as_ref()
648                .map(|(_, meta)| meta.is_first_prompt_chunk)
649                .unwrap_or(true)
650        });
651        for (i, layer) in self.layers.iter().enumerate() {
652            xs = self.mapper.map(xs, i)?;
653            xs = layer.forward(
654                &xs,
655                attention_mask
656                    .as_ref()
657                    .map(|m| m.to_device(xs.device()).unwrap())
658                    .as_ref(),
659                seqlen_offsets,
660                &mut cache[i],
661                metadata
662                    .as_ref()
663                    .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
664                flash_params,
665            )?;
666        }
667        let xs = xs.to_device(&self.device)?;
668        let mut xs = xs.apply(&self.norm)?;
669        if let Some(t) = self.lm_head.quantized_act_type() {
670            xs = xs.to_dtype(t)?;
671        }
672        extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
673    }
674}
675
676impl IsqModel for Model {
677    fn get_layers(
678        &mut self,
679    ) -> (
680        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
681        &dyn DeviceMapper,
682    ) {
683        let mut tensors = Vec::new();
684        tensors.push((&mut self.lm_head, None));
685        for (i, layer) in self.layers.iter_mut().enumerate() {
686            tensors.push((&mut layer.self_attn.q_proj, Some(i)));
687            tensors.push((&mut layer.self_attn.k_proj, Some(i)));
688            tensors.push((&mut layer.self_attn.v_proj, Some(i)));
689            tensors.push((&mut layer.self_attn.o_proj, Some(i)));
690            tensors.push((&mut layer.block_sparse_moe.gate, Some(i)));
691            for expert in &mut layer.block_sparse_moe.experts {
692                tensors.push((&mut expert.w1, Some(i)));
693                tensors.push((&mut expert.w2, Some(i)));
694                tensors.push((&mut expert.w3, Some(i)));
695            }
696        }
697        (tensors, &*self.mapper)
698    }
699
700    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
701        let uvb = UnVarBuilder::new();
702
703        let uvb_m = uvb.pp("model");
704        uvb_m.pp("embed_tokens").add(&self.embed_tokens);
705        uvb_m.pp("norm").add(&self.norm);
706
707        for (layer_idx, layer) in self.layers.iter().enumerate() {
708            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
709            uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
710            uvb_l
711                .pp("post_attention_layernorm")
712                .add(&layer.post_attention_layernorm);
713        }
714
715        uvb.to_safetensors()
716    }
717}
718
719impl NormalModel for Model {
720    fn forward(
721        &self,
722        input_ids: &Tensor,
723        seqlen_offsets: &[usize],
724        context_lens: Vec<(usize, usize)>,
725        _position_ids: Vec<usize>,
726        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
727        flash_params: &FlashParams,
728    ) -> Result<Tensor> {
729        self.forward(
730            input_ids,
731            seqlen_offsets,
732            context_lens,
733            metadata,
734            flash_params,
735        )
736    }
737    fn xlora_forward(
738        &self,
739        _input_ids: &Tensor,
740        _input_ids_full: &Tensor,
741        _seqlen_offsets: &[usize],
742        _seqlen_offsets_full: &[usize],
743        _no_kv_cache: bool,
744        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
745        _context_lens: Vec<(usize, usize)>,
746        _position_ids: Vec<usize>,
747        _flash_params: &FlashParams,
748        _flash_params_full: &FlashParams,
749    ) -> Result<Tensor> {
750        unimplemented!()
751    }
752    fn cache(&self) -> &EitherCache {
753        &self.cache
754    }
755    fn cache_mut(&mut self) -> &mut EitherCache {
756        &mut self.cache
757    }
758    fn device(&self) -> &Device {
759        &self.device
760    }
761    fn is_xlora(&self) -> bool {
762        false
763    }
764    fn max_seq_len(&self) -> usize {
765        self.max_seq_len
766    }
767    fn config(&self) -> &ModelConfigMetadata {
768        &self.cfg
769    }
770}
771
772impl AnyMoeBaseModelMixin for Model {}