1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{Device, IndexOp, Result, Tensor};
6use candle_nn::{Activation, Embedding, Module};
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder,
9};
10
11use crate::{
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{embedding, CausalMasker, Llama3RotaryEmbedding, RmsNorm, Sdpa},
15 layers_masker::PastKvLenCache,
16 paged_attention::{AttentionImplementation, ModelConfigMetadata},
17 pipeline::{
18 extract_logits, EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata,
19 },
20 utils::unvarbuilder::UnVarBuilder,
21};
22
23use super::config::MLlamaTextConfig;
24
25struct MLlamaTextMlp {
26 gate_proj: Arc<dyn QuantMethod>,
27 up_proj: Arc<dyn QuantMethod>,
28 down_proj: Arc<dyn QuantMethod>,
29 act: Activation,
30}
31
32impl MLlamaTextMlp {
33 fn new(
34 cfg: &MLlamaTextConfig,
35 vb: ShardedVarBuilder,
36 comm: &Arc<mistralrs_quant::Comm>,
37 ) -> Result<Self> {
38 Ok(Self {
39 gate_proj: ColumnParallelLayer::new(
40 cfg.hidden_size,
41 cfg.intermediate_size,
42 &cfg.quantization_config,
43 false,
44 comm,
45 vb.pp("gate_proj"),
46 )?,
47 up_proj: ColumnParallelLayer::new(
48 cfg.hidden_size,
49 cfg.intermediate_size,
50 &cfg.quantization_config,
51 false,
52 comm,
53 vb.pp("up_proj"),
54 )?,
55 down_proj: RowParallelLayer::new(
56 cfg.intermediate_size,
57 cfg.hidden_size,
58 &cfg.quantization_config,
59 false,
60 comm,
61 vb.pp("down_proj"),
62 )?,
63 act: cfg.hidden_act,
64 })
65 }
66
67 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
68 let original_dtype = xs.dtype();
69 let mut xs = xs.clone();
70 if let Some(t) = self.gate_proj.quantized_act_type() {
71 xs = xs.to_dtype(t)?;
72 }
73 let mut res = self.down_proj.forward(
74 &self
75 .act
76 .forward(&self.gate_proj.forward(&xs)?)?
77 .broadcast_mul(&self.up_proj.forward(&xs)?)?,
78 )?;
79 if self.gate_proj.quantized_act_type().is_some() {
80 res = res.to_dtype(original_dtype)?;
81 }
82 Ok(res)
83 }
84}
85
86struct MLlamaTextSelfAttention {
87 q_proj: Arc<dyn QuantMethod>,
88 k_proj: Arc<dyn QuantMethod>,
89 v_proj: Arc<dyn QuantMethod>,
90 o_proj: Arc<dyn QuantMethod>,
91 sdpa_params: SdpaParams,
92 rope: Arc<Llama3RotaryEmbedding>,
93 num_heads: usize,
94 num_kv_heads: usize,
95 head_dim: usize,
96}
97
98impl MLlamaTextSelfAttention {
99 fn new(
100 cfg: &MLlamaTextConfig,
101 vb: ShardedVarBuilder,
102 rope: Arc<Llama3RotaryEmbedding>,
103 comm: &Arc<mistralrs_quant::Comm>,
104 ) -> Result<Self> {
105 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
106
107 Ok(Self {
108 q_proj: ColumnParallelLayer::new(
109 cfg.hidden_size,
110 cfg.num_attention_heads * cfg.head_dim(),
111 &cfg.quantization_config,
112 false,
113 comm,
114 vb.pp("q_proj"),
115 )?,
116 k_proj: ColumnParallelLayer::new(
117 cfg.hidden_size,
118 cfg.num_key_value_heads * cfg.head_dim(),
119 &cfg.quantization_config,
120 false,
121 comm,
122 vb.pp("k_proj"),
123 )?,
124 v_proj: ColumnParallelLayer::new(
125 cfg.hidden_size,
126 cfg.num_key_value_heads * cfg.head_dim(),
127 &cfg.quantization_config,
128 false,
129 comm,
130 vb.pp("v_proj"),
131 )?,
132 o_proj: RowParallelLayer::new(
133 cfg.num_attention_heads * cfg.head_dim(),
134 cfg.hidden_size,
135 &cfg.quantization_config,
136 false,
137 comm,
138 vb.pp("o_proj"),
139 )?,
140 sdpa_params: SdpaParams {
141 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
142 use_flash_attn: false,
143 softcap: None,
144 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
145 sliding_window: None,
146 },
147 rope,
148 num_heads: cfg.num_attention_heads / comm.world_size(),
149 num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
150 head_dim,
151 })
152 }
153
154 fn forward(
155 &self,
156 hidden_states: &Tensor,
157 attention_mask: Option<&Tensor>,
158 seqlen_offsets: &[usize],
159 kv_cache: &mut KvCache,
160 ) -> Result<Tensor> {
161 let (bs, q_len, _) = hidden_states.dims3()?;
162
163 let mut hidden_states = hidden_states.clone();
164 let original_dtype = hidden_states.dtype();
165 if let Some(t) = self.q_proj.quantized_act_type() {
166 hidden_states = hidden_states.to_dtype(t)?;
167 }
168 let mut q = self.q_proj.forward(&hidden_states)?;
169 let mut k = self.k_proj.forward(&hidden_states)?;
170 let mut v = self.v_proj.forward(&hidden_states)?;
171 if self.q_proj.quantized_act_type().is_some() {
172 q = q.to_dtype(original_dtype)?;
173 k = k.to_dtype(original_dtype)?;
174 v = v.to_dtype(original_dtype)?;
175 }
176
177 let (q, k, mut v) = if q_len != 1 {
178 let q = q
179 .reshape((bs, q_len, self.num_heads, self.head_dim))?
180 .transpose(1, 2)?;
181 let k = k
182 .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
183 .transpose(1, 2)?;
184 let v = v
185 .reshape((bs, q_len, self.num_kv_heads, self.head_dim))?
186 .transpose(1, 2)?;
187 (q, k, v)
188 } else {
189 let q = q.reshape((bs, self.num_heads, q_len, self.head_dim))?;
190 let k = k.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
191 let v = v.reshape((bs, self.num_kv_heads, q_len, self.head_dim))?;
192 (q, k, v)
193 };
194
195 let (q, mut k) = self.rope.forward(&q, &k, seqlen_offsets)?;
196
197 (k, v) = kv_cache.append(&k, &v)?;
198
199 let mut attn_output = Sdpa
200 .run_attention(
201 &q.contiguous()?,
202 &k.contiguous()?,
203 &v.contiguous()?,
204 attention_mask,
205 None,
206 &self.sdpa_params,
207 )?
208 .transpose(1, 2)?
209 .contiguous()?
210 .reshape((bs, q_len, ()))?
211 .to_dtype(q.dtype())?;
212
213 if let Some(t) = self.q_proj.quantized_act_type() {
214 attn_output = attn_output.to_dtype(t)?;
215 }
216 let mut res = self.o_proj.forward(&attn_output)?;
217 if self.q_proj.quantized_act_type().is_some() {
218 res = res.to_dtype(original_dtype)?;
219 }
220 Ok(res)
221 }
222}
223
224struct MLlamaSelfAttentionDecoderLayer {
225 attn: MLlamaTextSelfAttention,
226 mlp: MLlamaTextMlp,
227 input_layernorm: RmsNorm,
228 post_attention_layernorm: RmsNorm,
229}
230
231impl MLlamaSelfAttentionDecoderLayer {
232 fn new(
233 cfg: &MLlamaTextConfig,
234 vb: ShardedVarBuilder,
235 rope: Arc<Llama3RotaryEmbedding>,
236 mapper: &dyn DeviceMapper,
237 layer_idx: usize,
238 loading_isq: bool,
239 comm: &Arc<mistralrs_quant::Comm>,
240 ) -> Result<Self> {
241 let mlp = MLlamaTextMlp::new(
242 cfg,
243 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
244 comm,
245 )?;
246 let input_layernorm = RmsNorm::new(
247 cfg.hidden_size,
248 cfg.rms_norm_eps,
249 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
250 )?;
251 let post_attention_layernorm = RmsNorm::new(
252 cfg.hidden_size,
253 cfg.rms_norm_eps,
254 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
255 )?;
256 let attn = MLlamaTextSelfAttention::new(
257 cfg,
258 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
259 rope,
260 comm,
261 )?;
262
263 Ok(Self {
264 attn,
265 mlp,
266 input_layernorm,
267 post_attention_layernorm,
268 })
269 }
270
271 fn forward(
272 &self,
273 hidden_states: &Tensor,
274 attention_mask: Option<&Tensor>,
275 seqlen_offsets: &[usize],
276 kv_cache: &mut KvCache,
277 ) -> Result<Tensor> {
278 let residual = hidden_states;
279
280 let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
281
282 hidden_states =
283 self.attn
284 .forward(&hidden_states, attention_mask, seqlen_offsets, kv_cache)?;
285 hidden_states = (residual + hidden_states)?;
286
287 let residual = &hidden_states;
288 let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
289 hidden_states = self.mlp.forward(&hidden_states)?;
290
291 residual + hidden_states
292 }
293}
294
295struct MLlamaTextCrossAttention {
296 q_proj: Arc<dyn QuantMethod>,
297 k_proj: Arc<dyn QuantMethod>,
298 v_proj: Arc<dyn QuantMethod>,
299 o_proj: Arc<dyn QuantMethod>,
300 q_norm: RmsNorm,
301 k_norm: RmsNorm,
302 num_heads: usize,
303 num_kv_heads: usize,
304 head_dim: usize,
305 sdpa_params: SdpaParams,
306}
307
308impl MLlamaTextCrossAttention {
309 fn new(
310 cfg: &MLlamaTextConfig,
311 vb: ShardedVarBuilder,
312 mapper: &dyn DeviceMapper,
313 layer_idx: usize,
314 comm: &Arc<mistralrs_quant::Comm>,
315 ) -> Result<Self> {
316 Ok(Self {
317 q_proj: ColumnParallelLayer::new(
318 cfg.hidden_size,
319 cfg.num_attention_heads * cfg.head_dim(),
320 &cfg.quantization_config,
321 false,
322 comm,
323 vb.pp("q_proj"),
324 )?,
325 k_proj: ColumnParallelLayer::new(
326 cfg.hidden_size,
327 cfg.num_key_value_heads * cfg.head_dim(),
328 &cfg.quantization_config,
329 false,
330 comm,
331 vb.pp("k_proj"),
332 )?,
333 v_proj: ColumnParallelLayer::new(
334 cfg.hidden_size,
335 cfg.num_key_value_heads * cfg.head_dim(),
336 &cfg.quantization_config,
337 false,
338 comm,
339 vb.pp("v_proj"),
340 )?,
341 o_proj: RowParallelLayer::new(
342 cfg.num_attention_heads * cfg.head_dim(),
343 cfg.hidden_size,
344 &cfg.quantization_config,
345 false,
346 comm,
347 vb.pp("o_proj"),
348 )?,
349 q_norm: RmsNorm::new(
350 cfg.head_dim(),
351 cfg.rms_norm_eps,
352 mapper.set_device(layer_idx, vb.pp("q_norm"), false),
353 )?,
354 k_norm: RmsNorm::new(
355 cfg.head_dim(),
356 cfg.rms_norm_eps,
357 mapper.set_device(layer_idx, vb.pp("k_norm"), false),
358 )?,
359 num_heads: cfg.num_attention_heads / comm.world_size(),
360 num_kv_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
361 head_dim: cfg.head_dim(),
362 sdpa_params: SdpaParams {
363 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
364 use_flash_attn: false,
365 softcap: None,
366 softmax_scale: 1.0 / (cfg.head_dim() as f32).sqrt(),
367 sliding_window: None,
368 },
369 })
370 }
371
372 fn forward(
373 &self,
374 hidden_states: &Tensor,
375 cross_attn_states: Option<&Tensor>,
376 attention_mask: Option<&Tensor>,
377 ) -> Result<Tensor> {
378 let (bs, q_len, _) = hidden_states.dims3()?;
379
380 let mut hidden_states = hidden_states.clone();
381 let original_dtype = hidden_states.dtype();
382 if let Some(t) = self.q_proj.quantized_act_type() {
383 hidden_states = hidden_states.to_dtype(t)?;
384 }
385 let mut q = self.q_proj.forward(&hidden_states)?;
386 if self.q_proj.quantized_act_type().is_some() {
387 q = q.to_dtype(original_dtype)?;
388 }
389 q = q
390 .reshape((bs, q_len, self.num_heads, self.head_dim))?
391 .transpose(1, 2)?;
392 q = self.q_norm.forward(&q)?;
393
394 let (k, v) = if let Some(cross_attn_states) = cross_attn_states {
395 let mut cross_attn_states = cross_attn_states.clone();
396 let original_dtype = cross_attn_states.dtype();
397 if let Some(t) = self.k_proj.quantized_act_type() {
398 cross_attn_states = cross_attn_states.to_dtype(t)?;
399 }
400 let mut k = self.k_proj.forward(&cross_attn_states)?;
401 k = k
402 .reshape((bs, (), self.num_kv_heads, self.head_dim))?
403 .transpose(1, 2)?;
404 if self.q_proj.quantized_act_type().is_some() {
405 k = k.to_dtype(original_dtype)?;
406 }
407 k = self.k_norm.forward(&k)?;
408
409 let mut v = self.v_proj.forward(&cross_attn_states)?;
410 if self.q_proj.quantized_act_type().is_some() {
411 v = v.to_dtype(original_dtype)?;
412 }
413 v = v
414 .reshape((bs, (), self.num_kv_heads, self.head_dim))?
415 .transpose(1, 2)?;
416
417 (k, v)
418 } else {
419 candle_core::bail!("Cross attn cannot find k,v cache or cross attn hidden states!")
420 };
421
422 let mut attn_output = Sdpa
423 .run_attention(
424 &q.contiguous()?,
425 &k.contiguous()?,
426 &v.contiguous()?,
427 attention_mask
428 .map(|m| m.repeat((1, self.num_heads, 1, 1)).unwrap())
429 .as_ref(),
430 None,
431 &self.sdpa_params,
432 )?
433 .transpose(1, 2)?
434 .contiguous()?
435 .reshape((bs, q_len, ()))?
436 .to_dtype(q.dtype())?;
437
438 if let Some(t) = self.q_proj.quantized_act_type() {
439 attn_output = attn_output.to_dtype(t)?;
440 }
441 let mut res = self.o_proj.forward(&attn_output)?;
442 if self.q_proj.quantized_act_type().is_some() {
443 res = res.to_dtype(original_dtype)?;
444 }
445 Ok(res)
446 }
447}
448
449struct MLlamaCrossAttentionDecoderLayer {
450 attn: MLlamaTextCrossAttention,
451 attn_gate: Tensor,
452 mlp: MLlamaTextMlp,
453 mlp_gate: Tensor,
454 input_layernorm: RmsNorm,
455 post_attention_layernorm: RmsNorm,
456}
457
458impl MLlamaCrossAttentionDecoderLayer {
459 fn new(
460 cfg: &MLlamaTextConfig,
461 vb: ShardedVarBuilder,
462 mapper: &dyn DeviceMapper,
463 layer_idx: usize,
464 loading_isq: bool,
465 comm: &Arc<mistralrs_quant::Comm>,
466 ) -> Result<Self> {
467 let mlp = MLlamaTextMlp::new(
468 cfg,
469 mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq),
470 comm,
471 )?;
472 let input_layernorm = RmsNorm::new(
473 cfg.hidden_size,
474 cfg.rms_norm_eps,
475 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
476 )?;
477 let post_attention_layernorm = RmsNorm::new(
478 cfg.hidden_size,
479 cfg.rms_norm_eps,
480 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
481 )?;
482 let attn = MLlamaTextCrossAttention::new(
483 cfg,
484 mapper.set_device(layer_idx, vb.pp("cross_attn"), loading_isq),
485 mapper,
486 layer_idx,
487 comm,
488 )?;
489
490 Ok(Self {
491 attn,
492 mlp,
493 input_layernorm,
494 post_attention_layernorm,
495 attn_gate: mapper
496 .set_device(layer_idx, vb.clone(), false)
497 .get((1,), "cross_attn_attn_gate")?,
498 mlp_gate: mapper
499 .set_device(layer_idx, vb.clone(), false)
500 .get((1,), "cross_attn_mlp_gate")?,
501 })
502 }
503
504 fn forward(
505 &self,
506 hidden_states: &Tensor,
507 cross_attn_states: Option<&Tensor>,
508 attention_mask: Option<&Tensor>,
509 full_text_row_masked_out_mask: Option<&Tensor>,
510 ) -> Result<Tensor> {
511 let residual = hidden_states;
512
513 let mut hidden_states = self.input_layernorm.forward(hidden_states)?;
514
515 hidden_states = self
516 .attn
517 .forward(&hidden_states, cross_attn_states, attention_mask)?;
518 hidden_states = (residual + hidden_states.broadcast_mul(&self.attn_gate.tanh()?)?)?;
519
520 let residual = &hidden_states;
521 let mut hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
522 hidden_states = self.mlp.forward(&hidden_states)?;
523 if let Some(full_text_row_masked_out_mask) = full_text_row_masked_out_mask {
524 hidden_states = full_text_row_masked_out_mask
525 .to_dtype(hidden_states.dtype())?
526 .i((.., 0))?
527 .broadcast_mul(&hidden_states)?;
528 }
529
530 residual + hidden_states.broadcast_mul(&self.mlp_gate.tanh()?)?
531 }
532}
533
534enum MLlamaDecoderLayer {
535 CrossAttn(MLlamaCrossAttentionDecoderLayer),
536 SelfAttn(MLlamaSelfAttentionDecoderLayer),
537}
538
539pub(super) struct MLlamaTextModel {
540 embed_tokens: Embedding,
541 lm_head: Arc<dyn QuantMethod>,
542 norm: RmsNorm,
543 layers: Vec<MLlamaDecoderLayer>,
544 pub(crate) cfg: ModelConfigMetadata,
545 pub(crate) cache: EitherCache,
546 pub(crate) device: Device,
547 pub(crate) max_position_embeddings: usize,
548 mapper: Box<dyn DeviceMapper + Send + Sync>,
549}
550
551impl MLlamaTextModel {
552 pub(super) fn new(
553 cfg: &MLlamaTextConfig,
554 vb: ShardedVarBuilder,
555 is_gptx: bool,
556 normal_loading_metadata: NormalLoadingMetadata,
557 attention_mechanism: AttentionImplementation,
558 ) -> Result<Self> {
559 if let Some(ref quant_cfg) = &cfg.quantization_config {
560 tracing::info!(
561 "Using {} quantization: {}.",
562 quant_cfg.name(),
563 quant_cfg.get_bits_name(&vb)
564 );
565 }
566 if !matches!(attention_mechanism, AttentionImplementation::Eager) {
567 candle_core::bail!("Expected eager attention implementation");
568 }
569 let mapper = normal_loading_metadata.mapper;
570
571 let embed_tokens = embedding(
572 cfg.vocab_size + 8,
573 cfg.hidden_size,
574 mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
575 &cfg.quantization_config,
576 )?;
577
578 let lm_head = if !cfg.tie_word_embeddings {
579 ReplicatedLayer::new(
580 cfg.hidden_size,
581 cfg.vocab_size,
582 &None,
583 false,
584 mapper.set_nm_device(vb.pp("lm_head"), false),
585 )?
586 } else {
587 ReplicatedLayer::from_linear(candle_nn::Linear::new(
588 mapper.cast_nm_device(embed_tokens.embeddings(), false)?,
589 None,
590 ))?
591 };
592
593 let vb = vb.pp("model");
594
595 let norm = RmsNorm::new(
596 cfg.hidden_size,
597 cfg.rms_norm_eps,
598 mapper.set_nm_device(vb.pp("norm"), false),
599 )?;
600
601 let mut ropes = HashMap::new();
602 for layer_idx in 0..cfg.num_hidden_layers {
603 let device = mapper
604 .device_for(layer_idx, false)
605 .unwrap_or(&normal_loading_metadata.real_device);
606 ropes.insert(
607 device.location(),
608 Arc::new(Llama3RotaryEmbedding::new_mllama3(
609 vb.dtype(),
610 cfg,
611 device,
612 is_gptx,
613 )?),
614 );
615 }
616
617 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
618 for i in 0..cfg.num_hidden_layers {
619 let comm = mapper.get_comm_for(i)?;
620 if cfg.cross_attention_layers.contains(&i) {
621 layers.push(MLlamaDecoderLayer::CrossAttn(
622 MLlamaCrossAttentionDecoderLayer::new(
623 cfg,
624 vb.pp(format!("layers.{i}")),
625 &*mapper,
626 i,
627 false,
628 &comm,
629 )?,
630 ))
631 } else {
632 let device = mapper
633 .device_for(i, false)
634 .unwrap_or(&normal_loading_metadata.real_device);
635 layers.push(MLlamaDecoderLayer::SelfAttn(
636 MLlamaSelfAttentionDecoderLayer::new(
637 cfg,
638 vb.pp(format!("layers.{i}")),
639 ropes
640 .get(&device.location())
641 .expect("No RoPE for device location!")
642 .clone(),
643 &*mapper,
644 i,
645 normal_loading_metadata.loading_isq,
646 &comm,
647 )?,
648 ))
649 }
650 }
651
652 Ok(Self {
653 embed_tokens,
654 layers,
655 norm,
656 lm_head,
657 cfg: ModelConfigMetadata {
658 max_seq_len: cfg.max_position_embeddings,
659 num_layers: cfg.num_hidden_layers,
660 hidden_size: cfg.hidden_size,
661 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
662 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
663 .max(1),
664 sliding_window: None,
665 k_head_dim: cfg.head_dim(),
666 v_head_dim: cfg.head_dim(),
667 },
668 cache: EitherCache::Normal(NormalCache::new(
669 cfg.num_hidden_layers,
670 cfg.max_position_embeddings,
671 )),
672 device: normal_loading_metadata.real_device,
673 max_position_embeddings: cfg.max_position_embeddings,
674 mapper,
675 })
676 }
677
678 #[allow(clippy::too_many_arguments)]
679 pub(super) fn forward(
680 &self,
681 input_ids: &Tensor,
682 cross_attn_states: Option<&Tensor>,
683 cross_attention_mask: Option<&Tensor>,
684 full_text_row_masked_out_mask: Option<&Tensor>,
685 seqlen_offsets: &[usize],
686 context_lens: Vec<(usize, usize)>,
687 ) -> Result<Tensor> {
688 let mut hidden_states = self.embed_tokens.forward(input_ids)?;
689
690 let cache = &mut self.cache.normal().0;
691 let self_mask = CausalMasker.make_causal_mask_matrix(
692 input_ids,
693 cache as &dyn PastKvLenCache,
694 hidden_states.dtype(),
695 self.cfg.num_attn_heads,
696 )?;
697
698 for (i, layer) in self.layers.iter().enumerate() {
699 hidden_states = self.mapper.map(hidden_states, i)?;
700 match layer {
701 MLlamaDecoderLayer::SelfAttn(attn) => {
702 hidden_states = attn.forward(
703 &hidden_states,
704 self_mask
705 .as_ref()
706 .map(|m| m.to_device(hidden_states.device()).unwrap())
707 .as_ref(),
708 seqlen_offsets,
709 &mut cache[i],
710 )?;
711 }
712 MLlamaDecoderLayer::CrossAttn(attn) => {
713 if cross_attn_states.is_none() {
717 continue;
718 }
719 hidden_states = attn.forward(
720 &hidden_states,
721 cross_attn_states
722 .as_ref()
723 .map(|x| x.to_device(hidden_states.device()).unwrap())
724 .as_ref(),
725 cross_attention_mask
726 .as_ref()
727 .map(|m| m.to_device(hidden_states.device()).unwrap())
728 .as_ref(),
729 full_text_row_masked_out_mask
730 .as_ref()
731 .map(|m| m.to_device(hidden_states.device()).unwrap())
732 .as_ref(),
733 )?;
734 }
735 }
736 }
737
738 hidden_states = hidden_states.to_device(&self.device)?;
739 hidden_states = self.norm.forward(&hidden_states)?;
740
741 hidden_states = self
742 .lm_head
743 .forward(&extract_logits(&hidden_states, context_lens)?)?;
744
745 Ok(hidden_states)
746 }
747}
748
749impl IsqModel for MLlamaTextModel {
750 fn get_layers(
751 &mut self,
752 ) -> (
753 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
754 &dyn DeviceMapper,
755 ) {
756 let mut tensors = Vec::new();
757 for (i, layer) in self.layers.iter_mut().enumerate() {
758 match layer {
759 MLlamaDecoderLayer::CrossAttn(_cross) => {
760 }
768 MLlamaDecoderLayer::SelfAttn(self_attn) => {
769 tensors.push((&mut self_attn.attn.q_proj, Some(i)));
770 tensors.push((&mut self_attn.attn.k_proj, Some(i)));
771 tensors.push((&mut self_attn.attn.v_proj, Some(i)));
772 tensors.push((&mut self_attn.attn.o_proj, Some(i)));
773 tensors.push((&mut self_attn.mlp.gate_proj, Some(i)));
774 tensors.push((&mut self_attn.mlp.up_proj, Some(i)));
775 tensors.push((&mut self_attn.mlp.down_proj, Some(i)));
776 }
777 }
778 }
779 (tensors, &*self.mapper)
780 }
781
782 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
783 let uvb = UnVarBuilder::new();
784
785 uvb.pp("model.embed_tokens").add(&self.embed_tokens);
786 uvb.pp("lm_head").add(&self.lm_head);
787
788 let uvb = uvb.pp("model");
789
790 uvb.pp("norm").add(&self.norm);
791
792 for (i, layer) in self.layers.iter().enumerate() {
793 let uvb_l = uvb.pp("layers").pp(i);
794 match layer {
795 MLlamaDecoderLayer::CrossAttn(crossattn) => {
796 uvb_l
798 .pp("post_attention_layernorm")
799 .add(&crossattn.post_attention_layernorm);
800 uvb_l.pp("input_layernorm").add(&crossattn.input_layernorm);
801 uvb_l.add_tensor("cross_attn_attn_gate", crossattn.attn_gate.clone());
802 uvb_l.add_tensor("cross_attn_mlp_gate", crossattn.mlp_gate.clone());
803
804 let uvb_attn = uvb_l.pp("cross_attn");
805 uvb_attn.pp("q_proj").add(&crossattn.attn.q_proj);
806 uvb_attn.pp("k_proj").add(&crossattn.attn.k_proj);
807 uvb_attn.pp("v_proj").add(&crossattn.attn.v_proj);
808 uvb_attn.pp("o_proj").add(&crossattn.attn.o_proj);
809 uvb_attn.pp("q_norm").add(&crossattn.attn.q_norm);
810 uvb_attn.pp("k_norm").add(&crossattn.attn.k_norm);
811
812 let uvb_mlp = uvb_l.pp("mlp");
813 uvb_mlp.pp("gate_proj").add(&crossattn.mlp.gate_proj);
814 uvb_mlp.pp("up_proj").add(&crossattn.mlp.up_proj);
815 uvb_mlp.pp("down_proj").add(&crossattn.mlp.down_proj);
816 }
817 MLlamaDecoderLayer::SelfAttn(selfattn) => {
818 uvb_l
819 .pp("post_attention_layernorm")
820 .add(&selfattn.post_attention_layernorm);
821 uvb_l.pp("input_layernorm").add(&selfattn.input_layernorm);
822 }
823 }
824 }
825
826 uvb.to_safetensors()
827 }
828}