1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{
6 ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
7 RowParallelLayer, ShardedVarBuilder,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12 amoe::AnyMoeBaseModelMixin,
13 attention::SdpaParams,
14 device_map::DeviceMapper,
15 layers::{
16 self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
17 PhiRotaryEmbedding, Sdpa,
18 },
19 layers_masker::{masked_fill, PastKvLenCache},
20 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21 pipeline::{
22 extract_logits,
23 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25 },
26 serde_default_fn,
27 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
34pub struct Config {
35 pub(crate) vocab_size: usize,
36 pub(crate) hidden_act: Activation,
37 pub(crate) hidden_size: usize,
38 pub(crate) intermediate_size: usize,
39 pub(crate) num_hidden_layers: usize,
40 pub(crate) num_attention_heads: usize,
41 pub(crate) num_key_value_heads: usize,
42 pub(crate) rms_norm_eps: f64,
43 pub(crate) rope_theta: f64,
44 pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
45 pub(crate) max_position_embeddings: usize,
46 pub(crate) sliding_window: Option<usize>,
47 pub(crate) original_max_position_embeddings: usize,
48
49 pub(crate) quantization_config: Option<QuantizedConfig>,
50 pub(crate) lm_head_bias: bool,
51 pub(crate) attention_bias: bool,
52 pub(crate) num_local_experts: usize,
53 pub(crate) router_jitter_noise: f64,
54 #[serde(default = "word_emb_default")]
55 pub(crate) tie_word_embeddings: bool,
56}
57
58impl From<Config> for PhiRopeConfig {
59 fn from(val: Config) -> Self {
60 PhiRopeConfig {
61 rope_scaling: val.rope_scaling,
62 max_position_embeddings: val.max_position_embeddings,
63 original_max_position_embeddings: val.original_max_position_embeddings,
64 rope_theta: val.rope_theta,
65 head_dim: val.hidden_size / val.num_attention_heads,
66 partial_rotary_factor: None,
67 }
68 }
69}
70
71impl Config {
72 pub fn head_dim(&self) -> usize {
73 self.hidden_size / self.num_attention_heads
74 }
75}
76
77struct Attention {
78 q_proj: Arc<dyn QuantMethod>,
79 k_proj: Arc<dyn QuantMethod>,
80 v_proj: Arc<dyn QuantMethod>,
81 o_proj: Arc<dyn QuantMethod>,
82 num_heads: usize,
83 num_kv_heads: usize,
84 head_dim: usize,
85 rotary_emb: Arc<PhiRotaryEmbedding>,
86 paged_attn: Option<PagedAttention>,
87 sdpa_params: SdpaParams,
88}
89
90impl Attention {
91 fn new(
92 rotary_emb: Arc<PhiRotaryEmbedding>,
93 cfg: &Config,
94 vb: ShardedVarBuilder,
95 paged_attn: Option<PagedAttention>,
96 comm: &Arc<mistralrs_quant::Comm>,
97 ) -> Result<Self> {
98 let num_heads = cfg.num_attention_heads;
99 let num_kv_heads = cfg.num_key_value_heads;
100 let head_dim = cfg.head_dim();
101
102 let q_proj = ColumnParallelLayer::new(
103 cfg.hidden_size,
104 num_heads * head_dim,
105 &cfg.quantization_config,
106 cfg.attention_bias,
107 comm,
108 vb.pp("q_proj"),
109 )?;
110 let kv_shard = mistralrs_quant::compute_kv_shard(
111 cfg.num_key_value_heads,
112 cfg.hidden_size / cfg.num_attention_heads,
113 comm,
114 );
115 let k_proj = ColumnParallelLayer::new_with_shard(
116 cfg.hidden_size,
117 num_kv_heads * head_dim,
118 &cfg.quantization_config,
119 cfg.attention_bias,
120 comm,
121 kv_shard,
122 vb.pp("k_proj"),
123 )?;
124 let v_proj = ColumnParallelLayer::new_with_shard(
125 cfg.hidden_size,
126 num_kv_heads * head_dim,
127 &cfg.quantization_config,
128 cfg.attention_bias,
129 comm,
130 kv_shard,
131 vb.pp("v_proj"),
132 )?;
133 let o_proj = RowParallelLayer::new(
134 num_heads * head_dim,
135 cfg.hidden_size,
136 &cfg.quantization_config,
137 cfg.attention_bias,
138 comm,
139 vb.pp("o_proj"),
140 )?;
141
142 Ok(Self {
143 q_proj,
144 k_proj,
145 v_proj,
146 o_proj,
147 rotary_emb,
148 num_heads: num_heads / comm.world_size(),
149 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
150 head_dim,
151 paged_attn,
152 sdpa_params: SdpaParams {
153 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
154 cfg.num_key_value_heads,
155 cfg.num_attention_heads,
156 comm,
157 ),
158 softcap: None,
159 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
160 sliding_window: cfg.sliding_window,
161 },
162 })
163 }
164
165 #[allow(clippy::too_many_arguments)]
166 fn forward(
167 &self,
168 xs: &Tensor,
169 attention_mask: Option<&Tensor>,
170 seqlen_offsets: &[usize],
171 position_ids: &[usize],
172 kv_cache: &mut KvCache,
173 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
174 flash_params: &FlashParams,
175 ) -> Result<Tensor> {
176 let (b_sz, q_len, _) = xs.dims3()?;
177
178 let original_dtype = xs.dtype();
179 let mut xs = xs.clone();
180 if let Some(t) = self.q_proj.quantized_act_type() {
181 xs = xs.to_dtype(t)?;
182 }
183 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
184 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
185 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
186 if self.q_proj.quantized_act_type().is_some() {
187 q = q.to_dtype(original_dtype)?;
188 k = k.to_dtype(original_dtype)?;
189 v = v.to_dtype(original_dtype)?;
190 }
191
192 let (q, k, v) = if q_len != 1 {
193 let q = q
194 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
195 .transpose(1, 2)?;
196 let k = k
197 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198 .transpose(1, 2)?;
199 let v = v
200 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
201 .transpose(1, 2)?;
202 (q, k, v)
203 } else {
204 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
205 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
206 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
207 (q, k, v)
208 };
209
210 let (q, k) = self
211 .rotary_emb
212 .forward(&q, &k, seqlen_offsets, position_ids)?;
213
214 let mut attn_output = match &self.paged_attn {
215 Some(paged_attn) => match metadata {
216 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
217 &q,
218 &k,
219 &v,
220 attention_mask,
221 Some(key_cache),
222 Some(value_cache),
223 input_metadata,
224 &self.sdpa_params,
225 Some(flash_params),
226 )?,
227 None => {
228 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
231 assert!(attention_mask.is_some());
233 paged_attn.forward(
234 &q,
235 &k,
236 &v,
237 attention_mask,
238 None,
239 None,
240 &input_metadata,
241 &self.sdpa_params,
242 Some(flash_params),
243 )?
244 }
245 },
246 _ => {
247 let (k, v) = kv_cache.append(&k, &v)?;
248
249 Sdpa.run_attention(
250 &q,
251 &k,
252 &v,
253 attention_mask,
254 Some(flash_params),
255 &self.sdpa_params,
256 )?
257 }
258 };
259
260 if let Some(t) = self.q_proj.quantized_act_type() {
261 attn_output = attn_output.to_dtype(t)?;
262 }
263 attn_output = if attention_mask.is_some() {
264 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
265 } else {
266 attn_output.reshape((b_sz, q_len, ()))?
267 };
268 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
269 if self.q_proj.quantized_act_type().is_some() {
270 res = res.to_dtype(original_dtype)?;
271 }
272 Ok(res)
273 }
274}
275
276#[derive(Clone)]
277struct Mlp {
278 w1: Arc<dyn QuantMethod>,
279 w2: Arc<dyn QuantMethod>,
280 w3: Arc<dyn QuantMethod>,
281 act_fn: Activation,
282}
283
284impl Mlp {
285 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
286 let hidden_size = cfg.hidden_size;
287 let i_size = cfg.intermediate_size;
288
289 let w1 = ColumnParallelLayer::new(
290 hidden_size,
291 i_size,
292 &cfg.quantization_config,
293 false,
294 comm,
295 vb.pp("w1"),
296 )?;
297 let w2 = RowParallelLayer::new(
298 i_size,
299 hidden_size,
300 &cfg.quantization_config,
301 false,
302 comm,
303 vb.pp("w2"),
304 )?;
305 let w3 = ColumnParallelLayer::new(
306 hidden_size,
307 i_size,
308 &cfg.quantization_config,
309 false,
310 comm,
311 vb.pp("w3"),
312 )?;
313
314 Ok(Self {
315 w1,
316 w2,
317 w3,
318 act_fn: cfg.hidden_act,
319 })
320 }
321
322 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
323 let original_dtype = xs.dtype();
324 let mut xs = xs.clone();
325 if let Some(t) = self.w1.quantized_act_type() {
326 xs = xs.to_dtype(t)?;
327 }
328 let mut current_hidden_states =
329 MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
330 let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
331 current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?;
332 let mut res = MatMul.qmethod_matmul(¤t_hidden_states, &*self.w2)?;
333 if self.w1.quantized_act_type().is_some() {
334 res = res.to_dtype(original_dtype)?;
335 }
336 Ok(res)
337 }
338}
339
340struct MoeMlp {
341 gate: candle_nn::Linear,
342 experts: Vec<Mlp>,
343 router_jitter_noise: f64,
344 num_experts: usize,
345}
346
347impl MoeMlp {
348 fn new(
349 cfg: &Config,
350 vb: ShardedVarBuilder,
351 layer_device: Device,
352 comm: &Arc<mistralrs_quant::Comm>,
353 ) -> Result<Self> {
354 let num_experts = cfg.num_local_experts;
355 let gate = layers::linear_no_bias(
356 cfg.hidden_size,
357 num_experts,
358 vb.pp("gate").set_device(layer_device),
359 )?;
360
361 let experts_vb = vb.pp("experts");
362 let mut experts = Vec::with_capacity(num_experts);
363 for i in 0..num_experts {
364 experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
365 }
366
367 Ok(Self {
368 gate,
369 experts,
370 router_jitter_noise: cfg.router_jitter_noise,
371 num_experts,
372 })
373 }
374
375 fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
376 let selected_experts = scores.argmax_keepdim(D::Minus1)?;
378 let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
379 let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
380 let mask_logits_threshold = mask_logits_threshold
381 .broadcast_sub(scores)?
382 .broadcast_div(&factor)?
383 .gt(2. * jitter_eps)?;
384
385 let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
387
388 let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
390 let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
391
392 let masked_scores = scores.scatter_add(
394 &selected_experts
395 .broadcast_as(scores.shape())?
396 .contiguous()?,
397 &(scores.ones_like()? * f64::NEG_INFINITY)?,
398 D::Minus1,
399 )?;
400
401 let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
403 let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
404 let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
405 let mask_logits_threshold = mask_logits_threshold
406 .broadcast_sub(scores)?
407 .broadcast_div(&factor)?
408 .gt(2. * jitter_eps)?;
409
410 let masked_gates_top2 =
412 masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
413 let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
414 let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
415
416 let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
417 let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
418
419 Ok((multiplier, selected_experts))
420 }
421
422 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423 let (bs, seq, hidden) = xs.dims3()?;
424 let xs = xs.reshape(((), hidden))?;
425 let xs_dev = xs.device();
426 let xs = xs.to_device(&Device::Cpu)?;
427
428 let router_logits = self
431 .gate
432 .forward(&xs.to_device(xs_dev)?)?
433 .to_device(&Device::Cpu)?;
434 let (routing_weights, selected_experts) = self.sparsemixer(
435 &router_logits.to_device(&Device::Cpu)?,
436 self.router_jitter_noise,
437 )?;
438
439 let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs.device())?;
440
441 let experts_mask =
444 candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
445 .permute((2, 1, 0))?;
446
447 for expert_idx in 0..self.num_experts {
449 let expert = &self.experts[expert_idx];
450 let expert_mask = experts_mask.i(expert_idx)?;
451 assert_eq!(expert_mask.rank(), 2);
452 let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
453 let idx = nonzero_mask.i((.., 0))?;
454 let top_x = nonzero_mask.i((.., 1))?;
455
456 if top_x.dim(0)? == 0 {
457 continue;
458 }
459
460 let current_state = xs.index_select(&top_x, 0)?.reshape((1, (), hidden))?;
464 let current_routing_weights = routing_weights
465 .index_select(&top_x, 0)?
466 .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
467 let exp_out = expert
468 .forward(¤t_state.to_device(xs_dev)?)?
469 .to_device(&Device::Cpu)?;
470
471 let current_hidden_states = exp_out.broadcast_mul(¤t_routing_weights)?;
472
473 final_hidden_states = final_hidden_states.index_add(
474 &top_x.contiguous()?,
475 ¤t_hidden_states.squeeze(0)?.to_dtype(xs.dtype())?,
476 0,
477 )?;
478 }
479
480 final_hidden_states
481 .reshape((bs, seq, hidden))?
482 .to_device(xs_dev)
483 }
484}
485
486struct DecoderLayer {
487 self_attn: Attention,
488 mlp: MoeMlp,
489 input_layernorm: LayerNorm,
490 post_attention_layernorm: LayerNorm,
491}
492
493impl DecoderLayer {
494 #[allow(clippy::too_many_arguments)]
495 fn new(
496 rotary_emb: Arc<PhiRotaryEmbedding>,
497 cfg: &Config,
498 vb: ShardedVarBuilder,
499 mapper: &dyn DeviceMapper,
500 layer_idx: usize,
501 loading_isq: bool,
502 paged_attn: Option<PagedAttention>,
503 real_device: Device,
504 comm: &Arc<mistralrs_quant::Comm>,
505 ) -> Result<Self> {
506 let self_attn = Attention::new(
507 rotary_emb,
508 cfg,
509 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
510 paged_attn,
511 comm,
512 )?;
513 let mlp = MoeMlp::new(
514 cfg,
515 mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
516 mapper
517 .device_for(layer_idx, false)
518 .cloned()
519 .unwrap_or(real_device),
520 comm,
521 )?;
522 let input_layernorm = layer_norm(
523 cfg.hidden_size,
524 cfg.rms_norm_eps,
525 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
526 )?;
527 let post_attention_layernorm = layer_norm(
528 cfg.hidden_size,
529 cfg.rms_norm_eps,
530 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
531 )?;
532 Ok(Self {
533 self_attn,
534 mlp,
535 input_layernorm,
536 post_attention_layernorm,
537 })
538 }
539
540 #[allow(clippy::too_many_arguments)]
541 fn forward(
542 &self,
543 xs: &Tensor,
544 attention_mask: Option<&Tensor>,
545 seqlen_offsets: &[usize],
546 position_ids: &[usize],
547 kv_cache: &mut KvCache,
548 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
549 flash_params: &FlashParams,
550 ) -> Result<Tensor> {
551 let residual = xs;
552 let xs = self.input_layernorm.forward(xs)?;
553 let xs = self.self_attn.forward(
554 &xs,
555 attention_mask,
556 seqlen_offsets,
557 position_ids,
558 kv_cache,
559 metadata,
560 flash_params,
561 )?;
562 let xs = (xs + residual)?;
563 let residual = &xs;
564 let xs = self
565 .mlp
566 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
567 residual + xs
568 }
569}
570
571pub struct Model {
572 embed_tokens: candle_nn::Embedding,
573 layers: Vec<DecoderLayer>,
574 norm: LayerNorm,
575 lm_head: Arc<dyn QuantMethod>,
576 device: Device,
577 cache: EitherCache,
578 max_seq_len: usize,
579 mapper: Box<dyn DeviceMapper + Send + Sync>,
580 sliding_window: Option<usize>,
581 cfg: ModelConfigMetadata,
582}
583
584impl Model {
585 pub fn new(
586 cfg: &Config,
587 vb: ShardedVarBuilder,
588 _is_gptx: bool,
589 normal_loading_metadata: NormalLoadingMetadata,
590 attention_mechanism: AttentionImplementation,
591 ) -> Result<Self> {
592 if let Some(ref quant_cfg) = &cfg.quantization_config {
593 tracing::info!(
594 "Using {} quantization: {}.",
595 quant_cfg.name(),
596 quant_cfg.get_bits_name(&vb)
597 );
598 }
599 let mapper = normal_loading_metadata.mapper;
600 let vb_m = vb.pp("model");
601
602 let embed_tokens = layers::embedding(
603 cfg.vocab_size,
604 cfg.hidden_size,
605 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
606 &cfg.quantization_config,
607 )?;
608 let mut ropes = HashMap::new();
609 for layer_idx in 0..cfg.num_hidden_layers {
610 let device = mapper
611 .device_for(layer_idx, false)
612 .unwrap_or(&normal_loading_metadata.real_device);
613 ropes.insert(
614 device.location(),
615 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
616 );
617 }
618 let vb_l = vb_m.pp("layers");
619 let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
620 0..cfg.num_hidden_layers,
621 "Loading repeating layers",
622 &normal_loading_metadata.multi_progress,
623 )
624 .par_iter_if_isq(|layer_idx| {
625 let device = mapper
626 .device_for(layer_idx, false)
627 .unwrap_or(&normal_loading_metadata.real_device);
628 let rotary_emb = ropes
629 .get(&device.location())
630 .expect("No RoPE for device location!")
631 .clone();
632 let paged_attn = match &attention_mechanism {
633 AttentionImplementation::Eager => None,
634 AttentionImplementation::PagedAttention => {
635 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
636 }
637 };
638 let comm = mapper.get_comm_for(layer_idx)?;
639 DecoderLayer::new(
640 rotary_emb.clone(),
641 cfg,
642 vb_l.pp(layer_idx),
643 &*mapper,
644 layer_idx,
645 normal_loading_metadata.loading_isq,
646 paged_attn,
647 normal_loading_metadata.real_device.clone(),
648 &comm,
649 )
650 })?;
651 let norm = layer_norm(
652 cfg.hidden_size,
653 cfg.rms_norm_eps,
654 mapper.set_nm_device(vb_m.pp("norm"), false),
655 )?;
656 let lm_head = if !cfg.tie_word_embeddings {
657 ReplicatedLayer::new(
658 cfg.hidden_size,
659 cfg.vocab_size,
660 &cfg.quantization_config,
661 cfg.lm_head_bias,
662 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
663 )?
664 } else {
665 unreachable!()
666 };
667 Ok(Self {
668 embed_tokens,
669 layers,
670 norm,
671 lm_head,
672 device: normal_loading_metadata.real_device,
673 cache: EitherCache::Normal(NormalCache::new_sliding(
674 cfg.num_hidden_layers,
675 cfg.max_position_embeddings,
676 cfg.sliding_window,
677 )),
678 max_seq_len: cfg.max_position_embeddings,
679 sliding_window: cfg.sliding_window,
680 cfg: ModelConfigMetadata {
681 max_seq_len: cfg.max_position_embeddings,
682 num_layers: cfg.num_hidden_layers,
683 hidden_size: cfg.hidden_size,
684 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
685 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
686 .max(1),
687 sliding_window: cfg.sliding_window,
688 k_head_dim: cfg.head_dim(),
689 v_head_dim: cfg.head_dim(),
690 },
691 mapper,
692 })
693 }
694
695 pub fn forward(
696 &self,
697 input_ids: &Tensor,
698 seqlen_offsets: &[usize],
699 position_ids: &[usize],
700 context_lens: Vec<(usize, usize)>,
701 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
702 flash_params: &FlashParams,
703 ) -> Result<Tensor> {
704 let mut xs = self.embed_tokens.forward(input_ids)?;
705 let cache = &mut self.cache.normal().0;
706 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
707 input_ids,
708 metadata
709 .as_ref()
710 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
711 .unwrap_or(cache as &dyn PastKvLenCache),
712 self.sliding_window,
713 xs.dtype(),
714 self.cfg.num_attn_heads,
715 )?;
716 let attention_mask = attention_mask.filter(|_| {
718 metadata
719 .as_ref()
720 .map(|(_, meta)| meta.is_first_prompt_chunk)
721 .unwrap_or(true)
722 });
723
724 for (i, layer) in self.layers.iter().enumerate() {
725 xs = self.mapper.map(xs, i)?;
726 xs = layer.forward(
727 &xs,
728 attention_mask
729 .as_ref()
730 .map(|m| m.to_device(xs.device()).unwrap())
731 .as_ref(),
732 seqlen_offsets,
733 position_ids,
734 &mut cache[i],
735 metadata
736 .as_ref()
737 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
738 flash_params,
739 )?
740 }
741 let xs = xs.to_device(&self.device)?;
742 let mut xs = xs.apply(&self.norm)?;
743 if let Some(t) = self.lm_head.quantized_act_type() {
744 xs = xs.to_dtype(t)?;
745 }
746 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
747 }
748}
749
750impl IsqModel for Model {
751 fn get_layers(
752 &mut self,
753 ) -> (
754 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
755 &dyn DeviceMapper,
756 ) {
757 let mut tensors = Vec::new();
758 tensors.push((&mut self.lm_head, None));
759 for (i, layer) in self.layers.iter_mut().enumerate() {
760 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
761 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
762 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
763 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
764 for expert in &mut layer.mlp.experts {
765 tensors.push((&mut expert.w1, Some(i)));
766 tensors.push((&mut expert.w2, Some(i)));
767 tensors.push((&mut expert.w3, Some(i)));
768 }
769 }
770 (tensors, &*self.mapper)
771 }
772 fn get_layers_moe_experts_only(
773 &mut self,
774 ) -> (
775 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
776 &dyn DeviceMapper,
777 ) {
778 let mut tensors = Vec::new();
779 tensors.push((&mut self.lm_head, None));
780 for (i, layer) in self.layers.iter_mut().enumerate() {
781 for expert in &mut layer.mlp.experts {
782 tensors.push((&mut expert.w1, Some(i)));
783 tensors.push((&mut expert.w2, Some(i)));
784 tensors.push((&mut expert.w3, Some(i)));
785 }
786 }
787 (tensors, &*self.mapper)
788 }
789
790 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
791 let uvb = UnVarBuilder::new();
792
793 let uvb_m = uvb.pp("model");
794 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
795 uvb_m.pp("norm").add(&self.norm);
796
797 for (layer_idx, layer) in self.layers.iter().enumerate() {
798 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
799 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
800 uvb_l
801 .pp("post_attention_layernorm")
802 .add(&layer.post_attention_layernorm);
803 }
804
805 uvb.to_safetensors()
806 }
807
808 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
809 let uvb = UnVarBuilder::new();
810
811 let uvb_m = uvb.pp("model");
812 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
813 uvb_m.pp("norm").add(&self.norm);
814
815 for (layer_idx, layer) in self.layers.iter().enumerate() {
816 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
817 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
818 uvb_l
819 .pp("post_attention_layernorm")
820 .add(&layer.post_attention_layernorm);
821
822 let uvb_attn = uvb_l.pp("self_attn");
823 uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
824 uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
825 uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
826 uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
827 }
828
829 Some(uvb.to_safetensors())
830 }
831}
832
833impl NormalModel for Model {
834 fn forward(
835 &self,
836 input_ids: &Tensor,
837 seqlen_offsets: &[usize],
838 context_lens: Vec<(usize, usize)>,
839 position_ids: Vec<usize>,
840 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
841 flash_params: &FlashParams,
842 ) -> Result<Tensor> {
843 self.forward(
844 input_ids,
845 seqlen_offsets,
846 &position_ids,
847 context_lens,
848 metadata,
849 flash_params,
850 )
851 }
852 fn xlora_forward(
853 &self,
854 _input_ids: &Tensor,
855 _input_ids_full: &Tensor,
856 _seqlen_offsets: &[usize],
857 _seqlen_offsets_full: &[usize],
858 _no_kv_cache: bool,
859 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
860 _context_lens: Vec<(usize, usize)>,
861 _position_ids: Vec<usize>,
862 _flash_params: &FlashParams,
863 _flash_params_full: &FlashParams,
864 ) -> Result<Tensor> {
865 unimplemented!()
866 }
867 fn cache(&self) -> &EitherCache {
868 &self.cache
869 }
870 fn cache_mut(&mut self) -> &mut EitherCache {
871 &mut self.cache
872 }
873 fn device(&self) -> &Device {
874 &self.device
875 }
876 fn is_xlora(&self) -> bool {
877 false
878 }
879 fn max_seq_len(&self) -> usize {
880 self.max_seq_len
881 }
882 fn config(&self) -> &ModelConfigMetadata {
883 &self.cfg
884 }
885}
886
887impl AnyMoeBaseModelMixin for Model {}