mistralrs_core/vision_models/llama4/
text.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6    distributed::layers::PackedExperts, linear_no_bias, ColumnParallelLayer, QuantMethod,
7    QuantizedConfig, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, SumAllReduce,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12    amoe::AnyMoeBaseModelMixin,
13    attention::SdpaParams,
14    device_map::DeviceMapper,
15    layers::{embedding, Activation, CausalMasker, Llama3RotaryEmbedding, RmsNorm, Sdpa},
16    layers_masker::PastKvLenCache,
17    ops::{TopKLastDimOp, TopKOutput},
18    paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19    pipeline::{
20        extract_logits,
21        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22        EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23    },
24    utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27use super::config::TextConfig;
28
29struct CausalSelfAttention {
30    q_proj: Arc<dyn QuantMethod>,
31    k_proj: Arc<dyn QuantMethod>,
32    v_proj: Arc<dyn QuantMethod>,
33    o_proj: Arc<dyn QuantMethod>,
34    num_attention_heads: usize,
35    num_key_value_heads: usize,
36    head_dim: usize,
37    rotary_emb: Arc<Llama3RotaryEmbedding>,
38    max_seq_len: usize,
39    paged_attn: Option<PagedAttention>,
40    sdpa_params: SdpaParams,
41    norm: Option<RmsNorm>,
42    use_rope: bool,
43    floor_scale: Option<f32>,
44    attn_scale: Option<f32>,
45    attn_temperature_tuning: Option<f32>,
46}
47
48impl CausalSelfAttention {
49    #[allow(clippy::too_many_arguments)]
50    fn new(
51        vb: ShardedVarBuilder,
52        cfg: &TextConfig,
53        layer_idx: usize,
54        loading_isq: bool,
55        mapper: &dyn DeviceMapper,
56        rope: Arc<Llama3RotaryEmbedding>,
57        paged_attn: Option<PagedAttention>,
58        comm: &Arc<mistralrs_quant::Comm>,
59    ) -> Result<Self> {
60        let size_in = cfg.hidden_size;
61        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
62        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
63        let q_proj = ColumnParallelLayer::new(
64            size_in,
65            size_q,
66            &cfg.quantization_config,
67            false,
68            comm,
69            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
70        )?;
71        let kv_shard = mistralrs_quant::compute_kv_shard(
72            cfg.num_key_value_heads,
73            cfg.hidden_size / cfg.num_attention_heads,
74            comm,
75        );
76        let k_proj = ColumnParallelLayer::new_with_shard(
77            size_in,
78            size_kv,
79            &cfg.quantization_config,
80            false,
81            comm,
82            kv_shard,
83            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
84        )?;
85        let v_proj = ColumnParallelLayer::new_with_shard(
86            size_in,
87            size_kv,
88            &cfg.quantization_config,
89            false,
90            comm,
91            kv_shard,
92            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
93        )?;
94        let o_proj = RowParallelLayer::new(
95            size_q,
96            size_in,
97            &cfg.quantization_config,
98            false,
99            comm,
100            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
101        )?;
102        let use_rope = (layer_idx + 1) % 4 != 0;
103        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
104        let norm = if cfg.use_qk_norm && use_rope {
105            let vb = mapper.set_device(layer_idx, vb, false);
106            Some(RmsNorm::from_w(
107                Tensor::ones(head_dim, vb.dtype(), vb.device())?,
108                1e-6,
109            )?)
110        } else {
111            None
112        };
113
114        Ok(Self {
115            q_proj,
116            k_proj,
117            v_proj,
118            o_proj,
119            num_attention_heads: cfg.num_attention_heads / comm.world_size(),
120            num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
121            head_dim,
122            rotary_emb: rope,
123            max_seq_len: cfg.max_position_embeddings,
124            paged_attn,
125            sdpa_params: SdpaParams {
126                n_kv_groups: mistralrs_quant::compute_n_kv_groups(
127                    cfg.num_key_value_heads,
128                    cfg.num_attention_heads,
129                    comm,
130                ),
131                use_flash_attn: cfg.use_flash_attn,
132                softcap: None,
133                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
134                sliding_window: None,
135            },
136            norm,
137            use_rope,
138            floor_scale: cfg.floor_scale,
139            attn_scale: cfg.attn_scale,
140            attn_temperature_tuning: cfg.attn_temperature_tuning,
141        })
142    }
143
144    #[allow(clippy::too_many_arguments)]
145    fn forward(
146        &self,
147        x: &Tensor,
148        position_ids: &Tensor,
149        attention_mask: &Option<Tensor>,
150        seqlen_offsets: &[usize],
151        kv_cache: &mut KvCache,
152        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
153        flash_params: &FlashParams,
154    ) -> Result<Tensor> {
155        let (b_sz, seq_len, _) = x.dims3()?;
156
157        let mut q = self.q_proj.forward_autocast(x)?;
158        let mut k = self.k_proj.forward_autocast(x)?;
159        let mut v = self.v_proj.forward_autocast(x)?;
160
161        q = q
162            .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
163            .transpose(1, 2)?;
164        k = k
165            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
166            .transpose(1, 2)?;
167        v = v
168            .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
169            .transpose(1, 2)?;
170
171        if self.use_rope {
172            (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
173        }
174
175        if let Some(qk_norm) = &self.norm {
176            q = qk_norm.forward(&q)?;
177            k = qk_norm.forward(&k)?;
178        }
179
180        if self.attn_temperature_tuning.is_some() && !self.use_rope {
181            let floor_scale = self.floor_scale.unwrap() as f64;
182            let attn_scale = self.attn_scale.unwrap() as f64;
183            let floor = ((position_ids.to_dtype(DType::F32)? + 1.)? / floor_scale)?.floor()?;
184            let attn_scales = (((floor + 1.0)?.log()? * attn_scale)? + 1.0)?;
185
186            q = q
187                .to_dtype(DType::F32)?
188                .broadcast_mul(&attn_scales.unsqueeze(D::Minus1)?)?
189                .to_dtype(q.dtype())?;
190        }
191
192        let mut y = match &self.paged_attn {
193            Some(paged_attn) => match metadata {
194                Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
195                    &q,
196                    &k,
197                    &v,
198                    attention_mask.clone().as_ref(),
199                    Some(key_cache),
200                    Some(value_cache),
201                    input_metadata,
202                    &self.sdpa_params,
203                    Some(flash_params),
204                )?,
205                None => {
206                    // If we don't have metadata, we are most likely generating an imatrix so we don't want to populate that.
207                    // Generating the dummy metadata with the assumption that we are not generating text (only processing prompts).
208                    let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
209                    // Sanity check.
210                    assert!(attention_mask.is_some());
211                    paged_attn.forward(
212                        &q,
213                        &k,
214                        &v,
215                        attention_mask.clone().as_ref(),
216                        None,
217                        None,
218                        &input_metadata,
219                        &self.sdpa_params,
220                        Some(flash_params),
221                    )?
222                }
223            },
224            None => {
225                let (k, v) = kv_cache.append(&k, &v)?;
226
227                Sdpa.run_attention(
228                    &q.contiguous()?,
229                    &k.contiguous()?,
230                    &v.contiguous()?,
231                    attention_mask.clone().as_ref(),
232                    Some(flash_params),
233                    &self.sdpa_params,
234                )?
235            }
236        };
237
238        y = if attention_mask.is_some() {
239            y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
240        } else {
241            y.reshape((b_sz, seq_len, ()))?
242        };
243        self.o_proj.forward_autocast(&y)
244    }
245}
246
247struct Mlp {
248    gate: Arc<dyn QuantMethod>,
249    up: Arc<dyn QuantMethod>,
250    down: Arc<dyn QuantMethod>,
251    act: Activation,
252}
253
254impl Mlp {
255    fn new(
256        vb: ShardedVarBuilder,
257        hidden_size: usize,
258        intermediate_size: usize,
259        quantization_config: &Option<QuantizedConfig>,
260        hidden_act: Activation,
261        comm: &Arc<mistralrs_quant::Comm>,
262    ) -> Result<Self> {
263        Ok(Self {
264            gate: ColumnParallelLayer::new(
265                hidden_size,
266                intermediate_size,
267                quantization_config,
268                false,
269                comm,
270                vb.pp("gate_proj"),
271            )?,
272            up: ColumnParallelLayer::new(
273                hidden_size,
274                intermediate_size,
275                quantization_config,
276                false,
277                comm,
278                vb.pp("up_proj"),
279            )?,
280            down: RowParallelLayer::new(
281                intermediate_size,
282                hidden_size,
283                quantization_config,
284                false,
285                comm,
286                vb.pp("down_proj"),
287            )?,
288            act: hidden_act,
289        })
290    }
291
292    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
293        let lhs = self.gate.forward_autocast(xs)?;
294        let rhs = self.up.forward_autocast(xs)?;
295
296        self.down.forward_autocast(&candle_nn::ops::mul_and_act(
297            &lhs,
298            &rhs,
299            self.act.try_into()?,
300        )?)
301    }
302}
303
304struct TextExperts {
305    gate_proj: Vec<Arc<dyn QuantMethod>>,
306    up_proj: Vec<Arc<dyn QuantMethod>>,
307    down_proj: Vec<Arc<dyn QuantMethod>>,
308    act: Activation,
309    hidden_size: usize,
310    sum_all_reduce: SumAllReduce,
311}
312
313impl TextExperts {
314    fn new(
315        vb: ShardedVarBuilder,
316        cfg: &TextConfig,
317        quantization_config: &Option<QuantizedConfig>,
318        comm: &Arc<mistralrs_quant::Comm>,
319    ) -> Result<Self> {
320        let PackedExperts {
321            gate_proj,
322            up_proj,
323            down_proj,
324        } = PackedExperts::new(
325            cfg.num_local_experts,
326            cfg.hidden_size,
327            cfg.intermediate_size,
328            quantization_config,
329            false,
330            comm,
331            vb,
332        )?;
333        Ok(Self {
334            gate_proj,
335            up_proj,
336            down_proj,
337            act: cfg.hidden_act,
338            hidden_size: cfg.hidden_size,
339            sum_all_reduce: SumAllReduce::new(comm),
340        })
341    }
342
343    // xs: (bs * seq_len, hidden_size)
344    // expert indices: (bs * seq_len)
345    fn forward(&self, xs: &Tensor, indices: &Tensor) -> Result<Tensor> {
346        let xs = xs.unsqueeze(1)?;
347
348        if self.gate_proj.len() == 1 {
349            let gate = self.gate_proj[0].gather_forward_autocast(&xs, indices)?;
350            let up = self.up_proj[0].gather_forward_autocast(&xs, indices)?;
351            let mut xs = self.down_proj[0]
352                .gather_forward_autocast(&(up * gate.apply(&self.act)?)?, indices)?;
353            xs = self.sum_all_reduce.sum_all_reduce(&xs)?;
354            xs.reshape(((), self.hidden_size))
355        } else {
356            let indices = indices.to_vec1::<u32>()?;
357            let mut results = Vec::new();
358            for (tok, id) in indices.into_iter().enumerate() {
359                let xs = xs.i(tok)?.reshape((1, self.hidden_size))?;
360
361                let res = {
362                    let gate = self.gate_proj[id as usize].forward_autocast(&xs)?;
363                    let up = self.up_proj[id as usize].forward_autocast(&xs)?;
364                    self.down_proj[id as usize].forward_autocast(&(up * gate.apply(&self.act)?)?)?
365                };
366                results.push(res);
367            }
368            let mut xs = Tensor::cat(&results, 0)?;
369            xs = self.sum_all_reduce.sum_all_reduce(&xs)?;
370            xs.reshape(((), self.hidden_size))
371        }
372    }
373}
374
375struct TextMoe {
376    experts: TextExperts,
377    shared_expert: Mlp,
378    router: Arc<dyn QuantMethod>,
379    topk: usize,
380}
381
382impl TextMoe {
383    fn new(
384        vb: ShardedVarBuilder,
385        cfg: &TextConfig,
386        quantization_config: &Option<QuantizedConfig>,
387        comm: &Arc<mistralrs_quant::Comm>,
388    ) -> Result<Self> {
389        let experts = TextExperts::new(vb.pp("experts"), cfg, quantization_config, comm)?;
390        let router = linear_no_bias(
391            cfg.hidden_size,
392            cfg.num_local_experts,
393            quantization_config,
394            vb.pp("router"),
395        )?;
396        let shared_expert = Mlp::new(
397            vb.pp("shared_expert"),
398            cfg.hidden_size,
399            cfg.intermediate_size,
400            quantization_config,
401            cfg.hidden_act,
402            comm,
403        )?;
404        Ok(Self {
405            experts,
406            shared_expert,
407            router,
408            topk: cfg.num_experts_per_tok,
409        })
410    }
411
412    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
413        let (bs, seq_len, hidden_dim) = xs.dims3()?;
414        let xs = xs.reshape(((), hidden_dim))?;
415        let router_logits = self.router.forward_autocast(&xs)?;
416
417        let TopKOutput {
418            values: router_top_value,
419            indices: router_indices,
420        } = router_logits.topk(self.topk)?;
421
422        let router_scores = candle_nn::ops::sigmoid(&router_top_value.to_dtype(DType::F32)?)?
423            .to_dtype(router_top_value.dtype())?;
424
425        let routed_in = xs.broadcast_mul(&router_scores)?;
426        let routed_out = self
427            .experts
428            .forward(&routed_in, &router_indices.squeeze(D::Minus1)?)?
429            .reshape((bs, seq_len, hidden_dim))?;
430        let out = self
431            .shared_expert
432            .forward(&xs.reshape((bs, seq_len, hidden_dim))?)?;
433
434        out + routed_out
435    }
436}
437
438enum MoeOrMlp {
439    Mlp(Mlp),
440    Moe(TextMoe),
441}
442
443impl MoeOrMlp {
444    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
445        match self {
446            Self::Mlp(l) => l.forward(xs),
447            Self::Moe(l) => l.forward(xs),
448        }
449    }
450}
451
452struct Block {
453    rms_1: RmsNorm,
454    attn: CausalSelfAttention,
455    rms_2: RmsNorm,
456    ff: MoeOrMlp,
457    use_chunked_attention: bool,
458}
459
460impl Block {
461    #[allow(clippy::too_many_arguments)]
462    fn new(
463        vb: ShardedVarBuilder,
464        cfg: &TextConfig,
465        mapper: &dyn DeviceMapper,
466        layer_idx: usize,
467        loading_isq: bool,
468        rope: Arc<Llama3RotaryEmbedding>,
469        paged_attn: Option<PagedAttention>,
470        comm: &Arc<mistralrs_quant::Comm>,
471    ) -> Result<Self> {
472        let use_chunked_attention = (layer_idx + 1) % 4 != 0;
473        let attn = CausalSelfAttention::new(
474            vb.pp("self_attn"),
475            cfg,
476            layer_idx,
477            loading_isq,
478            mapper,
479            rope,
480            paged_attn,
481            comm,
482        )?;
483        let is_moe_layer = cfg.moe_layers().contains(&layer_idx);
484        let ff = if is_moe_layer {
485            let moe = TextMoe::new(
486                mapper.set_device(layer_idx, vb.pp("feed_forward"), loading_isq),
487                cfg,
488                &cfg.quantization_config,
489                comm,
490            )?;
491            MoeOrMlp::Moe(moe)
492        } else {
493            let mlp = Mlp::new(
494                mapper.set_device(layer_idx, vb.pp("feed_forward"), loading_isq),
495                cfg.hidden_size,
496                cfg.intermediate_size_mlp,
497                &cfg.quantization_config,
498                cfg.hidden_act,
499                comm,
500            )?;
501            MoeOrMlp::Mlp(mlp)
502        };
503        let rms_1 = RmsNorm::new(
504            cfg.hidden_size,
505            cfg.rms_norm_eps,
506            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
507        )?;
508        let rms_2 = RmsNorm::new(
509            cfg.hidden_size,
510            cfg.rms_norm_eps,
511            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
512        )?;
513        Ok(Self {
514            rms_1,
515            attn,
516            rms_2,
517            ff,
518            use_chunked_attention,
519        })
520    }
521
522    #[allow(clippy::too_many_arguments)]
523    fn forward(
524        &self,
525        x: &Tensor,
526        position_ids: &Tensor,
527        attention_mask: &Option<Tensor>,
528        chunked_mask: &Option<Tensor>,
529        seqlen_offsets: &[usize],
530        kv_cache: &mut KvCache,
531        metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
532        flash_params: &FlashParams,
533    ) -> Result<Tensor> {
534        let residual = x;
535        let x = self.rms_1.forward(x)?;
536        let mask = if self.use_chunked_attention {
537            chunked_mask
538        } else {
539            attention_mask
540        };
541        let x = (self.attn.forward(
542            &x,
543            position_ids,
544            mask,
545            seqlen_offsets,
546            kv_cache,
547            metadata,
548            flash_params,
549        )? + residual)?;
550        let residual = &x;
551        let x = (self.ff.forward(&self.rms_2.forward(&x)?)? + residual)?;
552        Ok(x)
553    }
554}
555
556pub struct TextModel {
557    wte: Embedding,
558    blocks: Vec<Block>,
559    ln_f: RmsNorm,
560    lm_head: Arc<dyn QuantMethod>,
561    kv_cache: crate::pipeline::EitherCache,
562    device: Device,
563    mapper: Box<dyn DeviceMapper + Send + Sync>,
564    cfg: ModelConfigMetadata,
565    attention_chunk_size: usize,
566}
567
568impl TextModel {
569    pub fn new(
570        cfg: &TextConfig,
571        vb: ShardedVarBuilder,
572        is_gptx: bool,
573        normal_loading_metadata: NormalLoadingMetadata,
574        attention_mechanism: AttentionImplementation,
575    ) -> Result<Self> {
576        let vb_m = vb.pp("model");
577        let vb_lm_head = vb.pp("lm_head");
578        Self::new_inner(
579            cfg,
580            vb_m,
581            vb_lm_head,
582            is_gptx,
583            normal_loading_metadata,
584            attention_mechanism,
585        )
586    }
587
588    pub fn new_inner(
589        cfg: &TextConfig,
590        vb_m: ShardedVarBuilder,
591        vb_lm_head: ShardedVarBuilder,
592        is_gptx: bool,
593        normal_loading_metadata: NormalLoadingMetadata,
594        attention_mechanism: AttentionImplementation,
595    ) -> Result<Self> {
596        if let Some(ref quant_cfg) = &cfg.quantization_config {
597            tracing::info!(
598                "Using {} quantization: {}.",
599                quant_cfg.name(),
600                quant_cfg.get_bits_name(&vb_m)
601            );
602        }
603        let mapper = normal_loading_metadata.mapper;
604
605        let wte = embedding(
606            cfg.vocab_size,
607            cfg.hidden_size,
608            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
609            &cfg.quantization_config,
610        )?;
611        let lm_head = if !cfg.tie_word_embeddings {
612            ReplicatedLayer::new(
613                cfg.hidden_size,
614                cfg.vocab_size,
615                &None,
616                false,
617                mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
618            )?
619        } else {
620            ReplicatedLayer::from_linear(candle_nn::Linear::new(
621                mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
622                None,
623            ))?
624        };
625        let ln_f = RmsNorm::new(
626            cfg.hidden_size,
627            cfg.rms_norm_eps,
628            mapper.set_nm_device(vb_m.pp("norm"), false),
629        )?;
630        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
631        let mut ropes = HashMap::new();
632        for i in 0..cfg.num_hidden_layers {
633            let device = mapper
634                .device_for(i, false)
635                .unwrap_or(&normal_loading_metadata.real_device);
636            ropes.insert(
637                device.location(),
638                Arc::new(Llama3RotaryEmbedding::new_llama4(
639                    vb_m.dtype(),
640                    cfg,
641                    device,
642                    is_gptx,
643                )?),
644            );
645        }
646        let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
647            0..cfg.num_hidden_layers,
648            "Loading text repeating layers",
649            &normal_loading_metadata.multi_progress,
650        )
651        .into_iter()
652        .map(|i| {
653            let device = mapper
654                .device_for(i, false)
655                .unwrap_or(&normal_loading_metadata.real_device);
656            let rotary_emb = ropes
657                .get(&device.location())
658                .expect("No RoPE for device location!")
659                .clone();
660            let paged_attn = match &attention_mechanism {
661                AttentionImplementation::Eager => None,
662                AttentionImplementation::PagedAttention => Some(
663                    PagedAttention::new(head_dim, device, None)
664                        .expect("Failed to create PagedAttention"),
665                ),
666            };
667            let comm = mapper.get_comm_for(i).unwrap();
668            Block::new(
669                vb_m.pp(format!("layers.{i}")),
670                cfg,
671                &*mapper,
672                i,
673                normal_loading_metadata.loading_isq,
674                rotary_emb,
675                paged_attn,
676                &comm,
677            )
678            .expect("Failed to load block.")
679        })
680        .collect();
681
682        Ok(Self {
683            wte,
684            blocks,
685            ln_f,
686            lm_head,
687            kv_cache: EitherCache::Normal(NormalCache::new(
688                cfg.num_hidden_layers,
689                cfg.max_position_embeddings,
690            )),
691            device: normal_loading_metadata.real_device,
692            cfg: ModelConfigMetadata {
693                max_seq_len: cfg.max_position_embeddings,
694                num_layers: cfg.num_hidden_layers,
695                hidden_size: cfg.hidden_size,
696                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
697                    .max(1),
698                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
699                sliding_window: None,
700                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
702            },
703            mapper,
704            attention_chunk_size: cfg.attention_chunk_size,
705        })
706    }
707
708    pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
709        self.wte.forward(input_ids)
710    }
711
712    #[allow(clippy::too_many_arguments)]
713    pub fn forward_embeds(
714        &self,
715        input_ids: &Tensor,
716        input_embeds: Tensor,
717        seqlen_offsets: &[usize],
718        context_lens: Vec<(usize, usize)>,
719        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
720        flash_params: &FlashParams,
721    ) -> Result<Tensor> {
722        let mut x = input_embeds;
723        let cache = &mut self.kv_cache.normal().0;
724        let cache_for_mask = metadata
725            .as_ref()
726            .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
727            .unwrap_or(cache as &dyn PastKvLenCache);
728
729        let position_ids = Tensor::new(
730            seqlen_offsets.iter().map(|o| *o as i32).collect::<Vec<_>>(),
731            input_ids.device(),
732        )?;
733
734        let mask = CausalMasker.make_causal_mask_matrix(
735            input_ids,
736            cache_for_mask,
737            x.dtype(),
738            self.blocks[0].attn.num_attention_heads,
739        )?;
740        let chunked_mask = CausalMasker.make_chunked_mask_matrix(
741            input_ids,
742            self.attention_chunk_size,
743            cache_for_mask,
744            x.dtype(),
745            self.blocks[0].attn.num_attention_heads,
746        )?;
747        // PagedAttention prompt chunking
748        let mask = mask.filter(|_| {
749            metadata
750                .as_ref()
751                .map(|(_, meta)| meta.is_first_prompt_chunk)
752                .unwrap_or(true)
753        });
754        // PagedAttention prompt chunking
755        let chunked_mask = chunked_mask.filter(|_| {
756            metadata
757                .as_ref()
758                .map(|(_, meta)| meta.is_first_prompt_chunk)
759                .unwrap_or(true)
760        });
761        for (block_idx, block) in self.blocks.iter().enumerate() {
762            x = self.mapper.map(x, block_idx)?;
763            x = block.forward(
764                &x,
765                &position_ids.to_device(x.device())?,
766                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
767                &chunked_mask
768                    .clone()
769                    .map(|m| m.to_device(x.device()).unwrap()),
770                seqlen_offsets,
771                &mut cache[block_idx],
772                metadata
773                    .as_ref()
774                    .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
775                flash_params,
776            )?;
777        }
778        let mut x = x.to_device(&self.device)?;
779        x = self.ln_f.forward(&x)?;
780        x = self.lm_head.forward_autocast(&x)?;
781        extract_logits(&x, context_lens)
782    }
783
784    pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
785        uvb_m.pp("embed_tokens").add(&self.wte);
786        uvb_m.pp("norm").add(&self.ln_f);
787
788        for (layer_idx, layer) in self.blocks.iter().enumerate() {
789            let uvb_l = uvb_m.pp("layers").pp(layer_idx);
790            uvb_l.pp("input_layernorm").add(&layer.rms_1);
791            uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
792        }
793
794        uvb_m.to_safetensors()
795    }
796}
797
798impl IsqModel for TextModel {
799    fn get_layers(
800        &mut self,
801    ) -> (
802        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
803        &dyn DeviceMapper,
804    ) {
805        let mut tensors = Vec::new();
806        tensors.push((&mut self.lm_head, None));
807        for (i, layer) in self.blocks.iter_mut().enumerate() {
808            tensors.push((&mut layer.attn.q_proj, Some(i)));
809            tensors.push((&mut layer.attn.k_proj, Some(i)));
810            tensors.push((&mut layer.attn.v_proj, Some(i)));
811            tensors.push((&mut layer.attn.o_proj, Some(i)));
812            match &mut layer.ff {
813                MoeOrMlp::Mlp(x) => {
814                    tensors.push((&mut x.gate, Some(i)));
815                    tensors.push((&mut x.up, Some(i)));
816                    tensors.push((&mut x.down, Some(i)));
817                }
818                MoeOrMlp::Moe(x) => {
819                    tensors.push((&mut x.router, Some(i)));
820                    for g in &mut x.experts.gate_proj {
821                        tensors.push((g, Some(i)));
822                    }
823                    for u in &mut x.experts.up_proj {
824                        tensors.push((u, Some(i)));
825                    }
826                    for d in &mut x.experts.down_proj {
827                        tensors.push((d, Some(i)));
828                    }
829                    tensors.push((&mut x.shared_expert.gate, Some(i)));
830                    tensors.push((&mut x.shared_expert.up, Some(i)));
831                    tensors.push((&mut x.shared_expert.down, Some(i)));
832                }
833            }
834        }
835        (tensors, &*self.mapper)
836    }
837
838    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
839        let uvb = UnVarBuilder::new();
840        self.residual_tensors_m(uvb.pp("model"))
841    }
842}
843
844impl NormalModel for TextModel {
845    fn forward(
846        &self,
847        _input_ids: &Tensor,
848        _seqlen_offsets: &[usize],
849        _context_lens: Vec<(usize, usize)>,
850        _position_ids: Vec<usize>,
851        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
852        _flash_params: &FlashParams,
853    ) -> Result<Tensor> {
854        unreachable!()
855    }
856    fn xlora_forward(
857        &self,
858        _input_ids: &Tensor,
859        _input_ids_full: &Tensor,
860        _seqlen_offsets: &[usize],
861        _seqlen_offsets_full: &[usize],
862        _no_kv_cache: bool,
863        _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
864        _context_lens: Vec<(usize, usize)>,
865        _position_ids: Vec<usize>,
866        _flash_params: &FlashParams,
867        _flash_params_full: &FlashParams,
868    ) -> Result<Tensor> {
869        unimplemented!()
870    }
871    fn cache(&self) -> &crate::pipeline::EitherCache {
872        &self.kv_cache
873    }
874    fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
875        &mut self.kv_cache
876    }
877    fn device(&self) -> &Device {
878        &self.device
879    }
880    fn is_xlora(&self) -> bool {
881        false
882    }
883    fn max_seq_len(&self) -> usize {
884        self.blocks[0].attn.max_seq_len
885    }
886    fn config(&self) -> &ModelConfigMetadata {
887        &self.cfg
888    }
889}
890
891impl AnyMoeBaseModelMixin for TextModel {}