1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, IndexOp, Result, Tensor};
6use candle_nn::{Activation, Embedding, Module};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
9};
10
11use crate::{
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{embedding, CausalMasker, Llama3RotaryEmbedding, RmsNorm, Sdpa},
15 layers_masker::PastKvLenCache,
16 paged_attention::{AttentionImplementation, ModelConfigMetadata},
17 pipeline::{
18 extract_logits, EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata,
19 },
20 utils::unvarbuilder::UnVarBuilder,
21};
22
23use super::config::MLlamaTextConfig;
24
25struct MLlamaTextMlp {
26 gate_proj: Arc<dyn QuantMethod>,
27 up_proj: Arc<dyn QuantMethod>,
28 down_proj: Arc<dyn QuantMethod>,
29 act: Activation,
30}
31
32impl MLlamaTextMlp {
33 fn new(
34 cfg: &MLlamaTextConfig,
35 vb: ShardedVarBuilder,
36 comm: &Arc<mistralrs_quant::Comm>,
37 ) -> Result<Self> {
38 Ok(Self {
39 gate_proj: ColumnParallelLayer::new(
40 cfg.hidden_size,
41 cfg.intermediate_size,
42 &cfg.quantization_config,
43 false,
44 comm,
45 vb.pp("gate_proj"),
46 )?,
47 up_proj: ColumnParallelLayer::new(
48 cfg.hidden_size,
49 cfg.intermediate_size,
50 &cfg.quantization_config,
51 false,
52 comm,
53 vb.pp("up_proj"),
54 )?,
55 down_proj: RowParallelLayer::new(
56 cfg.intermediate_size,
57 cfg.hidden_size,
58 &cfg.quantization_config,
59 false,
60 comm,
61 vb.pp("down_proj"),
62 )?,
63 act: cfg.hidden_act,
64 })
65 }
66
67 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
68 let original_dtype = xs.dtype();
69 let mut xs = xs.clone();
70 if let Some(t) = self.gate_proj.quantized_act_type() {
71 xs = xs.to_dtype(t)?;
72 }
73 let mut res = self.down_proj.forward(
74 &self
75 .act
76 .forward(&self.gate_proj.forward(&xs)?)?
77 .broadcast_mul(&self.up_proj.forward(&xs)?)?,
78 )?;
79 if self.gate_proj.quantized_act_type().is_some() {
80 res = res.to_dtype(original_dtype)?;
81 }
82 Ok(res)
83 }
84}
85
86struct MLlamaTextSelfAttention {
87 q_proj: Arc<dyn QuantMethod>,
88 k_proj: Arc<dyn QuantMethod>,
89 v_proj: Arc<dyn QuantMethod>,
90 o_proj: Arc<dyn QuantMethod>,
91 sdpa_params: SdpaParams,
92 rope: Arc<Llama3RotaryEmbedding>,
93 num_heads: usize,
94 num_kv_heads: usize,
95 head_dim: usize,
96}
97
98impl MLlamaTextSelfAttention {
99 fn new(
100 cfg: &MLlamaTextConfig,
101 vb: ShardedVarBuilder,
102 rope: Arc<Llama3RotaryEmbedding>,
103 comm: &Arc<mistralrs_quant::Comm>,
104 ) -> Result<Self> {
105 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
106
107 Ok(Self {
108 q_proj: ColumnParallelLayer::new(
109 cfg.hidden_size,
110 cfg.num_attention_heads * cfg.head_dim(),
111 &cfg.quantization_config,
112 false,
113 comm,
114 vb.pp("q_proj"),
115 )?,
116 k_proj: ColumnParallelLayer::new(
117 cfg.hidden_size,
118 cfg.num_key_value_heads * cfg.head_dim(),
119 &cfg.quantization_config,
120 false,
121 comm,
122 vb.pp("k_proj"),
123 )?,
124 v_proj: ColumnParallelLayer::new(
125 cfg.hidden_size,
126 cfg.num_key_value_heads * cfg.head_dim(),
127 &cfg.quantization_config,
128 false,
129 comm,
130 vb.pp("v_proj"),
131 )?,
132 o_proj: RowParallelLayer::new(
133 cfg.num_attention_heads * cfg.head_dim(),
134 cfg.hidden_size,
135 &cfg.quantization_config,
136 false,
137 comm,
138 vb.pp("o_proj"),
139 )?,
140 sdpa_params: SdpaParams {
141 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
142 softcap: None,
143 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
144 sliding_window: None,
145 },
146 rope,
147 num_heads: cfg.num_attention_heads / comm.world_size(),
148 num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
149 head_dim,
150 })
151 }
152
153 fn forward(
154 &self,
155 hidden_states: &Tensor,
156 attention_mask: Option<&Tensor>,
157 seqlen_offsets: &[usize],
158 kv_cache: &mut KvCache,
159 ) -> Result<Tensor> {
160 let (bs, q_len, _) = hidden_states.dims3()?;
161
162 let mut hidden_states = hidden_states.clone();
163 let original_dtype = hidden_states.dtype();
164 if let Some(t) = self.q_proj.quantized_act_type() {
165 hidden_states = hidden_states.to_dtype(t)?;
166 }
167 let mut q = self.q_proj.forward(&hidden_states)?;
168 let mut k = self.k_proj.forward(&hidden_states)?;
169 let mut v = self.v_proj.forward(&hidden_states)?;
170 if self.q_proj.quantized_act_type().is_some() {
171 q = q.to_dtype(original_dtype)?;
172 k = k.to_dtype(original_dtype)?;
173 v = v.to_dtype(original_dtype)?;
174 }
175
176 let (q, k, mut v) = if q_len != 1 {
177 let q = q
178 .reshape((bs, q_len, self.num_heads, self.head_dim))?
179 .transpose(1, 2)?;
180 let k = k
181 .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
182 .transpose(1, 2)?;
183 let v = v
184 .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
185 .transpose(1, 2)?;
186 (q, k, v)
187 } else {
188 let q = q.reshape((bs, self.num_heads, q_len, self.head_dim))?;
189 let k = k.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
190 let v = v.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
191 (q, k, v)
192 };
193
194 let (q, mut k) = self.rope.forward(&q, &k, seqlen_offsets)?;
195
196 (k, v) = kv_cache.append(&k, &v)?;
197
198 let mut attn_output = Sdpa
199 .run_attention(
200 &q.contiguous()?,
201 &k.contiguous()?,
202 &v.contiguous()?,
203 attention_mask,
204 None,
205 &self.sdpa_params,
206 )?
207 .transpose(1, 2)?
208 .contiguous()?
209 .reshape((bs, q_len, ()))?
210 .to_dtype(q.dtype())?;
211
212 if let Some(t) = self.q_proj.quantized_act_type() {
213 attn_output = attn_output.to_dtype(t)?;
214 }
215 let mut res = self.o_proj.forward(&attn_output)?;
216 if self.q_proj.quantized_act_type().is_some() {
217 res = res.to_dtype(original_dtype)?;
218 }
219 Ok(res)
220 }
221}
222
223struct MLlamaSelfAttentionDecoderLayer {
224 attn: MLlamaTextSelfAttention,
225 mlp: MLlamaTextMlp,
226 input_layernorm: RmsNorm,
227 post_attention_layernorm: RmsNorm,
228}
229
230impl MLlamaSelfAttentionDecoderLayer {
231 fn new(
232 cfg: &MLlamaTextConfig,
233 vb: ShardedVarBuilder,
234 rope: Arc<Llama3RotaryEmbedding>,
235 mapper: &dyn DeviceMapper,
236 layer_idx: usize,
237 loading_isq: bool,
238 comm: &Arc<mistralrs_quant::Comm>,
239 ) -> Result<Self> {
240 let mlp = MLlamaTextMlp::new(
241 cfg,
242 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
243 comm,
244 )?;
245 let input_layernorm = RmsNorm::new(
246 cfg.hidden_size,
247 cfg.rms_norm_eps,
248 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
249 )?;
250 let post_attention_layernorm = RmsNorm::new(
251 cfg.hidden_size,
252 cfg.rms_norm_eps,
253 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
254 )?;
255 let attn = MLlamaTextSelfAttention::new(
256 cfg,
257 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
258 rope,
259 comm,
260 )?;
261
262 Ok(Self {
263 attn,
264 mlp,
265 input_layernorm,
266 post_attention_layernorm,
267 })
268 }
269
270 fn forward(
271 &self,
272 hidden_states: &Tensor,
273 attention_mask: Option<&Tensor>,
274 seqlen_offsets: &[usize],
275 kv_cache: &mut KvCache,
276 ) -> Result<Tensor> {
277 let residual = hidden_states;
278
279 let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
280
281 hidden_states =
282 self.attn
283 .forward(&hidden_states, attention_mask, seqlen_offsets, kv_cache)?;
284 hidden_states = (residual + hidden_states)?;
285
286 let residual = &hidden_states;
287 let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
288 hidden_states = self.mlp.forward(&hidden_states)?;
289
290 residual + hidden_states
291 }
292}
293
294struct MLlamaTextCrossAttention {
295 q_proj: Arc<dyn QuantMethod>,
296 k_proj: Arc<dyn QuantMethod>,
297 v_proj: Arc<dyn QuantMethod>,
298 o_proj: Arc<dyn QuantMethod>,
299 q_norm: RmsNorm,
300 k_norm: RmsNorm,
301 num_heads: usize,
302 num_kv_heads: usize,
303 head_dim: usize,
304 sdpa_params: SdpaParams,
305}
306
307impl MLlamaTextCrossAttention {
308 fn new(
309 cfg: &MLlamaTextConfig,
310 vb: ShardedVarBuilder,
311 mapper: &dyn DeviceMapper,
312 layer_idx: usize,
313 comm: &Arc<mistralrs_quant::Comm>,
314 ) -> Result<Self> {
315 Ok(Self {
316 q_proj: ColumnParallelLayer::new(
317 cfg.hidden_size,
318 cfg.num_attention_heads * cfg.head_dim(),
319 &cfg.quantization_config,
320 false,
321 comm,
322 vb.pp("q_proj"),
323 )?,
324 k_proj: ColumnParallelLayer::new(
325 cfg.hidden_size,
326 cfg.num_key_value_heads * cfg.head_dim(),
327 &cfg.quantization_config,
328 false,
329 comm,
330 vb.pp("k_proj"),
331 )?,
332 v_proj: ColumnParallelLayer::new(
333 cfg.hidden_size,
334 cfg.num_key_value_heads * cfg.head_dim(),
335 &cfg.quantization_config,
336 false,
337 comm,
338 vb.pp("v_proj"),
339 )?,
340 o_proj: RowParallelLayer::new(
341 cfg.num_attention_heads * cfg.head_dim(),
342 cfg.hidden_size,
343 &cfg.quantization_config,
344 false,
345 comm,
346 vb.pp("o_proj"),
347 )?,
348 q_norm: RmsNorm::new(
349 cfg.head_dim(),
350 cfg.rms_norm_eps,
351 mapper.set_device(layer_idx, vb.pp("q_norm"), false),
352 )?,
353 k_norm: RmsNorm::new(
354 cfg.head_dim(),
355 cfg.rms_norm_eps,
356 mapper.set_device(layer_idx, vb.pp("k_norm"), false),
357 )?,
358 num_heads: cfg.num_attention_heads / comm.world_size(),
359 num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
360 head_dim: cfg.head_dim(),
361 sdpa_params: SdpaParams {
362 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
363 softcap: None,
364 softmax_scale: 1.0 / (cfg.head_dim() as f32).sqrt(),
365 sliding_window: None,
366 },
367 })
368 }
369
370 fn forward(
371 &self,
372 hidden_states: &Tensor,
373 cross_attn_states: Option<&Tensor>,
374 attention_mask: Option<&Tensor>,
375 ) -> Result<Tensor> {
376 let (bs, q_len, _) = hidden_states.dims3()?;
377
378 let mut hidden_states = hidden_states.clone();
379 let original_dtype = hidden_states.dtype();
380 if let Some(t) = self.q_proj.quantized_act_type() {
381 hidden_states = hidden_states.to_dtype(t)?;
382 }
383 let mut q = self.q_proj.forward(&hidden_states)?;
384 if self.q_proj.quantized_act_type().is_some() {
385 q = q.to_dtype(original_dtype)?;
386 }
387 q = q
388 .reshape((bs, q_len, self.num_heads, self.head_dim))?
389 .transpose(1, 2)?;
390 q = self.q_norm.forward(&q)?;
391
392 let (k, v) = if let Some(cross_attn_states) = cross_attn_states {
393 let mut cross_attn_states = cross_attn_states.clone();
394 let original_dtype = cross_attn_states.dtype();
395 if let Some(t) = self.k_proj.quantized_act_type() {
396 cross_attn_states = cross_attn_states.to_dtype(t)?;
397 }
398 let mut k = self.k_proj.forward(&cross_attn_states)?;
399 k = k
400 .reshape((bs, (), self.num_kv_heads, self.head_dim))?
401 .transpose(1, 2)?;
402 if self.q_proj.quantized_act_type().is_some() {
403 k = k.to_dtype(original_dtype)?;
404 }
405 k = self.k_norm.forward(&k)?;
406
407 let mut v = self.v_proj.forward(&cross_attn_states)?;
408 if self.q_proj.quantized_act_type().is_some() {
409 v = v.to_dtype(original_dtype)?;
410 }
411 v = v
412 .reshape((bs, (), self.num_kv_heads, self.head_dim))?
413 .transpose(1, 2)?;
414
415 (k, v)
416 } else {
417 candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
418 };
419
420 let mut attn_output = Sdpa
421 .run_attention(
422 &q.contiguous()?,
423 &k.contiguous()?,
424 &v.contiguous()?,
425 attention_mask
426 .map(|m| m.repeat((1, self.num_heads, 1, 1)).unwrap())
427 .as_ref(),
428 None,
429 &self.sdpa_params,
430 )?
431 .transpose(1, 2)?
432 .contiguous()?
433 .reshape((bs, q_len, ()))?
434 .to_dtype(q.dtype())?;
435
436 if let Some(t) = self.q_proj.quantized_act_type() {
437 attn_output = attn_output.to_dtype(t)?;
438 }
439 let mut res = self.o_proj.forward(&attn_output)?;
440 if self.q_proj.quantized_act_type().is_some() {
441 res = res.to_dtype(original_dtype)?;
442 }
443 Ok(res)
444 }
445}
446
447struct MLlamaCrossAttentionDecoderLayer {
448 attn: MLlamaTextCrossAttention,
449 attn_gate: Tensor,
450 mlp: MLlamaTextMlp,
451 mlp_gate: Tensor,
452 input_layernorm: RmsNorm,
453 post_attention_layernorm: RmsNorm,
454}
455
456impl MLlamaCrossAttentionDecoderLayer {
457 fn new(
458 cfg: &MLlamaTextConfig,
459 vb: ShardedVarBuilder,
460 mapper: &dyn DeviceMapper,
461 layer_idx: usize,
462 loading_isq: bool,
463 comm: &Arc<mistralrs_quant::Comm>,
464 ) -> Result<Self> {
465 let mlp = MLlamaTextMlp::new(
466 cfg,
467 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
468 comm,
469 )?;
470 let input_layernorm = RmsNorm::new(
471 cfg.hidden_size,
472 cfg.rms_norm_eps,
473 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
474 )?;
475 let post_attention_layernorm = RmsNorm::new(
476 cfg.hidden_size,
477 cfg.rms_norm_eps,
478 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
479 )?;
480 let attn = MLlamaTextCrossAttention::new(
481 cfg,
482 mapper.set_device(layer_idx, vb.pp("cross_attn"), loading_isq),
483 mapper,
484 layer_idx,
485 comm,
486 )?;
487
488 Ok(Self {
489 attn,
490 mlp,
491 input_layernorm,
492 post_attention_layernorm,
493 attn_gate: mapper
494 .set_device(layer_idx, vb.clone(), false)
495 .get((1,), "cross_attn_attn_gate")?,
496 mlp_gate: mapper
497 .set_device(layer_idx, vb.clone(), false)
498 .get((1,), "cross_attn_mlp_gate")?,
499 })
500 }
501
502 fn forward(
503 &self,
504 hidden_states: &Tensor,
505 cross_attn_states: Option<&Tensor>,
506 attention_mask: Option<&Tensor>,
507 full_text_row_masked_out_mask: Option<&Tensor>,
508 ) -> Result<Tensor> {
509 let residual = hidden_states;
510
511 let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
512
513 hidden_states = self
514 .attn
515 .forward(&hidden_states, cross_attn_states, attention_mask)?;
516 hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?;
517
518 let residual = &hidden_states;
519 let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
520 hidden_states = self.mlp.forward(&hidden_states)?;
521 if let Some(full_text_row_masked_out_mask) = full_text_row_masked_out_mask {
522 hidden_states = full_text_row_masked_out_mask
523 .to_dtype(hidden_states.dtype())?
524 .i((.., 0))?
525 .broadcast_mul(&hidden_states)?;
526 }
527
528 residual + hidden_states.broadcast_mul(&self.mlp_gate.tanh()?)?
529 }
530}
531
532enum MLlamaDecoderLayer {
533 CrossAttn(MLlamaCrossAttentionDecoderLayer),
534 SelfAttn(MLlamaSelfAttentionDecoderLayer),
535}
536
537pub(super) struct MLlamaTextModel {
538 embed_tokens: Embedding,
539 lm_head: Arc<dyn QuantMethod>,
540 norm: RmsNorm,
541 layers: Vec<MLlamaDecoderLayer>,
542 pub(crate) cfg: ModelConfigMetadata,
543 pub(crate) cache: EitherCache,
544 pub(crate) device: Device,
545 pub(crate) max_position_embeddings: usize,
546 mapper: Box<dyn DeviceMapper + Send + Sync>,
547}
548
549impl MLlamaTextModel {
550 pub(super) fn new(
551 cfg: &MLlamaTextConfig,
552 vb: ShardedVarBuilder,
553 is_gptx: bool,
554 normal_loading_metadata: NormalLoadingMetadata,
555 attention_mechanism: AttentionImplementation,
556 ) -> Result<Self> {
557 if let Some(ref quant_cfg) = &cfg.quantization_config {
558 tracing::info!(
559 "Using {} quantization: {}.",
560 quant_cfg.name(),
561 quant_cfg.get_bits_name(&vb)
562 );
563 }
564 if !matches!(attention_mechanism, AttentionImplementation::Eager) {
565 candle_core::bail!("Expected eager attention implementation");
566 }
567 let mapper = normal_loading_metadata.mapper;
568
569 let embed_tokens = embedding(
570 cfg.vocab_size + 8,
571 cfg.hidden_size,
572 mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
573 &cfg.quantization_config,
574 )?;
575
576 let lm_head = if !cfg.tie_word_embeddings {
577 ReplicatedLayer::new(
578 cfg.hidden_size,
579 cfg.vocab_size,
580 &cfg.quantization_config,
581 false,
582 mapper.set_nm_device(vb.pp("lm_head"), false),
583 )?
584 } else {
585 ReplicatedLayer::from_linear(candle_nn::Linear::new(
586 mapper.cast_nm_device(embed_tokens.embeddings(), false)?,
587 None,
588 ))?
589 };
590
591 let vb = vb.pp("model");
592
593 let norm = RmsNorm::new(
594 cfg.hidden_size,
595 cfg.rms_norm_eps,
596 mapper.set_nm_device(vb.pp("norm"), false),
597 )?;
598
599 let mut ropes = HashMap::new();
600 for layer_idx in 0..cfg.num_hidden_layers {
601 let device = mapper
602 .device_for(layer_idx, false)
603 .unwrap_or(&normal_loading_metadata.real_device);
604 ropes.insert(
605 device.location(),
606 Arc::new(Llama3RotaryEmbedding::new_mllama3(
607 vb.dtype(),
608 cfg,
609 device,
610 is_gptx,
611 )?),
612 );
613 }
614
615 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
616 for i in 0..cfg.num_hidden_layers {
617 let comm = mapper.get_comm_for(i)?;
618 if cfg.cross_attention_layers.contains(&i) {
619 layers.push(MLlamaDecoderLayer::CrossAttn(
620 MLlamaCrossAttentionDecoderLayer::new(
621 cfg,
622 vb.pp(format!("layers.{i}")),
623 &*mapper,
624 i,
625 false,
626 &comm,
627 )?,
628 ))
629 } else {
630 let device = mapper
631 .device_for(i, false)
632 .unwrap_or(&normal_loading_metadata.real_device);
633 layers.push(MLlamaDecoderLayer::SelfAttn(
634 MLlamaSelfAttentionDecoderLayer::new(
635 cfg,
636 vb.pp(format!("layers.{i}")),
637 ropes
638 .get(&device.location())
639 .expect("No RoPE for device location!")
640 .clone(),
641 &*mapper,
642 i,
643 normal_loading_metadata.loading_isq,
644 &comm,
645 )?,
646 ))
647 }
648 }
649
650 Ok(Self {
651 embed_tokens,
652 layers,
653 norm,
654 lm_head,
655 cfg: ModelConfigMetadata {
656 max_seq_len: cfg.max_position_embeddings,
657 num_layers: cfg.num_hidden_layers,
658 hidden_size: cfg.hidden_size,
659 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
660 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
661 .max(1),
662 sliding_window: None,
663 k_head_dim: cfg.head_dim(),
664 v_head_dim: cfg.head_dim(),
665 },
666 cache: EitherCache::Normal(NormalCache::new(
667 cfg.num_hidden_layers,
668 cfg.max_position_embeddings,
669 )),
670 device: normal_loading_metadata.real_device,
671 max_position_embeddings: cfg.max_position_embeddings,
672 mapper,
673 })
674 }
675
676 #[allow(clippy::too_many_arguments)]
677 pub(super) fn forward(
678 &self,
679 input_ids: &Tensor,
680 cross_attn_states: Option<&Tensor>,
681 cross_attention_mask: Option<&Tensor>,
682 full_text_row_masked_out_mask: Option<&Tensor>,
683 seqlen_offsets: &[usize],
684 context_lens: Vec<(usize, usize)>,
685 ) -> Result<Tensor> {
686 let mut hidden_states = self.embed_tokens.forward(input_ids)?;
687
688 let cache = &mut self.cache.normal().0;
689 let self_mask = CausalMasker.make_causal_mask_matrix(
690 input_ids,
691 cache as &dyn PastKvLenCache,
692 hidden_states.dtype(),
693 self.cfg.num_attn_heads,
694 )?;
695
696 for (i, layer) in self.layers.iter().enumerate() {
697 hidden_states = self.mapper.map(hidden_states, i)?;
698 match layer {
699 MLlamaDecoderLayer::SelfAttn(attn) => {
700 hidden_states = attn.forward(
701 &hidden_states,
702 self_mask
703 .as_ref()
704 .map(|m| m.to_device(hidden_states.device()).unwrap())
705 .as_ref(),
706 seqlen_offsets,
707 &mut cache[i],
708 )?;
709 }
710 MLlamaDecoderLayer::CrossAttn(attn) => {
711 if cross_attn_states.is_none() {
715 continue;
716 }
717 hidden_states = attn.forward(
718 &hidden_states,
719 cross_attn_states
720 .as_ref()
721 .map(|x| x.to_device(hidden_states.device()).unwrap())
722 .as_ref(),
723 cross_attention_mask
724 .as_ref()
725 .map(|m| m.to_device(hidden_states.device()).unwrap())
726 .as_ref(),
727 full_text_row_masked_out_mask
728 .as_ref()
729 .map(|m| m.to_device(hidden_states.device()).unwrap())
730 .as_ref(),
731 )?;
732 }
733 }
734 }
735
736 hidden_states = hidden_states.to_device(&self.device)?;
737 hidden_states = self.norm.forward(&hidden_states)?;
738
739 hidden_states = self
740 .lm_head
741 .forward(&extract_logits(&hidden_states, context_lens)?)?;
742
743 Ok(hidden_states)
744 }
745}
746
747impl IsqModel for MLlamaTextModel {
748 fn get_layers(
749 &mut self,
750 ) -> (
751 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
752 &dyn DeviceMapper,
753 ) {
754 let mut tensors = Vec::new();
755 for (i, layer) in self.layers.iter_mut().enumerate() {
756 match layer {
757 MLlamaDecoderLayer::CrossAttn(_cross) => {
758 }
766 MLlamaDecoderLayer::SelfAttn(self_attn) => {
767 tensors.push((&mut self_attn.attn.q_proj, Some(i)));
768 tensors.push((&mut self_attn.attn.k_proj, Some(i)));
769 tensors.push((&mut self_attn.attn.v_proj, Some(i)));
770 tensors.push((&mut self_attn.attn.o_proj, Some(i)));
771 tensors.push((&mut self_attn.mlp.gate_proj, Some(i)));
772 tensors.push((&mut self_attn.mlp.up_proj, Some(i)));
773 tensors.push((&mut self_attn.mlp.down_proj, Some(i)));
774 }
775 }
776 }
777 (tensors, &*self.mapper)
778 }
779
780 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
781 let uvb = UnVarBuilder::new();
782
783 uvb.pp("model.embed_tokens").add(&self.embed_tokens);
784 uvb.pp("lm_head").add(&self.lm_head);
785
786 let uvb = uvb.pp("model");
787
788 uvb.pp("norm").add(&self.norm);
789
790 for (i, layer) in self.layers.iter().enumerate() {
791 let uvb_l = uvb.pp("layers").pp(i);
792 match layer {
793 MLlamaDecoderLayer::CrossAttn(crossattn) => {
794 uvb_l
796 .pp("post_attention_layernorm")
797 .add(&crossattn.post_attention_layernorm);
798 uvb_l.pp("input_layernorm").add(&crossattn.input_layernorm);
799 uvb_l.add_tensor("cross_attn_attn_gate", crossattn.attn_gate.clone());
800 uvb_l.add_tensor("cross_attn_mlp_gate", crossattn.mlp_gate.clone());
801
802 let uvb_attn = uvb_l.pp("cross_attn");
803 uvb_attn.pp("q_proj").add(&crossattn.attn.q_proj);
804 uvb_attn.pp("k_proj").add(&crossattn.attn.k_proj);
805 uvb_attn.pp("v_proj").add(&crossattn.attn.v_proj);
806 uvb_attn.pp("o_proj").add(&crossattn.attn.o_proj);
807 uvb_attn.pp("q_norm").add(&crossattn.attn.q_norm);
808 uvb_attn.pp("k_norm").add(&crossattn.attn.k_norm);
809
810 let uvb_mlp = uvb_l.pp("mlp");
811 uvb_mlp.pp("gate_proj").add(&crossattn.mlp.gate_proj);
812 uvb_mlp.pp("up_proj").add(&crossattn.mlp.up_proj);
813 uvb_mlp.pp("down_proj").add(&crossattn.mlp.down_proj);
814 }
815 MLlamaDecoderLayer::SelfAttn(selfattn) => {
816 uvb_l
817 .pp("post_attention_layernorm")
818 .add(&selfattn.post_attention_layernorm);
819 uvb_l.pp("input_layernorm").add(&selfattn.input_layernorm);
820 }
821 }
822 }
823
824 uvb.to_safetensors()
825 }
826}