mistralrs_core/vision_models/mllama/
text.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, IndexOp, Result, Tensor};
6use candle_nn::{Activation, Embedding, Module};
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
9};
10
11use crate::{
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{embedding, CausalMasker, Llama3RotaryEmbedding, RmsNorm, Sdpa},
15    layers_masker::PastKvLenCache,
16    paged_attention::{AttentionImplementation, ModelConfigMetadata},
17    pipeline::{
18        extract_logits, EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata,
19    },
20    utils::unvarbuilder::UnVarBuilder,
21};
22
23use super::config::MLlamaTextConfig;
24
25struct MLlamaTextMlp {
26    gate_proj: Arc<dyn QuantMethod>,
27    up_proj: Arc<dyn QuantMethod>,
28    down_proj: Arc<dyn QuantMethod>,
29    act: Activation,
30}
31
32impl MLlamaTextMlp {
33    fn new(
34        cfg: &MLlamaTextConfig,
35        vb: ShardedVarBuilder,
36        comm: &Arc<mistralrs_quant::Comm>,
37    ) -> Result<Self> {
38        Ok(Self {
39            gate_proj: ColumnParallelLayer::new(
40                cfg.hidden_size,
41                cfg.intermediate_size,
42                &cfg.quantization_config,
43                false,
44                comm,
45                vb.pp("gate_proj"),
46            )?,
47            up_proj: ColumnParallelLayer::new(
48                cfg.hidden_size,
49                cfg.intermediate_size,
50                &cfg.quantization_config,
51                false,
52                comm,
53                vb.pp("up_proj"),
54            )?,
55            down_proj: RowParallelLayer::new(
56                cfg.intermediate_size,
57                cfg.hidden_size,
58                &cfg.quantization_config,
59                false,
60                comm,
61                vb.pp("down_proj"),
62            )?,
63            act: cfg.hidden_act,
64        })
65    }
66
67    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
68        let original_dtype = xs.dtype();
69        let mut xs = xs.clone();
70        if let Some(t) = self.gate_proj.quantized_act_type() {
71            xs = xs.to_dtype(t)?;
72        }
73        let mut res = self.down_proj.forward(
74            &self
75                .act
76                .forward(&self.gate_proj.forward(&xs)?)?
77                .broadcast_mul(&self.up_proj.forward(&xs)?)?,
78        )?;
79        if self.gate_proj.quantized_act_type().is_some() {
80            res = res.to_dtype(original_dtype)?;
81        }
82        Ok(res)
83    }
84}
85
86struct MLlamaTextSelfAttention {
87    q_proj: Arc<dyn QuantMethod>,
88    k_proj: Arc<dyn QuantMethod>,
89    v_proj: Arc<dyn QuantMethod>,
90    o_proj: Arc<dyn QuantMethod>,
91    sdpa_params: SdpaParams,
92    rope: Arc<Llama3RotaryEmbedding>,
93    num_heads: usize,
94    num_kv_heads: usize,
95    head_dim: usize,
96}
97
98impl MLlamaTextSelfAttention {
99    fn new(
100        cfg: &MLlamaTextConfig,
101        vb: ShardedVarBuilder,
102        rope: Arc<Llama3RotaryEmbedding>,
103        comm: &Arc<mistralrs_quant::Comm>,
104    ) -> Result<Self> {
105        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
106
107        Ok(Self {
108            q_proj: ColumnParallelLayer::new(
109                cfg.hidden_size,
110                cfg.num_attention_heads * cfg.head_dim(),
111                &cfg.quantization_config,
112                false,
113                comm,
114                vb.pp("q_proj"),
115            )?,
116            k_proj: ColumnParallelLayer::new(
117                cfg.hidden_size,
118                cfg.num_key_value_heads * cfg.head_dim(),
119                &cfg.quantization_config,
120                false,
121                comm,
122                vb.pp("k_proj"),
123            )?,
124            v_proj: ColumnParallelLayer::new(
125                cfg.hidden_size,
126                cfg.num_key_value_heads * cfg.head_dim(),
127                &cfg.quantization_config,
128                false,
129                comm,
130                vb.pp("v_proj"),
131            )?,
132            o_proj: RowParallelLayer::new(
133                cfg.num_attention_heads * cfg.head_dim(),
134                cfg.hidden_size,
135                &cfg.quantization_config,
136                false,
137                comm,
138                vb.pp("o_proj"),
139            )?,
140            sdpa_params: SdpaParams {
141                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
142                use_flash_attn: false,
143                softcap: None,
144                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
145                sliding_window: None,
146            },
147            rope,
148            num_heads: cfg.num_attention_heads / comm.world_size(),
149            num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
150            head_dim,
151        })
152    }
153
154    fn forward(
155        &self,
156        hidden_states: &Tensor,
157        attention_mask: Option<&Tensor>,
158        seqlen_offsets: &[usize],
159        kv_cache: &mut KvCache,
160    ) -> Result<Tensor> {
161        let (bs, q_len, _) = hidden_states.dims3()?;
162
163        let mut hidden_states = hidden_states.clone();
164        let original_dtype = hidden_states.dtype();
165        if let Some(t) = self.q_proj.quantized_act_type() {
166            hidden_states = hidden_states.to_dtype(t)?;
167        }
168        let mut q = self.q_proj.forward(&hidden_states)?;
169        let mut k = self.k_proj.forward(&hidden_states)?;
170        let mut v = self.v_proj.forward(&hidden_states)?;
171        if self.q_proj.quantized_act_type().is_some() {
172            q = q.to_dtype(original_dtype)?;
173            k = k.to_dtype(original_dtype)?;
174            v = v.to_dtype(original_dtype)?;
175        }
176
177        let (q, k, mut v) = if q_len != 1 {
178            let q = q
179                .reshape((bs, q_len, self.num_heads, self.head_dim))?
180                .transpose(1, 2)?;
181            let k = k
182                .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
183                .transpose(1, 2)?;
184            let v = v
185                .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
186                .transpose(1, 2)?;
187            (q, k, v)
188        } else {
189            let q = q.reshape((bs, self.num_heads, q_len, self.head_dim))?;
190            let k = k.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
191            let v = v.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
192            (q, k, v)
193        };
194
195        let (q, mut k) = self.rope.forward(&q, &k, seqlen_offsets)?;
196
197        (k, v) = kv_cache.append(&k, &v)?;
198
199        let mut attn_output = Sdpa
200            .run_attention(
201                &q.contiguous()?,
202                &k.contiguous()?,
203                &v.contiguous()?,
204                attention_mask,
205                None,
206                &self.sdpa_params,
207            )?
208            .transpose(1, 2)?
209            .contiguous()?
210            .reshape((bs, q_len, ()))?
211            .to_dtype(q.dtype())?;
212
213        if let Some(t) = self.q_proj.quantized_act_type() {
214            attn_output = attn_output.to_dtype(t)?;
215        }
216        let mut res = self.o_proj.forward(&attn_output)?;
217        if self.q_proj.quantized_act_type().is_some() {
218            res = res.to_dtype(original_dtype)?;
219        }
220        Ok(res)
221    }
222}
223
224struct MLlamaSelfAttentionDecoderLayer {
225    attn: MLlamaTextSelfAttention,
226    mlp: MLlamaTextMlp,
227    input_layernorm: RmsNorm,
228    post_attention_layernorm: RmsNorm,
229}
230
231impl MLlamaSelfAttentionDecoderLayer {
232    fn new(
233        cfg: &MLlamaTextConfig,
234        vb: ShardedVarBuilder,
235        rope: Arc<Llama3RotaryEmbedding>,
236        mapper: &dyn DeviceMapper,
237        layer_idx: usize,
238        loading_isq: bool,
239        comm: &Arc<mistralrs_quant::Comm>,
240    ) -> Result<Self> {
241        let mlp = MLlamaTextMlp::new(
242            cfg,
243            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
244            comm,
245        )?;
246        let input_layernorm = RmsNorm::new(
247            cfg.hidden_size,
248            cfg.rms_norm_eps,
249            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
250        )?;
251        let post_attention_layernorm = RmsNorm::new(
252            cfg.hidden_size,
253            cfg.rms_norm_eps,
254            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
255        )?;
256        let attn = MLlamaTextSelfAttention::new(
257            cfg,
258            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
259            rope,
260            comm,
261        )?;
262
263        Ok(Self {
264            attn,
265            mlp,
266            input_layernorm,
267            post_attention_layernorm,
268        })
269    }
270
271    fn forward(
272        &self,
273        hidden_states: &Tensor,
274        attention_mask: Option<&Tensor>,
275        seqlen_offsets: &[usize],
276        kv_cache: &mut KvCache,
277    ) -> Result<Tensor> {
278        let residual = hidden_states;
279
280        let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
281
282        hidden_states =
283            self.attn
284                .forward(&hidden_states, attention_mask, seqlen_offsets, kv_cache)?;
285        hidden_states = (residual + hidden_states)?;
286
287        let residual = &hidden_states;
288        let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
289        hidden_states = self.mlp.forward(&hidden_states)?;
290
291        residual + hidden_states
292    }
293}
294
295struct MLlamaTextCrossAttention {
296    q_proj: Arc<dyn QuantMethod>,
297    k_proj: Arc<dyn QuantMethod>,
298    v_proj: Arc<dyn QuantMethod>,
299    o_proj: Arc<dyn QuantMethod>,
300    q_norm: RmsNorm,
301    k_norm: RmsNorm,
302    num_heads: usize,
303    num_kv_heads: usize,
304    head_dim: usize,
305    sdpa_params: SdpaParams,
306}
307
308impl MLlamaTextCrossAttention {
309    fn new(
310        cfg: &MLlamaTextConfig,
311        vb: ShardedVarBuilder,
312        mapper: &dyn DeviceMapper,
313        layer_idx: usize,
314        comm: &Arc<mistralrs_quant::Comm>,
315    ) -> Result<Self> {
316        Ok(Self {
317            q_proj: ColumnParallelLayer::new(
318                cfg.hidden_size,
319                cfg.num_attention_heads * cfg.head_dim(),
320                &cfg.quantization_config,
321                false,
322                comm,
323                vb.pp("q_proj"),
324            )?,
325            k_proj: ColumnParallelLayer::new(
326                cfg.hidden_size,
327                cfg.num_key_value_heads * cfg.head_dim(),
328                &cfg.quantization_config,
329                false,
330                comm,
331                vb.pp("k_proj"),
332            )?,
333            v_proj: ColumnParallelLayer::new(
334                cfg.hidden_size,
335                cfg.num_key_value_heads * cfg.head_dim(),
336                &cfg.quantization_config,
337                false,
338                comm,
339                vb.pp("v_proj"),
340            )?,
341            o_proj: RowParallelLayer::new(
342                cfg.num_attention_heads * cfg.head_dim(),
343                cfg.hidden_size,
344                &cfg.quantization_config,
345                false,
346                comm,
347                vb.pp("o_proj"),
348            )?,
349            q_norm: RmsNorm::new(
350                cfg.head_dim(),
351                cfg.rms_norm_eps,
352                mapper.set_device(layer_idx, vb.pp("q_norm"), false),
353            )?,
354            k_norm: RmsNorm::new(
355                cfg.head_dim(),
356                cfg.rms_norm_eps,
357                mapper.set_device(layer_idx, vb.pp("k_norm"), false),
358            )?,
359            num_heads: cfg.num_attention_heads / comm.world_size(),
360            num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
361            head_dim: cfg.head_dim(),
362            sdpa_params: SdpaParams {
363                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
364                use_flash_attn: false,
365                softcap: None,
366                softmax_scale: 1.0 / (cfg.head_dim() as f32).sqrt(),
367                sliding_window: None,
368            },
369        })
370    }
371
372    fn forward(
373        &self,
374        hidden_states: &Tensor,
375        cross_attn_states: Option<&Tensor>,
376        attention_mask: Option<&Tensor>,
377    ) -> Result<Tensor> {
378        let (bs, q_len, _) = hidden_states.dims3()?;
379
380        let mut hidden_states = hidden_states.clone();
381        let original_dtype = hidden_states.dtype();
382        if let Some(t) = self.q_proj.quantized_act_type() {
383            hidden_states = hidden_states.to_dtype(t)?;
384        }
385        let mut q = self.q_proj.forward(&hidden_states)?;
386        if self.q_proj.quantized_act_type().is_some() {
387            q = q.to_dtype(original_dtype)?;
388        }
389        q = q
390            .reshape((bs, q_len, self.num_heads, self.head_dim))?
391            .transpose(1, 2)?;
392        q = self.q_norm.forward(&q)?;
393
394        let (k, v) = if let Some(cross_attn_states) = cross_attn_states {
395            let mut cross_attn_states = cross_attn_states.clone();
396            let original_dtype = cross_attn_states.dtype();
397            if let Some(t) = self.k_proj.quantized_act_type() {
398                cross_attn_states = cross_attn_states.to_dtype(t)?;
399            }
400            let mut k = self.k_proj.forward(&cross_attn_states)?;
401            k = k
402                .reshape((bs, (), self.num_kv_heads, self.head_dim))?
403                .transpose(1, 2)?;
404            if self.q_proj.quantized_act_type().is_some() {
405                k = k.to_dtype(original_dtype)?;
406            }
407            k = self.k_norm.forward(&k)?;
408
409            let mut v = self.v_proj.forward(&cross_attn_states)?;
410            if self.q_proj.quantized_act_type().is_some() {
411                v = v.to_dtype(original_dtype)?;
412            }
413            v = v
414                .reshape((bs, (), self.num_kv_heads, self.head_dim))?
415                .transpose(1, 2)?;
416
417            (k, v)
418        } else {
419            candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
420        };
421
422        let mut attn_output = Sdpa
423            .run_attention(
424                &q.contiguous()?,
425                &k.contiguous()?,
426                &v.contiguous()?,
427                attention_mask
428                    .map(|m| m.repeat((1, self.num_heads, 1, 1)).unwrap())
429                    .as_ref(),
430                None,
431                &self.sdpa_params,
432            )?
433            .transpose(1, 2)?
434            .contiguous()?
435            .reshape((bs, q_len, ()))?
436            .to_dtype(q.dtype())?;
437
438        if let Some(t) = self.q_proj.quantized_act_type() {
439            attn_output = attn_output.to_dtype(t)?;
440        }
441        let mut res = self.o_proj.forward(&attn_output)?;
442        if self.q_proj.quantized_act_type().is_some() {
443            res = res.to_dtype(original_dtype)?;
444        }
445        Ok(res)
446    }
447}
448
449struct MLlamaCrossAttentionDecoderLayer {
450    attn: MLlamaTextCrossAttention,
451    attn_gate: Tensor,
452    mlp: MLlamaTextMlp,
453    mlp_gate: Tensor,
454    input_layernorm: RmsNorm,
455    post_attention_layernorm: RmsNorm,
456}
457
458impl MLlamaCrossAttentionDecoderLayer {
459    fn new(
460        cfg: &MLlamaTextConfig,
461        vb: ShardedVarBuilder,
462        mapper: &dyn DeviceMapper,
463        layer_idx: usize,
464        loading_isq: bool,
465        comm: &Arc<mistralrs_quant::Comm>,
466    ) -> Result<Self> {
467        let mlp = MLlamaTextMlp::new(
468            cfg,
469            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
470            comm,
471        )?;
472        let input_layernorm = RmsNorm::new(
473            cfg.hidden_size,
474            cfg.rms_norm_eps,
475            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
476        )?;
477        let post_attention_layernorm = RmsNorm::new(
478            cfg.hidden_size,
479            cfg.rms_norm_eps,
480            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
481        )?;
482        let attn = MLlamaTextCrossAttention::new(
483            cfg,
484            mapper.set_device(layer_idx, vb.pp("cross_attn"), loading_isq),
485            mapper,
486            layer_idx,
487            comm,
488        )?;
489
490        Ok(Self {
491            attn,
492            mlp,
493            input_layernorm,
494            post_attention_layernorm,
495            attn_gate: mapper
496                .set_device(layer_idx, vb.clone(), false)
497                .get((1,), "cross_attn_attn_gate")?,
498            mlp_gate: mapper
499                .set_device(layer_idx, vb.clone(), false)
500                .get((1,), "cross_attn_mlp_gate")?,
501        })
502    }
503
504    fn forward(
505        &self,
506        hidden_states: &Tensor,
507        cross_attn_states: Option<&Tensor>,
508        attention_mask: Option<&Tensor>,
509        full_text_row_masked_out_mask: Option<&Tensor>,
510    ) -> Result<Tensor> {
511        let residual = hidden_states;
512
513        let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
514
515        hidden_states = self
516            .attn
517            .forward(&hidden_states, cross_attn_states, attention_mask)?;
518        hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?;
519
520        let residual = &hidden_states;
521        let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
522        hidden_states = self.mlp.forward(&hidden_states)?;
523        if let Some(full_text_row_masked_out_mask) = full_text_row_masked_out_mask {
524            hidden_states = full_text_row_masked_out_mask
525                .to_dtype(hidden_states.dtype())?
526                .i((.., 0))?
527                .broadcast_mul(&hidden_states)?;
528        }
529
530        residual + hidden_states.broadcast_mul(&self.mlp_gate.tanh()?)?
531    }
532}
533
534enum MLlamaDecoderLayer {
535    CrossAttn(MLlamaCrossAttentionDecoderLayer),
536    SelfAttn(MLlamaSelfAttentionDecoderLayer),
537}
538
539pub(super) struct MLlamaTextModel {
540    embed_tokens: Embedding,
541    lm_head: Arc<dyn QuantMethod>,
542    norm: RmsNorm,
543    layers: Vec<MLlamaDecoderLayer>,
544    pub(crate) cfg: ModelConfigMetadata,
545    pub(crate) cache: EitherCache,
546    pub(crate) device: Device,
547    pub(crate) max_position_embeddings: usize,
548    mapper: Box<dyn DeviceMapper + Send + Sync>,
549}
550
551impl MLlamaTextModel {
552    pub(super) fn new(
553        cfg: &MLlamaTextConfig,
554        vb: ShardedVarBuilder,
555        is_gptx: bool,
556        normal_loading_metadata: NormalLoadingMetadata,
557        attention_mechanism: AttentionImplementation,
558    ) -> Result<Self> {
559        if let Some(ref quant_cfg) = &cfg.quantization_config {
560            tracing::info!(
561                "Using {} quantization: {}.",
562                quant_cfg.name(),
563                quant_cfg.get_bits_name(&vb)
564            );
565        }
566        if !matches!(attention_mechanism, AttentionImplementation::Eager) {
567            candle_core::bail!("Expected eager attention implementation");
568        }
569        let mapper = normal_loading_metadata.mapper;
570
571        let embed_tokens = embedding(
572            cfg.vocab_size + 8,
573            cfg.hidden_size,
574            mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
575            &cfg.quantization_config,
576        )?;
577
578        let lm_head = if !cfg.tie_word_embeddings {
579            ReplicatedLayer::new(
580                cfg.hidden_size,
581                cfg.vocab_size,
582                &None,
583                false,
584                mapper.set_nm_device(vb.pp("lm_head"), false),
585            )?
586        } else {
587            ReplicatedLayer::from_linear(candle_nn::Linear::new(
588                mapper.cast_nm_device(embed_tokens.embeddings(), false)?,
589                None,
590            ))?
591        };
592
593        let vb = vb.pp("model");
594
595        let norm = RmsNorm::new(
596            cfg.hidden_size,
597            cfg.rms_norm_eps,
598            mapper.set_nm_device(vb.pp("norm"), false),
599        )?;
600
601        let mut ropes = HashMap::new();
602        for layer_idx in 0..cfg.num_hidden_layers {
603            let device = mapper
604                .device_for(layer_idx, false)
605                .unwrap_or(&normal_loading_metadata.real_device);
606            ropes.insert(
607                device.location(),
608                Arc::new(Llama3RotaryEmbedding::new_mllama3(
609                    vb.dtype(),
610                    cfg,
611                    device,
612                    is_gptx,
613                )?),
614            );
615        }
616
617        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
618        for i in 0..cfg.num_hidden_layers {
619            let comm = mapper.get_comm_for(i)?;
620            if cfg.cross_attention_layers.contains(&i) {
621                layers.push(MLlamaDecoderLayer::CrossAttn(
622                    MLlamaCrossAttentionDecoderLayer::new(
623                        cfg,
624                        vb.pp(format!("layers.{i}")),
625                        &*mapper,
626                        i,
627                        false,
628                        &comm,
629                    )?,
630                ))
631            } else {
632                let device = mapper
633                    .device_for(i, false)
634                    .unwrap_or(&normal_loading_metadata.real_device);
635                layers.push(MLlamaDecoderLayer::SelfAttn(
636                    MLlamaSelfAttentionDecoderLayer::new(
637                        cfg,
638                        vb.pp(format!("layers.{i}")),
639                        ropes
640                            .get(&device.location())
641                            .expect("No RoPE for device location!")
642                            .clone(),
643                        &*mapper,
644                        i,
645                        normal_loading_metadata.loading_isq,
646                        &comm,
647                    )?,
648                ))
649            }
650        }
651
652        Ok(Self {
653            embed_tokens,
654            layers,
655            norm,
656            lm_head,
657            cfg: ModelConfigMetadata {
658                max_seq_len: cfg.max_position_embeddings,
659                num_layers: cfg.num_hidden_layers,
660                hidden_size: cfg.hidden_size,
661                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
662                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
663                    .max(1),
664                sliding_window: None,
665                k_head_dim: cfg.head_dim(),
666                v_head_dim: cfg.head_dim(),
667            },
668            cache: EitherCache::Normal(NormalCache::new(
669                cfg.num_hidden_layers,
670                cfg.max_position_embeddings,
671            )),
672            device: normal_loading_metadata.real_device,
673            max_position_embeddings: cfg.max_position_embeddings,
674            mapper,
675        })
676    }
677
678    #[allow(clippy::too_many_arguments)]
679    pub(super) fn forward(
680        &self,
681        input_ids: &Tensor,
682        cross_attn_states: Option<&Tensor>,
683        cross_attention_mask: Option<&Tensor>,
684        full_text_row_masked_out_mask: Option<&Tensor>,
685        seqlen_offsets: &[usize],
686        context_lens: Vec<(usize, usize)>,
687    ) -> Result<Tensor> {
688        let mut hidden_states = self.embed_tokens.forward(input_ids)?;
689
690        let cache = &mut self.cache.normal().0;
691        let self_mask = CausalMasker.make_causal_mask_matrix(
692            input_ids,
693            cache as &dyn PastKvLenCache,
694            hidden_states.dtype(),
695            self.cfg.num_attn_heads,
696        )?;
697
698        for (i, layer) in self.layers.iter().enumerate() {
699            hidden_states = self.mapper.map(hidden_states, i)?;
700            match layer {
701                MLlamaDecoderLayer::SelfAttn(attn) => {
702                    hidden_states = attn.forward(
703                        &hidden_states,
704                        self_mask
705                            .as_ref()
706                            .map(|m| m.to_device(hidden_states.device()).unwrap())
707                            .as_ref(),
708                        seqlen_offsets,
709                        &mut cache[i],
710                    )?;
711                }
712                MLlamaDecoderLayer::CrossAttn(attn) => {
713                    // For text-only path we should skip cross attention layers.
714                    // Let's check if the layer is cross attention layer and if we have cross attention states
715                    // or cached cross attention states.
716                    if cross_attn_states.is_none() {
717                        continue;
718                    }
719                    hidden_states = attn.forward(
720                        &hidden_states,
721                        cross_attn_states
722                            .as_ref()
723                            .map(|x| x.to_device(hidden_states.device()).unwrap())
724                            .as_ref(),
725                        cross_attention_mask
726                            .as_ref()
727                            .map(|m| m.to_device(hidden_states.device()).unwrap())
728                            .as_ref(),
729                        full_text_row_masked_out_mask
730                            .as_ref()
731                            .map(|m| m.to_device(hidden_states.device()).unwrap())
732                            .as_ref(),
733                    )?;
734                }
735            }
736        }
737
738        hidden_states = hidden_states.to_device(&self.device)?;
739        hidden_states = self.norm.forward(&hidden_states)?;
740
741        hidden_states = self
742            .lm_head
743            .forward(&extract_logits(&hidden_states, context_lens)?)?;
744
745        Ok(hidden_states)
746    }
747}
748
749impl IsqModel for MLlamaTextModel {
750    fn get_layers(
751        &mut self,
752    ) -> (
753        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
754        &dyn DeviceMapper,
755    ) {
756        let mut tensors = Vec::new();
757        for (i, layer) in self.layers.iter_mut().enumerate() {
758            match layer {
759                MLlamaDecoderLayer::CrossAttn(_cross) => {
760                    // tensors.push((&mut cross.attn.q_proj, Some(i)));
761                    // tensors.push((&mut cross.attn.k_proj, Some(i)));
762                    // tensors.push((&mut cross.attn.v_proj, Some(i)));
763                    // tensors.push((&mut cross.attn.o_proj, Some(i)));
764                    // tensors.push((&mut cross.mlp.gate_proj, Some(i)));
765                    // tensors.push((&mut cross.mlp.up_proj, Some(i)));
766                    // tensors.push((&mut cross.mlp.down_proj, Some(i)));
767                }
768                MLlamaDecoderLayer::SelfAttn(self_attn) => {
769                    tensors.push((&mut self_attn.attn.q_proj, Some(i)));
770                    tensors.push((&mut self_attn.attn.k_proj, Some(i)));
771                    tensors.push((&mut self_attn.attn.v_proj, Some(i)));
772                    tensors.push((&mut self_attn.attn.o_proj, Some(i)));
773                    tensors.push((&mut self_attn.mlp.gate_proj, Some(i)));
774                    tensors.push((&mut self_attn.mlp.up_proj, Some(i)));
775                    tensors.push((&mut self_attn.mlp.down_proj, Some(i)));
776                }
777            }
778        }
779        (tensors, &*self.mapper)
780    }
781
782    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
783        let uvb = UnVarBuilder::new();
784
785        uvb.pp("model.embed_tokens").add(&self.embed_tokens);
786        uvb.pp("lm_head").add(&self.lm_head);
787
788        let uvb = uvb.pp("model");
789
790        uvb.pp("norm").add(&self.norm);
791
792        for (i, layer) in self.layers.iter().enumerate() {
793            let uvb_l = uvb.pp("layers").pp(i);
794            match layer {
795                MLlamaDecoderLayer::CrossAttn(crossattn) => {
796                    // Cross attention layers are not quantized
797                    uvb_l
798                        .pp("post_attention_layernorm")
799                        .add(&crossattn.post_attention_layernorm);
800                    uvb_l.pp("input_layernorm").add(&crossattn.input_layernorm);
801                    uvb_l.add_tensor("cross_attn_attn_gate", crossattn.attn_gate.clone());
802                    uvb_l.add_tensor("cross_attn_mlp_gate", crossattn.mlp_gate.clone());
803
804                    let uvb_attn = uvb_l.pp("cross_attn");
805                    uvb_attn.pp("q_proj").add(&crossattn.attn.q_proj);
806                    uvb_attn.pp("k_proj").add(&crossattn.attn.k_proj);
807                    uvb_attn.pp("v_proj").add(&crossattn.attn.v_proj);
808                    uvb_attn.pp("o_proj").add(&crossattn.attn.o_proj);
809                    uvb_attn.pp("q_norm").add(&crossattn.attn.q_norm);
810                    uvb_attn.pp("k_norm").add(&crossattn.attn.k_norm);
811
812                    let uvb_mlp = uvb_l.pp("mlp");
813                    uvb_mlp.pp("gate_proj").add(&crossattn.mlp.gate_proj);
814                    uvb_mlp.pp("up_proj").add(&crossattn.mlp.up_proj);
815                    uvb_mlp.pp("down_proj").add(&crossattn.mlp.down_proj);
816                }
817                MLlamaDecoderLayer::SelfAttn(selfattn) => {
818                    uvb_l
819                        .pp("post_attention_layernorm")
820                        .add(&selfattn.post_attention_layernorm);
821                    uvb_l.pp("input_layernorm").add(&selfattn.input_layernorm);
822                }
823            }
824        }
825
826        uvb.to_safetensors()
827    }
828}