1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
4use candle_nn::{Embedding, Module};
5use mistralrs_quant::{
6 distributed::layers::PackedExperts, linear_no_bias, ColumnParallelLayer, QuantMethod,
7 QuantizedConfig, ReplicatedLayer, RowParallelLayer, ShardedVarBuilder, SumAllReduce,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12 amoe::AnyMoeBaseModelMixin,
13 attention::SdpaParams,
14 device_map::DeviceMapper,
15 layers::{embedding, Activation, CausalMasker, Llama3RotaryEmbedding, RmsNorm, Sdpa},
16 layers_masker::PastKvLenCache,
17 ops::{TopKLastDimOp, TopKOutput},
18 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
19 pipeline::{
20 extract_logits,
21 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
22 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
23 },
24 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
25};
26
27use super::config::TextConfig;
28
29struct CausalSelfAttention {
30 q_proj: Arc<dyn QuantMethod>,
31 k_proj: Arc<dyn QuantMethod>,
32 v_proj: Arc<dyn QuantMethod>,
33 o_proj: Arc<dyn QuantMethod>,
34 num_attention_heads: usize,
35 num_key_value_heads: usize,
36 head_dim: usize,
37 rotary_emb: Arc<Llama3RotaryEmbedding>,
38 max_seq_len: usize,
39 paged_attn: Option<PagedAttention>,
40 sdpa_params: SdpaParams,
41 norm: Option<RmsNorm>,
42 use_rope: bool,
43 floor_scale: Option<f32>,
44 attn_scale: Option<f32>,
45 attn_temperature_tuning: Option<f32>,
46}
47
48impl CausalSelfAttention {
49 #[allow(clippy::too_many_arguments)]
50 fn new(
51 vb: ShardedVarBuilder,
52 cfg: &TextConfig,
53 layer_idx: usize,
54 loading_isq: bool,
55 mapper: &dyn DeviceMapper,
56 rope: Arc<Llama3RotaryEmbedding>,
57 paged_attn: Option<PagedAttention>,
58 comm: &Arc<mistralrs_quant::Comm>,
59 ) -> Result<Self> {
60 let size_in = cfg.hidden_size;
61 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
62 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
63 let q_proj = ColumnParallelLayer::new(
64 size_in,
65 size_q,
66 &cfg.quantization_config,
67 false,
68 comm,
69 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
70 )?;
71 let kv_shard = mistralrs_quant::compute_kv_shard(
72 cfg.num_key_value_heads,
73 cfg.hidden_size / cfg.num_attention_heads,
74 comm,
75 );
76 let k_proj = ColumnParallelLayer::new_with_shard(
77 size_in,
78 size_kv,
79 &cfg.quantization_config,
80 false,
81 comm,
82 kv_shard,
83 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
84 )?;
85 let v_proj = ColumnParallelLayer::new_with_shard(
86 size_in,
87 size_kv,
88 &cfg.quantization_config,
89 false,
90 comm,
91 kv_shard,
92 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
93 )?;
94 let o_proj = RowParallelLayer::new(
95 size_q,
96 size_in,
97 &cfg.quantization_config,
98 false,
99 comm,
100 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
101 )?;
102 let use_rope = (layer_idx + 1) % 4 != 0;
103 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
104 let norm = if cfg.use_qk_norm && use_rope {
105 let vb = mapper.set_device(layer_idx, vb, false);
106 Some(RmsNorm::from_w(
107 Tensor::ones(head_dim, vb.dtype(), vb.device())?,
108 1e-6,
109 )?)
110 } else {
111 None
112 };
113
114 Ok(Self {
115 q_proj,
116 k_proj,
117 v_proj,
118 o_proj,
119 num_attention_heads: cfg.num_attention_heads / comm.world_size(),
120 num_key_value_heads: (cfg.num_key_value_heads / comm.world_size()).max(1),
121 head_dim,
122 rotary_emb: rope,
123 max_seq_len: cfg.max_position_embeddings,
124 paged_attn,
125 sdpa_params: SdpaParams {
126 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
127 cfg.num_key_value_heads,
128 cfg.num_attention_heads,
129 comm,
130 ),
131 use_flash_attn: cfg.use_flash_attn,
132 softcap: None,
133 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
134 sliding_window: None,
135 },
136 norm,
137 use_rope,
138 floor_scale: cfg.floor_scale,
139 attn_scale: cfg.attn_scale,
140 attn_temperature_tuning: cfg.attn_temperature_tuning,
141 })
142 }
143
144 #[allow(clippy::too_many_arguments)]
145 fn forward(
146 &self,
147 x: &Tensor,
148 position_ids: &Tensor,
149 attention_mask: &Option<Tensor>,
150 seqlen_offsets: &[usize],
151 kv_cache: &mut KvCache,
152 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
153 flash_params: &FlashParams,
154 ) -> Result<Tensor> {
155 let (b_sz, seq_len, _) = x.dims3()?;
156
157 let mut q = self.q_proj.forward_autocast(x)?;
158 let mut k = self.k_proj.forward_autocast(x)?;
159 let mut v = self.v_proj.forward_autocast(x)?;
160
161 q = q
162 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
163 .transpose(1, 2)?;
164 k = k
165 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
166 .transpose(1, 2)?;
167 v = v
168 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
169 .transpose(1, 2)?;
170
171 if self.use_rope {
172 (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
173 }
174
175 if let Some(qk_norm) = &self.norm {
176 q = qk_norm.forward(&q)?;
177 k = qk_norm.forward(&k)?;
178 }
179
180 if self.attn_temperature_tuning.is_some() && !self.use_rope {
181 let floor_scale = self.floor_scale.unwrap() as f64;
182 let attn_scale = self.attn_scale.unwrap() as f64;
183 let floor = ((position_ids.to_dtype(DType::F32)? + 1.)? / floor_scale)?.floor()?;
184 let attn_scales = (((floor + 1.0)?.log()? * attn_scale)? + 1.0)?;
185
186 q = q
187 .to_dtype(DType::F32)?
188 .broadcast_mul(&attn_scales.unsqueeze(D::Minus1)?)?
189 .to_dtype(q.dtype())?;
190 }
191
192 let mut y = match &self.paged_attn {
193 Some(paged_attn) => match metadata {
194 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
195 &q,
196 &k,
197 &v,
198 attention_mask.clone().as_ref(),
199 Some(key_cache),
200 Some(value_cache),
201 input_metadata,
202 &self.sdpa_params,
203 Some(flash_params),
204 )?,
205 None => {
206 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
209 assert!(attention_mask.is_some());
211 paged_attn.forward(
212 &q,
213 &k,
214 &v,
215 attention_mask.clone().as_ref(),
216 None,
217 None,
218 &input_metadata,
219 &self.sdpa_params,
220 Some(flash_params),
221 )?
222 }
223 },
224 None => {
225 let (k, v) = kv_cache.append(&k, &v)?;
226
227 Sdpa.run_attention(
228 &q.contiguous()?,
229 &k.contiguous()?,
230 &v.contiguous()?,
231 attention_mask.clone().as_ref(),
232 Some(flash_params),
233 &self.sdpa_params,
234 )?
235 }
236 };
237
238 y = if attention_mask.is_some() {
239 y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?
240 } else {
241 y.reshape((b_sz, seq_len, ()))?
242 };
243 self.o_proj.forward_autocast(&y)
244 }
245}
246
247struct Mlp {
248 gate: Arc<dyn QuantMethod>,
249 up: Arc<dyn QuantMethod>,
250 down: Arc<dyn QuantMethod>,
251 act: Activation,
252}
253
254impl Mlp {
255 fn new(
256 vb: ShardedVarBuilder,
257 hidden_size: usize,
258 intermediate_size: usize,
259 quantization_config: &Option<QuantizedConfig>,
260 hidden_act: Activation,
261 comm: &Arc<mistralrs_quant::Comm>,
262 ) -> Result<Self> {
263 Ok(Self {
264 gate: ColumnParallelLayer::new(
265 hidden_size,
266 intermediate_size,
267 quantization_config,
268 false,
269 comm,
270 vb.pp("gate_proj"),
271 )?,
272 up: ColumnParallelLayer::new(
273 hidden_size,
274 intermediate_size,
275 quantization_config,
276 false,
277 comm,
278 vb.pp("up_proj"),
279 )?,
280 down: RowParallelLayer::new(
281 intermediate_size,
282 hidden_size,
283 quantization_config,
284 false,
285 comm,
286 vb.pp("down_proj"),
287 )?,
288 act: hidden_act,
289 })
290 }
291
292 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
293 let lhs = self.gate.forward_autocast(xs)?;
294 let rhs = self.up.forward_autocast(xs)?;
295
296 self.down.forward_autocast(&candle_nn::ops::mul_and_act(
297 &lhs,
298 &rhs,
299 self.act.try_into()?,
300 )?)
301 }
302}
303
304struct TextExperts {
305 gate_proj: Vec<Arc<dyn QuantMethod>>,
306 up_proj: Vec<Arc<dyn QuantMethod>>,
307 down_proj: Vec<Arc<dyn QuantMethod>>,
308 act: Activation,
309 hidden_size: usize,
310 sum_all_reduce: SumAllReduce,
311}
312
313impl TextExperts {
314 fn new(
315 vb: ShardedVarBuilder,
316 cfg: &TextConfig,
317 quantization_config: &Option<QuantizedConfig>,
318 comm: &Arc<mistralrs_quant::Comm>,
319 ) -> Result<Self> {
320 let PackedExperts {
321 gate_proj,
322 up_proj,
323 down_proj,
324 } = PackedExperts::new(
325 cfg.num_local_experts,
326 cfg.hidden_size,
327 cfg.intermediate_size,
328 quantization_config,
329 false,
330 comm,
331 vb,
332 )?;
333 Ok(Self {
334 gate_proj,
335 up_proj,
336 down_proj,
337 act: cfg.hidden_act,
338 hidden_size: cfg.hidden_size,
339 sum_all_reduce: SumAllReduce::new(comm),
340 })
341 }
342
343 fn forward(&self, xs: &Tensor, indices: &Tensor) -> Result<Tensor> {
346 let xs = xs.unsqueeze(1)?;
347
348 if self.gate_proj.len() == 1 {
349 let gate = self.gate_proj[0].gather_forward_autocast(&xs, indices)?;
350 let up = self.up_proj[0].gather_forward_autocast(&xs, indices)?;
351 let mut xs = self.down_proj[0]
352 .gather_forward_autocast(&(up * gate.apply(&self.act)?)?, indices)?;
353 xs = self.sum_all_reduce.sum_all_reduce(&xs)?;
354 xs.reshape(((), self.hidden_size))
355 } else {
356 let indices = indices.to_vec1::<u32>()?;
357 let mut results = Vec::new();
358 for (tok, id) in indices.into_iter().enumerate() {
359 let xs = xs.i(tok)?.reshape((1, self.hidden_size))?;
360
361 let res = {
362 let gate = self.gate_proj[id as usize].forward_autocast(&xs)?;
363 let up = self.up_proj[id as usize].forward_autocast(&xs)?;
364 self.down_proj[id as usize].forward_autocast(&(up * gate.apply(&self.act)?)?)?
365 };
366 results.push(res);
367 }
368 let mut xs = Tensor::cat(&results, 0)?;
369 xs = self.sum_all_reduce.sum_all_reduce(&xs)?;
370 xs.reshape(((), self.hidden_size))
371 }
372 }
373}
374
375struct TextMoe {
376 experts: TextExperts,
377 shared_expert: Mlp,
378 router: Arc<dyn QuantMethod>,
379 topk: usize,
380}
381
382impl TextMoe {
383 fn new(
384 vb: ShardedVarBuilder,
385 cfg: &TextConfig,
386 quantization_config: &Option<QuantizedConfig>,
387 comm: &Arc<mistralrs_quant::Comm>,
388 ) -> Result<Self> {
389 let experts = TextExperts::new(vb.pp("experts"), cfg, quantization_config, comm)?;
390 let router = linear_no_bias(
391 cfg.hidden_size,
392 cfg.num_local_experts,
393 quantization_config,
394 vb.pp("router"),
395 )?;
396 let shared_expert = Mlp::new(
397 vb.pp("shared_expert"),
398 cfg.hidden_size,
399 cfg.intermediate_size,
400 quantization_config,
401 cfg.hidden_act,
402 comm,
403 )?;
404 Ok(Self {
405 experts,
406 shared_expert,
407 router,
408 topk: cfg.num_experts_per_tok,
409 })
410 }
411
412 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
413 let (bs, seq_len, hidden_dim) = xs.dims3()?;
414 let xs = xs.reshape(((), hidden_dim))?;
415 let router_logits = self.router.forward_autocast(&xs)?;
416
417 let TopKOutput {
418 values: router_top_value,
419 indices: router_indices,
420 } = router_logits.topk(self.topk)?;
421
422 let router_scores = candle_nn::ops::sigmoid(&router_top_value.to_dtype(DType::F32)?)?
423 .to_dtype(router_top_value.dtype())?;
424
425 let routed_in = xs.broadcast_mul(&router_scores)?;
426 let routed_out = self
427 .experts
428 .forward(&routed_in, &router_indices.squeeze(D::Minus1)?)?
429 .reshape((bs, seq_len, hidden_dim))?;
430 let out = self
431 .shared_expert
432 .forward(&xs.reshape((bs, seq_len, hidden_dim))?)?;
433
434 out + routed_out
435 }
436}
437
438enum MoeOrMlp {
439 Mlp(Mlp),
440 Moe(TextMoe),
441}
442
443impl MoeOrMlp {
444 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
445 match self {
446 Self::Mlp(l) => l.forward(xs),
447 Self::Moe(l) => l.forward(xs),
448 }
449 }
450}
451
452struct Block {
453 rms_1: RmsNorm,
454 attn: CausalSelfAttention,
455 rms_2: RmsNorm,
456 ff: MoeOrMlp,
457 use_chunked_attention: bool,
458}
459
460impl Block {
461 #[allow(clippy::too_many_arguments)]
462 fn new(
463 vb: ShardedVarBuilder,
464 cfg: &TextConfig,
465 mapper: &dyn DeviceMapper,
466 layer_idx: usize,
467 loading_isq: bool,
468 rope: Arc<Llama3RotaryEmbedding>,
469 paged_attn: Option<PagedAttention>,
470 comm: &Arc<mistralrs_quant::Comm>,
471 ) -> Result<Self> {
472 let use_chunked_attention = (layer_idx + 1) % 4 != 0;
473 let attn = CausalSelfAttention::new(
474 vb.pp("self_attn"),
475 cfg,
476 layer_idx,
477 loading_isq,
478 mapper,
479 rope,
480 paged_attn,
481 comm,
482 )?;
483 let is_moe_layer = cfg.moe_layers().contains(&layer_idx);
484 let ff = if is_moe_layer {
485 let moe = TextMoe::new(
486 mapper.set_device(layer_idx, vb.pp("feed_forward"), loading_isq),
487 cfg,
488 &cfg.quantization_config,
489 comm,
490 )?;
491 MoeOrMlp::Moe(moe)
492 } else {
493 let mlp = Mlp::new(
494 mapper.set_device(layer_idx, vb.pp("feed_forward"), loading_isq),
495 cfg.hidden_size,
496 cfg.intermediate_size_mlp,
497 &cfg.quantization_config,
498 cfg.hidden_act,
499 comm,
500 )?;
501 MoeOrMlp::Mlp(mlp)
502 };
503 let rms_1 = RmsNorm::new(
504 cfg.hidden_size,
505 cfg.rms_norm_eps,
506 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
507 )?;
508 let rms_2 = RmsNorm::new(
509 cfg.hidden_size,
510 cfg.rms_norm_eps,
511 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
512 )?;
513 Ok(Self {
514 rms_1,
515 attn,
516 rms_2,
517 ff,
518 use_chunked_attention,
519 })
520 }
521
522 #[allow(clippy::too_many_arguments)]
523 fn forward(
524 &self,
525 x: &Tensor,
526 position_ids: &Tensor,
527 attention_mask: &Option<Tensor>,
528 chunked_mask: &Option<Tensor>,
529 seqlen_offsets: &[usize],
530 kv_cache: &mut KvCache,
531 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
532 flash_params: &FlashParams,
533 ) -> Result<Tensor> {
534 let residual = x;
535 let x = self.rms_1.forward(x)?;
536 let mask = if self.use_chunked_attention {
537 chunked_mask
538 } else {
539 attention_mask
540 };
541 let x = (self.attn.forward(
542 &x,
543 position_ids,
544 mask,
545 seqlen_offsets,
546 kv_cache,
547 metadata,
548 flash_params,
549 )? + residual)?;
550 let residual = &x;
551 let x = (self.ff.forward(&self.rms_2.forward(&x)?)? + residual)?;
552 Ok(x)
553 }
554}
555
556pub struct TextModel {
557 wte: Embedding,
558 blocks: Vec<Block>,
559 ln_f: RmsNorm,
560 lm_head: Arc<dyn QuantMethod>,
561 kv_cache: crate::pipeline::EitherCache,
562 device: Device,
563 mapper: Box<dyn DeviceMapper + Send + Sync>,
564 cfg: ModelConfigMetadata,
565 attention_chunk_size: usize,
566}
567
568impl TextModel {
569 pub fn new(
570 cfg: &TextConfig,
571 vb: ShardedVarBuilder,
572 is_gptx: bool,
573 normal_loading_metadata: NormalLoadingMetadata,
574 attention_mechanism: AttentionImplementation,
575 ) -> Result<Self> {
576 let vb_m = vb.pp("model");
577 let vb_lm_head = vb.pp("lm_head");
578 Self::new_inner(
579 cfg,
580 vb_m,
581 vb_lm_head,
582 is_gptx,
583 normal_loading_metadata,
584 attention_mechanism,
585 )
586 }
587
588 pub fn new_inner(
589 cfg: &TextConfig,
590 vb_m: ShardedVarBuilder,
591 vb_lm_head: ShardedVarBuilder,
592 is_gptx: bool,
593 normal_loading_metadata: NormalLoadingMetadata,
594 attention_mechanism: AttentionImplementation,
595 ) -> Result<Self> {
596 if let Some(ref quant_cfg) = &cfg.quantization_config {
597 tracing::info!(
598 "Using {} quantization: {}.",
599 quant_cfg.name(),
600 quant_cfg.get_bits_name(&vb_m)
601 );
602 }
603 let mapper = normal_loading_metadata.mapper;
604
605 let wte = embedding(
606 cfg.vocab_size,
607 cfg.hidden_size,
608 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
609 &cfg.quantization_config,
610 )?;
611 let lm_head = if !cfg.tie_word_embeddings {
612 ReplicatedLayer::new(
613 cfg.hidden_size,
614 cfg.vocab_size,
615 &None,
616 false,
617 mapper.set_nm_device(vb_lm_head, normal_loading_metadata.loading_isq),
618 )?
619 } else {
620 ReplicatedLayer::from_linear(candle_nn::Linear::new(
621 mapper.cast_nm_device(wte.embeddings(), normal_loading_metadata.loading_isq)?,
622 None,
623 ))?
624 };
625 let ln_f = RmsNorm::new(
626 cfg.hidden_size,
627 cfg.rms_norm_eps,
628 mapper.set_nm_device(vb_m.pp("norm"), false),
629 )?;
630 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
631 let mut ropes = HashMap::new();
632 for i in 0..cfg.num_hidden_layers {
633 let device = mapper
634 .device_for(i, false)
635 .unwrap_or(&normal_loading_metadata.real_device);
636 ropes.insert(
637 device.location(),
638 Arc::new(Llama3RotaryEmbedding::new_llama4(
639 vb_m.dtype(),
640 cfg,
641 device,
642 is_gptx,
643 )?),
644 );
645 }
646 let blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
647 0..cfg.num_hidden_layers,
648 "Loading text repeating layers",
649 &normal_loading_metadata.multi_progress,
650 )
651 .into_iter()
652 .map(|i| {
653 let device = mapper
654 .device_for(i, false)
655 .unwrap_or(&normal_loading_metadata.real_device);
656 let rotary_emb = ropes
657 .get(&device.location())
658 .expect("No RoPE for device location!")
659 .clone();
660 let paged_attn = match &attention_mechanism {
661 AttentionImplementation::Eager => None,
662 AttentionImplementation::PagedAttention => Some(
663 PagedAttention::new(head_dim, device, None)
664 .expect("Failed to create PagedAttention"),
665 ),
666 };
667 let comm = mapper.get_comm_for(i).unwrap();
668 Block::new(
669 vb_m.pp(format!("layers.{i}")),
670 cfg,
671 &*mapper,
672 i,
673 normal_loading_metadata.loading_isq,
674 rotary_emb,
675 paged_attn,
676 &comm,
677 )
678 .expect("Failed to load block.")
679 })
680 .collect();
681
682 Ok(Self {
683 wte,
684 blocks,
685 ln_f,
686 lm_head,
687 kv_cache: EitherCache::Normal(NormalCache::new(
688 cfg.num_hidden_layers,
689 cfg.max_position_embeddings,
690 )),
691 device: normal_loading_metadata.real_device,
692 cfg: ModelConfigMetadata {
693 max_seq_len: cfg.max_position_embeddings,
694 num_layers: cfg.num_hidden_layers,
695 hidden_size: cfg.hidden_size,
696 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
697 .max(1),
698 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
699 sliding_window: None,
700 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
702 },
703 mapper,
704 attention_chunk_size: cfg.attention_chunk_size,
705 })
706 }
707
708 pub fn get_input_embeddings(&self, input_ids: &Tensor) -> Result<Tensor> {
709 self.wte.forward(input_ids)
710 }
711
712 #[allow(clippy::too_many_arguments)]
713 pub fn forward_embeds(
714 &self,
715 input_ids: &Tensor,
716 input_embeds: Tensor,
717 seqlen_offsets: &[usize],
718 context_lens: Vec<(usize, usize)>,
719 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
720 flash_params: &FlashParams,
721 ) -> Result<Tensor> {
722 let mut x = input_embeds;
723 let cache = &mut self.kv_cache.normal().0;
724 let cache_for_mask = metadata
725 .as_ref()
726 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
727 .unwrap_or(cache as &dyn PastKvLenCache);
728
729 let position_ids = Tensor::new(
730 seqlen_offsets.iter().map(|o| *o as i32).collect::<Vec<_>>(),
731 input_ids.device(),
732 )?;
733
734 let mask = CausalMasker.make_causal_mask_matrix(
735 input_ids,
736 cache_for_mask,
737 x.dtype(),
738 self.blocks[0].attn.num_attention_heads,
739 )?;
740 let chunked_mask = CausalMasker.make_chunked_mask_matrix(
741 input_ids,
742 self.attention_chunk_size,
743 cache_for_mask,
744 x.dtype(),
745 self.blocks[0].attn.num_attention_heads,
746 )?;
747 let mask = mask.filter(|_| {
749 metadata
750 .as_ref()
751 .map(|(_, meta)| meta.is_first_prompt_chunk)
752 .unwrap_or(true)
753 });
754 let chunked_mask = chunked_mask.filter(|_| {
756 metadata
757 .as_ref()
758 .map(|(_, meta)| meta.is_first_prompt_chunk)
759 .unwrap_or(true)
760 });
761 for (block_idx, block) in self.blocks.iter().enumerate() {
762 x = self.mapper.map(x, block_idx)?;
763 x = block.forward(
764 &x,
765 &position_ids.to_device(x.device())?,
766 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
767 &chunked_mask
768 .clone()
769 .map(|m| m.to_device(x.device()).unwrap()),
770 seqlen_offsets,
771 &mut cache[block_idx],
772 metadata
773 .as_ref()
774 .map(|(kv_cache, metadata)| (kv_cache[block_idx].clone(), *metadata)),
775 flash_params,
776 )?;
777 }
778 let mut x = x.to_device(&self.device)?;
779 x = self.ln_f.forward(&x)?;
780 x = self.lm_head.forward_autocast(&x)?;
781 extract_logits(&x, context_lens)
782 }
783
784 pub fn residual_tensors_m(&self, uvb_m: UnVarBuilder) -> Vec<(String, Tensor)> {
785 uvb_m.pp("embed_tokens").add(&self.wte);
786 uvb_m.pp("norm").add(&self.ln_f);
787
788 for (layer_idx, layer) in self.blocks.iter().enumerate() {
789 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
790 uvb_l.pp("input_layernorm").add(&layer.rms_1);
791 uvb_l.pp("post_attention_layernorm").add(&layer.rms_2);
792 }
793
794 uvb_m.to_safetensors()
795 }
796}
797
798impl IsqModel for TextModel {
799 fn get_layers(
800 &mut self,
801 ) -> (
802 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
803 &dyn DeviceMapper,
804 ) {
805 let mut tensors = Vec::new();
806 tensors.push((&mut self.lm_head, None));
807 for (i, layer) in self.blocks.iter_mut().enumerate() {
808 tensors.push((&mut layer.attn.q_proj, Some(i)));
809 tensors.push((&mut layer.attn.k_proj, Some(i)));
810 tensors.push((&mut layer.attn.v_proj, Some(i)));
811 tensors.push((&mut layer.attn.o_proj, Some(i)));
812 match &mut layer.ff {
813 MoeOrMlp::Mlp(x) => {
814 tensors.push((&mut x.gate, Some(i)));
815 tensors.push((&mut x.up, Some(i)));
816 tensors.push((&mut x.down, Some(i)));
817 }
818 MoeOrMlp::Moe(x) => {
819 tensors.push((&mut x.router, Some(i)));
820 for g in &mut x.experts.gate_proj {
821 tensors.push((g, Some(i)));
822 }
823 for u in &mut x.experts.up_proj {
824 tensors.push((u, Some(i)));
825 }
826 for d in &mut x.experts.down_proj {
827 tensors.push((d, Some(i)));
828 }
829 tensors.push((&mut x.shared_expert.gate, Some(i)));
830 tensors.push((&mut x.shared_expert.up, Some(i)));
831 tensors.push((&mut x.shared_expert.down, Some(i)));
832 }
833 }
834 }
835 (tensors, &*self.mapper)
836 }
837
838 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
839 let uvb = UnVarBuilder::new();
840 self.residual_tensors_m(uvb.pp("model"))
841 }
842}
843
844impl NormalModel for TextModel {
845 fn forward(
846 &self,
847 _input_ids: &Tensor,
848 _seqlen_offsets: &[usize],
849 _context_lens: Vec<(usize, usize)>,
850 _position_ids: Vec<usize>,
851 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
852 _flash_params: &FlashParams,
853 ) -> Result<Tensor> {
854 unreachable!()
855 }
856 fn xlora_forward(
857 &self,
858 _input_ids: &Tensor,
859 _input_ids_full: &Tensor,
860 _seqlen_offsets: &[usize],
861 _seqlen_offsets_full: &[usize],
862 _no_kv_cache: bool,
863 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
864 _context_lens: Vec<(usize, usize)>,
865 _position_ids: Vec<usize>,
866 _flash_params: &FlashParams,
867 _flash_params_full: &FlashParams,
868 ) -> Result<Tensor> {
869 unimplemented!()
870 }
871 fn cache(&self) -> &crate::pipeline::EitherCache {
872 &self.kv_cache
873 }
874 fn cache_mut(&mut self) -> &mut crate::pipeline::EitherCache {
875 &mut self.kv_cache
876 }
877 fn device(&self) -> &Device {
878 &self.device
879 }
880 fn is_xlora(&self) -> bool {
881 false
882 }
883 fn max_seq_len(&self) -> usize {
884 self.blocks[0].attn.max_seq_len
885 }
886 fn config(&self) -> &ModelConfigMetadata {
887 &self.cfg
888 }
889}
890
891impl AnyMoeBaseModelMixin for TextModel {}