mistralrs_core/vision_models/mllama/
text.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, IndexOp, Result, Tensor};
6use candle_nn::{Activation, Embedding, Module};
7use mistralrs_quant::{
8    ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
9};
10
11use crate::{
12    attention::SdpaParams,
13    device_map::DeviceMapper,
14    layers::{embedding, CausalMasker, Llama3RotaryEmbedding, RmsNorm, Sdpa},
15    layers_masker::PastKvLenCache,
16    paged_attention::{AttentionImplementation, ModelConfigMetadata},
17    pipeline::{
18        extract_logits, EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata,
19    },
20    utils::unvarbuilder::UnVarBuilder,
21};
22
23use super::config::MLlamaTextConfig;
24
25struct MLlamaTextMlp {
26    gate_proj: Arc<dyn QuantMethod>,
27    up_proj: Arc<dyn QuantMethod>,
28    down_proj: Arc<dyn QuantMethod>,
29    act: Activation,
30}
31
32impl MLlamaTextMlp {
33    fn new(
34        cfg: &MLlamaTextConfig,
35        vb: ShardedVarBuilder,
36        comm: &Arc<mistralrs_quant::Comm>,
37    ) -> Result<Self> {
38        Ok(Self {
39            gate_proj: ColumnParallelLayer::new(
40                cfg.hidden_size,
41                cfg.intermediate_size,
42                &cfg.quantization_config,
43                false,
44                comm,
45                vb.pp("gate_proj"),
46            )?,
47            up_proj: ColumnParallelLayer::new(
48                cfg.hidden_size,
49                cfg.intermediate_size,
50                &cfg.quantization_config,
51                false,
52                comm,
53                vb.pp("up_proj"),
54            )?,
55            down_proj: RowParallelLayer::new(
56                cfg.intermediate_size,
57                cfg.hidden_size,
58                &cfg.quantization_config,
59                false,
60                comm,
61                vb.pp("down_proj"),
62            )?,
63            act: cfg.hidden_act,
64        })
65    }
66
67    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
68        let original_dtype = xs.dtype();
69        let mut xs = xs.clone();
70        if let Some(t) = self.gate_proj.quantized_act_type() {
71            xs = xs.to_dtype(t)?;
72        }
73        let mut res = self.down_proj.forward(
74            &self
75                .act
76                .forward(&self.gate_proj.forward(&xs)?)?
77                .broadcast_mul(&self.up_proj.forward(&xs)?)?,
78        )?;
79        if self.gate_proj.quantized_act_type().is_some() {
80            res = res.to_dtype(original_dtype)?;
81        }
82        Ok(res)
83    }
84}
85
86struct MLlamaTextSelfAttention {
87    q_proj: Arc<dyn QuantMethod>,
88    k_proj: Arc<dyn QuantMethod>,
89    v_proj: Arc<dyn QuantMethod>,
90    o_proj: Arc<dyn QuantMethod>,
91    sdpa_params: SdpaParams,
92    rope: Arc<Llama3RotaryEmbedding>,
93    num_heads: usize,
94    num_kv_heads: usize,
95    head_dim: usize,
96}
97
98impl MLlamaTextSelfAttention {
99    fn new(
100        cfg: &MLlamaTextConfig,
101        vb: ShardedVarBuilder,
102        rope: Arc<Llama3RotaryEmbedding>,
103        comm: &Arc<mistralrs_quant::Comm>,
104    ) -> Result<Self> {
105        let head_dim = cfg.hidden_size / cfg.num_attention_heads;
106
107        Ok(Self {
108            q_proj: ColumnParallelLayer::new(
109                cfg.hidden_size,
110                cfg.num_attention_heads * cfg.head_dim(),
111                &cfg.quantization_config,
112                false,
113                comm,
114                vb.pp("q_proj"),
115            )?,
116            k_proj: ColumnParallelLayer::new(
117                cfg.hidden_size,
118                cfg.num_key_value_heads * cfg.head_dim(),
119                &cfg.quantization_config,
120                false,
121                comm,
122                vb.pp("k_proj"),
123            )?,
124            v_proj: ColumnParallelLayer::new(
125                cfg.hidden_size,
126                cfg.num_key_value_heads * cfg.head_dim(),
127                &cfg.quantization_config,
128                false,
129                comm,
130                vb.pp("v_proj"),
131            )?,
132            o_proj: RowParallelLayer::new(
133                cfg.num_attention_heads * cfg.head_dim(),
134                cfg.hidden_size,
135                &cfg.quantization_config,
136                false,
137                comm,
138                vb.pp("o_proj"),
139            )?,
140            sdpa_params: SdpaParams {
141                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
142                softcap: None,
143                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
144                sliding_window: None,
145            },
146            rope,
147            num_heads: cfg.num_attention_heads / comm.world_size(),
148            num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
149            head_dim,
150        })
151    }
152
153    fn forward(
154        &self,
155        hidden_states: &Tensor,
156        attention_mask: Option<&Tensor>,
157        seqlen_offsets: &[usize],
158        kv_cache: &mut KvCache,
159    ) -> Result<Tensor> {
160        let (bs, q_len, _) = hidden_states.dims3()?;
161
162        let mut hidden_states = hidden_states.clone();
163        let original_dtype = hidden_states.dtype();
164        if let Some(t) = self.q_proj.quantized_act_type() {
165            hidden_states = hidden_states.to_dtype(t)?;
166        }
167        let mut q = self.q_proj.forward(&hidden_states)?;
168        let mut k = self.k_proj.forward(&hidden_states)?;
169        let mut v = self.v_proj.forward(&hidden_states)?;
170        if self.q_proj.quantized_act_type().is_some() {
171            q = q.to_dtype(original_dtype)?;
172            k = k.to_dtype(original_dtype)?;
173            v = v.to_dtype(original_dtype)?;
174        }
175
176        let (q, k, mut v) = if q_len != 1 {
177            let q = q
178                .reshape((bs, q_len, self.num_heads, self.head_dim))?
179                .transpose(1, 2)?;
180            let k = k
181                .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
182                .transpose(1, 2)?;
183            let v = v
184                .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
185                .transpose(1, 2)?;
186            (q, k, v)
187        } else {
188            let q = q.reshape((bs, self.num_heads, q_len, self.head_dim))?;
189            let k = k.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
190            let v = v.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
191            (q, k, v)
192        };
193
194        let (q, mut k) = self.rope.forward(&q, &k, seqlen_offsets)?;
195
196        (k, v) = kv_cache.append(&k, &v)?;
197
198        let mut attn_output = Sdpa
199            .run_attention(
200                &q.contiguous()?,
201                &k.contiguous()?,
202                &v.contiguous()?,
203                attention_mask,
204                None,
205                &self.sdpa_params,
206            )?
207            .transpose(1, 2)?
208            .contiguous()?
209            .reshape((bs, q_len, ()))?
210            .to_dtype(q.dtype())?;
211
212        if let Some(t) = self.q_proj.quantized_act_type() {
213            attn_output = attn_output.to_dtype(t)?;
214        }
215        let mut res = self.o_proj.forward(&attn_output)?;
216        if self.q_proj.quantized_act_type().is_some() {
217            res = res.to_dtype(original_dtype)?;
218        }
219        Ok(res)
220    }
221}
222
223struct MLlamaSelfAttentionDecoderLayer {
224    attn: MLlamaTextSelfAttention,
225    mlp: MLlamaTextMlp,
226    input_layernorm: RmsNorm,
227    post_attention_layernorm: RmsNorm,
228}
229
230impl MLlamaSelfAttentionDecoderLayer {
231    fn new(
232        cfg: &MLlamaTextConfig,
233        vb: ShardedVarBuilder,
234        rope: Arc<Llama3RotaryEmbedding>,
235        mapper: &dyn DeviceMapper,
236        layer_idx: usize,
237        loading_isq: bool,
238        comm: &Arc<mistralrs_quant::Comm>,
239    ) -> Result<Self> {
240        let mlp = MLlamaTextMlp::new(
241            cfg,
242            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
243            comm,
244        )?;
245        let input_layernorm = RmsNorm::new(
246            cfg.hidden_size,
247            cfg.rms_norm_eps,
248            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
249        )?;
250        let post_attention_layernorm = RmsNorm::new(
251            cfg.hidden_size,
252            cfg.rms_norm_eps,
253            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
254        )?;
255        let attn = MLlamaTextSelfAttention::new(
256            cfg,
257            mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
258            rope,
259            comm,
260        )?;
261
262        Ok(Self {
263            attn,
264            mlp,
265            input_layernorm,
266            post_attention_layernorm,
267        })
268    }
269
270    fn forward(
271        &self,
272        hidden_states: &Tensor,
273        attention_mask: Option<&Tensor>,
274        seqlen_offsets: &[usize],
275        kv_cache: &mut KvCache,
276    ) -> Result<Tensor> {
277        let residual = hidden_states;
278
279        let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
280
281        hidden_states =
282            self.attn
283                .forward(&hidden_states, attention_mask, seqlen_offsets, kv_cache)?;
284        hidden_states = (residual + hidden_states)?;
285
286        let residual = &hidden_states;
287        let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
288        hidden_states = self.mlp.forward(&hidden_states)?;
289
290        residual + hidden_states
291    }
292}
293
294struct MLlamaTextCrossAttention {
295    q_proj: Arc<dyn QuantMethod>,
296    k_proj: Arc<dyn QuantMethod>,
297    v_proj: Arc<dyn QuantMethod>,
298    o_proj: Arc<dyn QuantMethod>,
299    q_norm: RmsNorm,
300    k_norm: RmsNorm,
301    num_heads: usize,
302    num_kv_heads: usize,
303    head_dim: usize,
304    sdpa_params: SdpaParams,
305}
306
307impl MLlamaTextCrossAttention {
308    fn new(
309        cfg: &MLlamaTextConfig,
310        vb: ShardedVarBuilder,
311        mapper: &dyn DeviceMapper,
312        layer_idx: usize,
313        comm: &Arc<mistralrs_quant::Comm>,
314    ) -> Result<Self> {
315        Ok(Self {
316            q_proj: ColumnParallelLayer::new(
317                cfg.hidden_size,
318                cfg.num_attention_heads * cfg.head_dim(),
319                &cfg.quantization_config,
320                false,
321                comm,
322                vb.pp("q_proj"),
323            )?,
324            k_proj: ColumnParallelLayer::new(
325                cfg.hidden_size,
326                cfg.num_key_value_heads * cfg.head_dim(),
327                &cfg.quantization_config,
328                false,
329                comm,
330                vb.pp("k_proj"),
331            )?,
332            v_proj: ColumnParallelLayer::new(
333                cfg.hidden_size,
334                cfg.num_key_value_heads * cfg.head_dim(),
335                &cfg.quantization_config,
336                false,
337                comm,
338                vb.pp("v_proj"),
339            )?,
340            o_proj: RowParallelLayer::new(
341                cfg.num_attention_heads * cfg.head_dim(),
342                cfg.hidden_size,
343                &cfg.quantization_config,
344                false,
345                comm,
346                vb.pp("o_proj"),
347            )?,
348            q_norm: RmsNorm::new(
349                cfg.head_dim(),
350                cfg.rms_norm_eps,
351                mapper.set_device(layer_idx, vb.pp("q_norm"), false),
352            )?,
353            k_norm: RmsNorm::new(
354                cfg.head_dim(),
355                cfg.rms_norm_eps,
356                mapper.set_device(layer_idx, vb.pp("k_norm"), false),
357            )?,
358            num_heads: cfg.num_attention_heads / comm.world_size(),
359            num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
360            head_dim: cfg.head_dim(),
361            sdpa_params: SdpaParams {
362                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
363                softcap: None,
364                softmax_scale: 1.0 / (cfg.head_dim() as f32).sqrt(),
365                sliding_window: None,
366            },
367        })
368    }
369
370    fn forward(
371        &self,
372        hidden_states: &Tensor,
373        cross_attn_states: Option<&Tensor>,
374        attention_mask: Option<&Tensor>,
375    ) -> Result<Tensor> {
376        let (bs, q_len, _) = hidden_states.dims3()?;
377
378        let mut hidden_states = hidden_states.clone();
379        let original_dtype = hidden_states.dtype();
380        if let Some(t) = self.q_proj.quantized_act_type() {
381            hidden_states = hidden_states.to_dtype(t)?;
382        }
383        let mut q = self.q_proj.forward(&hidden_states)?;
384        if self.q_proj.quantized_act_type().is_some() {
385            q = q.to_dtype(original_dtype)?;
386        }
387        q = q
388            .reshape((bs, q_len, self.num_heads, self.head_dim))?
389            .transpose(1, 2)?;
390        q = self.q_norm.forward(&q)?;
391
392        let (k, v) = if let Some(cross_attn_states) = cross_attn_states {
393            let mut cross_attn_states = cross_attn_states.clone();
394            let original_dtype = cross_attn_states.dtype();
395            if let Some(t) = self.k_proj.quantized_act_type() {
396                cross_attn_states = cross_attn_states.to_dtype(t)?;
397            }
398            let mut k = self.k_proj.forward(&cross_attn_states)?;
399            k = k
400                .reshape((bs, (), self.num_kv_heads, self.head_dim))?
401                .transpose(1, 2)?;
402            if self.q_proj.quantized_act_type().is_some() {
403                k = k.to_dtype(original_dtype)?;
404            }
405            k = self.k_norm.forward(&k)?;
406
407            let mut v = self.v_proj.forward(&cross_attn_states)?;
408            if self.q_proj.quantized_act_type().is_some() {
409                v = v.to_dtype(original_dtype)?;
410            }
411            v = v
412                .reshape((bs, (), self.num_kv_heads, self.head_dim))?
413                .transpose(1, 2)?;
414
415            (k, v)
416        } else {
417            candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
418        };
419
420        let mut attn_output = Sdpa
421            .run_attention(
422                &q.contiguous()?,
423                &k.contiguous()?,
424                &v.contiguous()?,
425                attention_mask
426                    .map(|m| m.repeat((1, self.num_heads, 1, 1)).unwrap())
427                    .as_ref(),
428                None,
429                &self.sdpa_params,
430            )?
431            .transpose(1, 2)?
432            .contiguous()?
433            .reshape((bs, q_len, ()))?
434            .to_dtype(q.dtype())?;
435
436        if let Some(t) = self.q_proj.quantized_act_type() {
437            attn_output = attn_output.to_dtype(t)?;
438        }
439        let mut res = self.o_proj.forward(&attn_output)?;
440        if self.q_proj.quantized_act_type().is_some() {
441            res = res.to_dtype(original_dtype)?;
442        }
443        Ok(res)
444    }
445}
446
447struct MLlamaCrossAttentionDecoderLayer {
448    attn: MLlamaTextCrossAttention,
449    attn_gate: Tensor,
450    mlp: MLlamaTextMlp,
451    mlp_gate: Tensor,
452    input_layernorm: RmsNorm,
453    post_attention_layernorm: RmsNorm,
454}
455
456impl MLlamaCrossAttentionDecoderLayer {
457    fn new(
458        cfg: &MLlamaTextConfig,
459        vb: ShardedVarBuilder,
460        mapper: &dyn DeviceMapper,
461        layer_idx: usize,
462        loading_isq: bool,
463        comm: &Arc<mistralrs_quant::Comm>,
464    ) -> Result<Self> {
465        let mlp = MLlamaTextMlp::new(
466            cfg,
467            mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
468            comm,
469        )?;
470        let input_layernorm = RmsNorm::new(
471            cfg.hidden_size,
472            cfg.rms_norm_eps,
473            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
474        )?;
475        let post_attention_layernorm = RmsNorm::new(
476            cfg.hidden_size,
477            cfg.rms_norm_eps,
478            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
479        )?;
480        let attn = MLlamaTextCrossAttention::new(
481            cfg,
482            mapper.set_device(layer_idx, vb.pp("cross_attn"), loading_isq),
483            mapper,
484            layer_idx,
485            comm,
486        )?;
487
488        Ok(Self {
489            attn,
490            mlp,
491            input_layernorm,
492            post_attention_layernorm,
493            attn_gate: mapper
494                .set_device(layer_idx, vb.clone(), false)
495                .get((1,), "cross_attn_attn_gate")?,
496            mlp_gate: mapper
497                .set_device(layer_idx, vb.clone(), false)
498                .get((1,), "cross_attn_mlp_gate")?,
499        })
500    }
501
502    fn forward(
503        &self,
504        hidden_states: &Tensor,
505        cross_attn_states: Option<&Tensor>,
506        attention_mask: Option<&Tensor>,
507        full_text_row_masked_out_mask: Option<&Tensor>,
508    ) -> Result<Tensor> {
509        let residual = hidden_states;
510
511        let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
512
513        hidden_states = self
514            .attn
515            .forward(&hidden_states, cross_attn_states, attention_mask)?;
516        hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?;
517
518        let residual = &hidden_states;
519        let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
520        hidden_states = self.mlp.forward(&hidden_states)?;
521        if let Some(full_text_row_masked_out_mask) = full_text_row_masked_out_mask {
522            hidden_states = full_text_row_masked_out_mask
523                .to_dtype(hidden_states.dtype())?
524                .i((.., 0))?
525                .broadcast_mul(&hidden_states)?;
526        }
527
528        residual + hidden_states.broadcast_mul(&self.mlp_gate.tanh()?)?
529    }
530}
531
532enum MLlamaDecoderLayer {
533    CrossAttn(MLlamaCrossAttentionDecoderLayer),
534    SelfAttn(MLlamaSelfAttentionDecoderLayer),
535}
536
537pub(super) struct MLlamaTextModel {
538    embed_tokens: Embedding,
539    lm_head: Arc<dyn QuantMethod>,
540    norm: RmsNorm,
541    layers: Vec<MLlamaDecoderLayer>,
542    pub(crate) cfg: ModelConfigMetadata,
543    pub(crate) cache: EitherCache,
544    pub(crate) device: Device,
545    pub(crate) max_position_embeddings: usize,
546    mapper: Box<dyn DeviceMapper + Send + Sync>,
547}
548
549impl MLlamaTextModel {
550    pub(super) fn new(
551        cfg: &MLlamaTextConfig,
552        vb: ShardedVarBuilder,
553        is_gptx: bool,
554        normal_loading_metadata: NormalLoadingMetadata,
555        attention_mechanism: AttentionImplementation,
556    ) -> Result<Self> {
557        if let Some(ref quant_cfg) = &cfg.quantization_config {
558            tracing::info!(
559                "Using {} quantization: {}.",
560                quant_cfg.name(),
561                quant_cfg.get_bits_name(&vb)
562            );
563        }
564        if !matches!(attention_mechanism, AttentionImplementation::Eager) {
565            candle_core::bail!("Expected eager attention implementation");
566        }
567        let mapper = normal_loading_metadata.mapper;
568
569        let embed_tokens = embedding(
570            cfg.vocab_size + 8,
571            cfg.hidden_size,
572            mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
573            &cfg.quantization_config,
574        )?;
575
576        let lm_head = if !cfg.tie_word_embeddings {
577            ReplicatedLayer::new(
578                cfg.hidden_size,
579                cfg.vocab_size,
580                &cfg.quantization_config,
581                false,
582                mapper.set_nm_device(vb.pp("lm_head"), false),
583            )?
584        } else {
585            ReplicatedLayer::from_linear(candle_nn::Linear::new(
586                mapper.cast_nm_device(embed_tokens.embeddings(), false)?,
587                None,
588            ))?
589        };
590
591        let vb = vb.pp("model");
592
593        let norm = RmsNorm::new(
594            cfg.hidden_size,
595            cfg.rms_norm_eps,
596            mapper.set_nm_device(vb.pp("norm"), false),
597        )?;
598
599        let mut ropes = HashMap::new();
600        for layer_idx in 0..cfg.num_hidden_layers {
601            let device = mapper
602                .device_for(layer_idx, false)
603                .unwrap_or(&normal_loading_metadata.real_device);
604            ropes.insert(
605                device.location(),
606                Arc::new(Llama3RotaryEmbedding::new_mllama3(
607                    vb.dtype(),
608                    cfg,
609                    device,
610                    is_gptx,
611                )?),
612            );
613        }
614
615        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
616        for i in 0..cfg.num_hidden_layers {
617            let comm = mapper.get_comm_for(i)?;
618            if cfg.cross_attention_layers.contains(&i) {
619                layers.push(MLlamaDecoderLayer::CrossAttn(
620                    MLlamaCrossAttentionDecoderLayer::new(
621                        cfg,
622                        vb.pp(format!("layers.{i}")),
623                        &*mapper,
624                        i,
625                        false,
626                        &comm,
627                    )?,
628                ))
629            } else {
630                let device = mapper
631                    .device_for(i, false)
632                    .unwrap_or(&normal_loading_metadata.real_device);
633                layers.push(MLlamaDecoderLayer::SelfAttn(
634                    MLlamaSelfAttentionDecoderLayer::new(
635                        cfg,
636                        vb.pp(format!("layers.{i}")),
637                        ropes
638                            .get(&device.location())
639                            .expect("No RoPE for device location!")
640                            .clone(),
641                        &*mapper,
642                        i,
643                        normal_loading_metadata.loading_isq,
644                        &comm,
645                    )?,
646                ))
647            }
648        }
649
650        Ok(Self {
651            embed_tokens,
652            layers,
653            norm,
654            lm_head,
655            cfg: ModelConfigMetadata {
656                max_seq_len: cfg.max_position_embeddings,
657                num_layers: cfg.num_hidden_layers,
658                hidden_size: cfg.hidden_size,
659                num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
660                num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
661                    .max(1),
662                sliding_window: None,
663                k_head_dim: cfg.head_dim(),
664                v_head_dim: cfg.head_dim(),
665            },
666            cache: EitherCache::Normal(NormalCache::new(
667                cfg.num_hidden_layers,
668                cfg.max_position_embeddings,
669            )),
670            device: normal_loading_metadata.real_device,
671            max_position_embeddings: cfg.max_position_embeddings,
672            mapper,
673        })
674    }
675
676    #[allow(clippy::too_many_arguments)]
677    pub(super) fn forward(
678        &self,
679        input_ids: &Tensor,
680        cross_attn_states: Option<&Tensor>,
681        cross_attention_mask: Option<&Tensor>,
682        full_text_row_masked_out_mask: Option<&Tensor>,
683        seqlen_offsets: &[usize],
684        context_lens: Vec<(usize, usize)>,
685    ) -> Result<Tensor> {
686        let mut hidden_states = self.embed_tokens.forward(input_ids)?;
687
688        let cache = &mut self.cache.normal().0;
689        let self_mask = CausalMasker.make_causal_mask_matrix(
690            input_ids,
691            cache as &dyn PastKvLenCache,
692            hidden_states.dtype(),
693            self.cfg.num_attn_heads,
694        )?;
695
696        for (i, layer) in self.layers.iter().enumerate() {
697            hidden_states = self.mapper.map(hidden_states, i)?;
698            match layer {
699                MLlamaDecoderLayer::SelfAttn(attn) => {
700                    hidden_states = attn.forward(
701                        &hidden_states,
702                        self_mask
703                            .as_ref()
704                            .map(|m| m.to_device(hidden_states.device()).unwrap())
705                            .as_ref(),
706                        seqlen_offsets,
707                        &mut cache[i],
708                    )?;
709                }
710                MLlamaDecoderLayer::CrossAttn(attn) => {
711                    // For text-only path we should skip cross attention layers.
712                    // Let's check if the layer is cross attention layer and if we have cross attention states
713                    // or cached cross attention states.
714                    if cross_attn_states.is_none() {
715                        continue;
716                    }
717                    hidden_states = attn.forward(
718                        &hidden_states,
719                        cross_attn_states
720                            .as_ref()
721                            .map(|x| x.to_device(hidden_states.device()).unwrap())
722                            .as_ref(),
723                        cross_attention_mask
724                            .as_ref()
725                            .map(|m| m.to_device(hidden_states.device()).unwrap())
726                            .as_ref(),
727                        full_text_row_masked_out_mask
728                            .as_ref()
729                            .map(|m| m.to_device(hidden_states.device()).unwrap())
730                            .as_ref(),
731                    )?;
732                }
733            }
734        }
735
736        hidden_states = hidden_states.to_device(&self.device)?;
737        hidden_states = self.norm.forward(&hidden_states)?;
738
739        hidden_states = self
740            .lm_head
741            .forward(&extract_logits(&hidden_states, context_lens)?)?;
742
743        Ok(hidden_states)
744    }
745}
746
747impl IsqModel for MLlamaTextModel {
748    fn get_layers(
749        &mut self,
750    ) -> (
751        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
752        &dyn DeviceMapper,
753    ) {
754        let mut tensors = Vec::new();
755        for (i, layer) in self.layers.iter_mut().enumerate() {
756            match layer {
757                MLlamaDecoderLayer::CrossAttn(_cross) => {
758                    // tensors.push((&mut cross.attn.q_proj, Some(i)));
759                    // tensors.push((&mut cross.attn.k_proj, Some(i)));
760                    // tensors.push((&mut cross.attn.v_proj, Some(i)));
761                    // tensors.push((&mut cross.attn.o_proj, Some(i)));
762                    // tensors.push((&mut cross.mlp.gate_proj, Some(i)));
763                    // tensors.push((&mut cross.mlp.up_proj, Some(i)));
764                    // tensors.push((&mut cross.mlp.down_proj, Some(i)));
765                }
766                MLlamaDecoderLayer::SelfAttn(self_attn) => {
767                    tensors.push((&mut self_attn.attn.q_proj, Some(i)));
768                    tensors.push((&mut self_attn.attn.k_proj, Some(i)));
769                    tensors.push((&mut self_attn.attn.v_proj, Some(i)));
770                    tensors.push((&mut self_attn.attn.o_proj, Some(i)));
771                    tensors.push((&mut self_attn.mlp.gate_proj, Some(i)));
772                    tensors.push((&mut self_attn.mlp.up_proj, Some(i)));
773                    tensors.push((&mut self_attn.mlp.down_proj, Some(i)));
774                }
775            }
776        }
777        (tensors, &*self.mapper)
778    }
779
780    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
781        let uvb = UnVarBuilder::new();
782
783        uvb.pp("model.embed_tokens").add(&self.embed_tokens);
784        uvb.pp("lm_head").add(&self.lm_head);
785
786        let uvb = uvb.pp("model");
787
788        uvb.pp("norm").add(&self.norm);
789
790        for (i, layer) in self.layers.iter().enumerate() {
791            let uvb_l = uvb.pp("layers").pp(i);
792            match layer {
793                MLlamaDecoderLayer::CrossAttn(crossattn) => {
794                    // Cross attention layers are not quantized
795                    uvb_l
796                        .pp("post_attention_layernorm")
797                        .add(&crossattn.post_attention_layernorm);
798                    uvb_l.pp("input_layernorm").add(&crossattn.input_layernorm);
799                    uvb_l.add_tensor("cross_attn_attn_gate", crossattn.attn_gate.clone());
800                    uvb_l.add_tensor("cross_attn_mlp_gate", crossattn.mlp_gate.clone());
801
802                    let uvb_attn = uvb_l.pp("cross_attn");
803                    uvb_attn.pp("q_proj").add(&crossattn.attn.q_proj);
804                    uvb_attn.pp("k_proj").add(&crossattn.attn.k_proj);
805                    uvb_attn.pp("v_proj").add(&crossattn.attn.v_proj);
806                    uvb_attn.pp("o_proj").add(&crossattn.attn.o_proj);
807                    uvb_attn.pp("q_norm").add(&crossattn.attn.q_norm);
808                    uvb_attn.pp("k_norm").add(&crossattn.attn.k_norm);
809
810                    let uvb_mlp = uvb_l.pp("mlp");
811                    uvb_mlp.pp("gate_proj").add(&crossattn.mlp.gate_proj);
812                    uvb_mlp.pp("up_proj").add(&crossattn.mlp.up_proj);
813                    uvb_mlp.pp("down_proj").add(&crossattn.mlp.down_proj);
814                }
815                MLlamaDecoderLayer::SelfAttn(selfattn) => {
816                    uvb_l
817                        .pp("post_attention_layernorm")
818                        .add(&selfattn.post_attention_layernorm);
819                    uvb_l.pp("input_layernorm").add(&selfattn.input_layernorm);
820                }
821            }
822        }
823
824        uvb.to_safetensors()
825    }
826}