1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{
6 ColumnParallelLayer, NonZeroOp, QuantMethod, QuantizedConfig, ReplicatedLayer,
7 RowParallelLayer, ShardedVarBuilder,
8};
9use std::{collections::HashMap, sync::Arc};
10
11use crate::{
12 amoe::AnyMoeBaseModelMixin,
13 attention::SdpaParams,
14 device_map::DeviceMapper,
15 layers::{
16 self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
17 PhiRotaryEmbedding, Sdpa,
18 },
19 layers_masker::{masked_fill, PastKvLenCache},
20 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
21 pipeline::{
22 extract_logits,
23 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
24 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
25 },
26 serde_default_fn,
27 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
28};
29
30serde_default_fn!(bool, word_emb_default, false);
31
32#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
34pub struct Config {
35 pub(crate) vocab_size: usize,
36 pub(crate) hidden_act: Activation,
37 pub(crate) hidden_size: usize,
38 pub(crate) intermediate_size: usize,
39 pub(crate) num_hidden_layers: usize,
40 pub(crate) num_attention_heads: usize,
41 pub(crate) num_key_value_heads: usize,
42 pub(crate) rms_norm_eps: f64,
43 pub(crate) rope_theta: f64,
44 pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
45 pub(crate) max_position_embeddings: usize,
46 pub(crate) sliding_window: Option<usize>,
47 pub(crate) original_max_position_embeddings: usize,
48
49 pub(crate) quantization_config: Option<QuantizedConfig>,
50 pub(crate) lm_head_bias: bool,
51 pub(crate) attention_bias: bool,
52 pub(crate) num_local_experts: usize,
53 pub(crate) router_jitter_noise: f64,
54 #[serde(default = "word_emb_default")]
55 pub(crate) tie_word_embeddings: bool,
56}
57
58impl From<Config> for PhiRopeConfig {
59 fn from(val: Config) -> Self {
60 PhiRopeConfig {
61 rope_scaling: val.rope_scaling,
62 max_position_embeddings: val.max_position_embeddings,
63 original_max_position_embeddings: val.original_max_position_embeddings,
64 rope_theta: val.rope_theta,
65 head_dim: val.hidden_size / val.num_attention_heads,
66 partial_rotary_factor: None,
67 }
68 }
69}
70
71impl Config {
72 pub fn head_dim(&self) -> usize {
73 self.hidden_size / self.num_attention_heads
74 }
75}
76
77struct Attention {
78 q_proj: Arc<dyn QuantMethod>,
79 k_proj: Arc<dyn QuantMethod>,
80 v_proj: Arc<dyn QuantMethod>,
81 o_proj: Arc<dyn QuantMethod>,
82 num_heads: usize,
83 num_kv_heads: usize,
84 head_dim: usize,
85 rotary_emb: Arc<PhiRotaryEmbedding>,
86 paged_attn: Option<PagedAttention>,
87 sdpa_params: SdpaParams,
88}
89
90impl Attention {
91 fn new(
92 rotary_emb: Arc<PhiRotaryEmbedding>,
93 cfg: &Config,
94 vb: ShardedVarBuilder,
95 paged_attn: Option<PagedAttention>,
96 comm: &Arc<mistralrs_quant::Comm>,
97 ) -> Result<Self> {
98 let num_heads = cfg.num_attention_heads;
99 let num_kv_heads = cfg.num_key_value_heads;
100 let head_dim = cfg.head_dim();
101
102 let q_proj = ColumnParallelLayer::new(
103 cfg.hidden_size,
104 num_heads * head_dim,
105 &cfg.quantization_config,
106 cfg.attention_bias,
107 comm,
108 vb.pp("q_proj"),
109 )?;
110 let kv_shard = mistralrs_quant::compute_kv_shard(
111 cfg.num_key_value_heads,
112 cfg.hidden_size / cfg.num_attention_heads,
113 comm,
114 );
115 let k_proj = ColumnParallelLayer::new_with_shard(
116 cfg.hidden_size,
117 num_kv_heads * head_dim,
118 &cfg.quantization_config,
119 cfg.attention_bias,
120 comm,
121 kv_shard,
122 vb.pp("k_proj"),
123 )?;
124 let v_proj = ColumnParallelLayer::new_with_shard(
125 cfg.hidden_size,
126 num_kv_heads * head_dim,
127 &cfg.quantization_config,
128 cfg.attention_bias,
129 comm,
130 kv_shard,
131 vb.pp("v_proj"),
132 )?;
133 let o_proj = RowParallelLayer::new(
134 num_heads * head_dim,
135 cfg.hidden_size,
136 &cfg.quantization_config,
137 cfg.attention_bias,
138 comm,
139 vb.pp("o_proj"),
140 )?;
141
142 Ok(Self {
143 q_proj,
144 k_proj,
145 v_proj,
146 o_proj,
147 rotary_emb,
148 num_heads: num_heads / comm.world_size(),
149 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
150 head_dim,
151 paged_attn,
152 sdpa_params: SdpaParams {
153 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
154 cfg.num_key_value_heads,
155 cfg.num_attention_heads,
156 comm,
157 ),
158 softcap: None,
159 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
160 sliding_window: cfg.sliding_window,
161 },
162 })
163 }
164
165 #[allow(clippy::too_many_arguments)]
166 fn forward(
167 &self,
168 xs: &Tensor,
169 attention_mask: Option<&Tensor>,
170 seqlen_offsets: &[usize],
171 position_ids: &[usize],
172 kv_cache: &mut KvCache,
173 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
174 flash_params: &FlashParams,
175 ) -> Result<Tensor> {
176 let (b_sz, q_len, _) = xs.dims3()?;
177
178 let original_dtype = xs.dtype();
179 let mut xs = xs.clone();
180 if let Some(t) = self.q_proj.quantized_act_type() {
181 xs = xs.to_dtype(t)?;
182 }
183 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
184 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
185 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
186 if self.q_proj.quantized_act_type().is_some() {
187 q = q.to_dtype(original_dtype)?;
188 k = k.to_dtype(original_dtype)?;
189 v = v.to_dtype(original_dtype)?;
190 }
191
192 let (q, k, v) = if q_len != 1 {
193 let q = q
194 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
195 .transpose(1, 2)?;
196 let k = k
197 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
198 .transpose(1, 2)?;
199 let v = v
200 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
201 .transpose(1, 2)?;
202 (q, k, v)
203 } else {
204 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
205 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
206 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
207 (q, k, v)
208 };
209
210 let (q, k) = self
211 .rotary_emb
212 .forward(&q, &k, seqlen_offsets, position_ids)?;
213
214 let mut attn_output = match &self.paged_attn {
215 Some(paged_attn) => match metadata {
216 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
217 &q,
218 &k,
219 &v,
220 attention_mask,
221 Some(key_cache),
222 Some(value_cache),
223 input_metadata,
224 &self.sdpa_params,
225 Some(flash_params),
226 )?,
227 None => {
228 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
231 assert!(attention_mask.is_some());
233 paged_attn.forward(
234 &q,
235 &k,
236 &v,
237 attention_mask,
238 None,
239 None,
240 &input_metadata,
241 &self.sdpa_params,
242 Some(flash_params),
243 )?
244 }
245 },
246 _ => {
247 let (k, v) = kv_cache.append(&k, &v)?;
248
249 Sdpa.run_attention(
250 &q,
251 &k,
252 &v,
253 attention_mask,
254 Some(flash_params),
255 &self.sdpa_params,
256 )?
257 }
258 };
259
260 if let Some(t) = self.q_proj.quantized_act_type() {
261 attn_output = attn_output.to_dtype(t)?;
262 }
263 attn_output = if attention_mask.is_some() {
264 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
265 } else {
266 attn_output.reshape((b_sz, q_len, ()))?
267 };
268 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
269 if self.q_proj.quantized_act_type().is_some() {
270 res = res.to_dtype(original_dtype)?;
271 }
272 Ok(res)
273 }
274}
275
276#[derive(Clone)]
277struct Mlp {
278 w1: Arc<dyn QuantMethod>,
279 w2: Arc<dyn QuantMethod>,
280 w3: Arc<dyn QuantMethod>,
281 act_fn: Activation,
282}
283
284impl Mlp {
285 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
286 let hidden_size = cfg.hidden_size;
287 let i_size = cfg.intermediate_size;
288
289 let w1 = ColumnParallelLayer::new(
290 hidden_size,
291 i_size,
292 &cfg.quantization_config,
293 false,
294 comm,
295 vb.pp("w1"),
296 )?;
297 let w2 = RowParallelLayer::new(
298 i_size,
299 hidden_size,
300 &cfg.quantization_config,
301 false,
302 comm,
303 vb.pp("w2"),
304 )?;
305 let w3 = ColumnParallelLayer::new(
306 hidden_size,
307 i_size,
308 &cfg.quantization_config,
309 false,
310 comm,
311 vb.pp("w3"),
312 )?;
313
314 Ok(Self {
315 w1,
316 w2,
317 w3,
318 act_fn: cfg.hidden_act,
319 })
320 }
321
322 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
323 let original_dtype = xs.dtype();
324 let mut xs = xs.clone();
325 if let Some(t) = self.w1.quantized_act_type() {
326 xs = xs.to_dtype(t)?;
327 }
328 let mut current_hidden_states =
329 MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
330 let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
331 current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?;
332 let mut res = MatMul.qmethod_matmul(¤t_hidden_states, &*self.w2)?;
333 if self.w1.quantized_act_type().is_some() {
334 res = res.to_dtype(original_dtype)?;
335 }
336 Ok(res)
337 }
338}
339
340struct MoeMlp {
341 gate: candle_nn::Linear,
342 experts: Vec<Mlp>,
343 router_jitter_noise: f64,
344 num_experts: usize,
345}
346
347impl MoeMlp {
348 fn new(
349 cfg: &Config,
350 vb: ShardedVarBuilder,
351 layer_device: Device,
352 comm: &Arc<mistralrs_quant::Comm>,
353 ) -> Result<Self> {
354 let num_experts = cfg.num_local_experts;
355 let gate = layers::linear_no_bias(
356 cfg.hidden_size,
357 num_experts,
358 vb.pp("gate").set_device(layer_device),
359 )?;
360
361 let experts_vb = vb.pp("experts");
362 let mut experts = Vec::with_capacity(num_experts);
363 for i in 0..num_experts {
364 experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
365 }
366
367 Ok(Self {
368 gate,
369 experts,
370 router_jitter_noise: cfg.router_jitter_noise,
371 num_experts,
372 })
373 }
374
375 fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
376 let selected_experts = scores.argmax_keepdim(D::Minus1)?;
378 let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
379 let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
380 let mask_logits_threshold = mask_logits_threshold
381 .broadcast_sub(scores)?
382 .broadcast_div(&factor)?
383 .gt(2. * jitter_eps)?;
384
385 let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
387
388 let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
390 let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
391
392 let masked_scores = scores.scatter_add(
394 &selected_experts
395 .broadcast_as(scores.shape())?
396 .contiguous()?,
397 &(scores.ones_like()? * f64::NEG_INFINITY)?,
398 D::Minus1,
399 )?;
400
401 let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
403 let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
404 let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
405 let mask_logits_threshold = mask_logits_threshold
406 .broadcast_sub(scores)?
407 .broadcast_div(&factor)?
408 .gt(2. * jitter_eps)?;
409
410 let masked_gates_top2 =
412 masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
413 let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
414 let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
415
416 let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
417 let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
418
419 Ok((multiplier, selected_experts))
420 }
421
422 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
423 let (bs, seq, hidden) = xs.dims3()?;
424 let xs = xs.reshape(((), hidden))?;
425 let xs_dev = xs.device();
426 let xs = xs.to_device(&Device::Cpu)?;
427
428 let router_logits = self
431 .gate
432 .forward(&xs.to_device(xs_dev)?)?
433 .to_device(&Device::Cpu)?;
434 let (routing_weights, selected_experts) = self.sparsemixer(
435 &router_logits.to_device(&Device::Cpu)?,
436 self.router_jitter_noise,
437 )?;
438
439 let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs_dev)?;
440
441 let experts_mask =
444 candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
445 .permute((2, 1, 0))?;
446
447 for expert_idx in 0..self.num_experts {
449 let expert = &self.experts[expert_idx];
450 let expert_mask = experts_mask.i(expert_idx)?;
451 assert_eq!(expert_mask.rank(), 2);
452 let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
453 let idx = nonzero_mask.i((.., 0))?;
454 let top_x = nonzero_mask.i((.., 1))?;
455
456 if top_x.dim(0)? == 0 {
457 continue;
458 }
459
460 let current_state = xs.index_select(&top_x, 0)?.reshape((1, (), hidden))?;
464 let current_routing_weights = routing_weights
465 .index_select(&top_x, 0)?
466 .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
467 let exp_out = expert
468 .forward(¤t_state.to_device(xs_dev)?)?
469 .to_device(&Device::Cpu)?;
470
471 let current_hidden_states = exp_out.broadcast_mul(¤t_routing_weights)?;
472
473 final_hidden_states = final_hidden_states.index_add(
474 &top_x.contiguous()?.to_device(xs_dev)?,
475 ¤t_hidden_states
476 .squeeze(0)?
477 .to_dtype(xs.dtype())?
478 .to_device(xs_dev)?,
479 0,
480 )?;
481 }
482
483 final_hidden_states
484 .reshape((bs, seq, hidden))?
485 .to_device(xs_dev)
486 }
487}
488
489struct DecoderLayer {
490 self_attn: Attention,
491 mlp: MoeMlp,
492 input_layernorm: LayerNorm,
493 post_attention_layernorm: LayerNorm,
494}
495
496impl DecoderLayer {
497 #[allow(clippy::too_many_arguments)]
498 fn new(
499 rotary_emb: Arc<PhiRotaryEmbedding>,
500 cfg: &Config,
501 vb: ShardedVarBuilder,
502 mapper: &dyn DeviceMapper,
503 layer_idx: usize,
504 loading_isq: bool,
505 paged_attn: Option<PagedAttention>,
506 real_device: Device,
507 comm: &Arc<mistralrs_quant::Comm>,
508 ) -> Result<Self> {
509 let self_attn = Attention::new(
510 rotary_emb,
511 cfg,
512 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
513 paged_attn,
514 comm,
515 )?;
516 let mlp = MoeMlp::new(
517 cfg,
518 mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
519 mapper
520 .device_for(layer_idx, false)
521 .cloned()
522 .unwrap_or(real_device),
523 comm,
524 )?;
525 let input_layernorm = layer_norm(
526 cfg.hidden_size,
527 cfg.rms_norm_eps,
528 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
529 )?;
530 let post_attention_layernorm = layer_norm(
531 cfg.hidden_size,
532 cfg.rms_norm_eps,
533 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
534 )?;
535 Ok(Self {
536 self_attn,
537 mlp,
538 input_layernorm,
539 post_attention_layernorm,
540 })
541 }
542
543 #[allow(clippy::too_many_arguments)]
544 fn forward(
545 &self,
546 xs: &Tensor,
547 attention_mask: Option<&Tensor>,
548 seqlen_offsets: &[usize],
549 position_ids: &[usize],
550 kv_cache: &mut KvCache,
551 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
552 flash_params: &FlashParams,
553 ) -> Result<Tensor> {
554 let residual = xs;
555 let xs = self.input_layernorm.forward(xs)?;
556 let xs = self.self_attn.forward(
557 &xs,
558 attention_mask,
559 seqlen_offsets,
560 position_ids,
561 kv_cache,
562 metadata,
563 flash_params,
564 )?;
565 let xs = (xs + residual)?;
566 let residual = &xs;
567 let xs = self
568 .mlp
569 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
570 residual + xs
571 }
572}
573
574pub struct Model {
575 embed_tokens: candle_nn::Embedding,
576 layers: Vec<DecoderLayer>,
577 norm: LayerNorm,
578 lm_head: Arc<dyn QuantMethod>,
579 device: Device,
580 cache: EitherCache,
581 max_seq_len: usize,
582 mapper: Box<dyn DeviceMapper + Send + Sync>,
583 sliding_window: Option<usize>,
584 cfg: ModelConfigMetadata,
585}
586
587impl Model {
588 pub fn new(
589 cfg: &Config,
590 vb: ShardedVarBuilder,
591 _is_gptx: bool,
592 normal_loading_metadata: NormalLoadingMetadata,
593 attention_mechanism: AttentionImplementation,
594 ) -> Result<Self> {
595 if let Some(ref quant_cfg) = &cfg.quantization_config {
596 tracing::info!(
597 "Using {} quantization: {}.",
598 quant_cfg.name(),
599 quant_cfg.get_bits_name(&vb)
600 );
601 }
602 let mapper = normal_loading_metadata.mapper;
603 let vb_m = vb.pp("model");
604
605 let embed_tokens = layers::embedding(
606 cfg.vocab_size,
607 cfg.hidden_size,
608 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
609 &cfg.quantization_config,
610 )?;
611 let mut ropes = HashMap::new();
612 for layer_idx in 0..cfg.num_hidden_layers {
613 let device = mapper
614 .device_for(layer_idx, false)
615 .unwrap_or(&normal_loading_metadata.real_device);
616 ropes.insert(
617 device.location(),
618 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
619 );
620 }
621 let vb_l = vb_m.pp("layers");
622 let layers: Vec<DecoderLayer> = NiceProgressBar::<_, 'b'>(
623 0..cfg.num_hidden_layers,
624 "Loading repeating layers",
625 &normal_loading_metadata.multi_progress,
626 )
627 .par_iter_if_isq(|layer_idx| {
628 let device = mapper
629 .device_for(layer_idx, false)
630 .unwrap_or(&normal_loading_metadata.real_device);
631 let rotary_emb = ropes
632 .get(&device.location())
633 .expect("No RoPE for device location!")
634 .clone();
635 let paged_attn = match &attention_mechanism {
636 AttentionImplementation::Eager => None,
637 AttentionImplementation::PagedAttention => {
638 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
639 }
640 };
641 let comm = mapper.get_comm_for(layer_idx)?;
642 DecoderLayer::new(
643 rotary_emb.clone(),
644 cfg,
645 vb_l.pp(layer_idx),
646 &*mapper,
647 layer_idx,
648 normal_loading_metadata.loading_isq,
649 paged_attn,
650 normal_loading_metadata.real_device.clone(),
651 &comm,
652 )
653 })?;
654 let norm = layer_norm(
655 cfg.hidden_size,
656 cfg.rms_norm_eps,
657 mapper.set_nm_device(vb_m.pp("norm"), false),
658 )?;
659 let lm_head = if !cfg.tie_word_embeddings {
660 ReplicatedLayer::new(
661 cfg.hidden_size,
662 cfg.vocab_size,
663 &cfg.quantization_config,
664 cfg.lm_head_bias,
665 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
666 )?
667 } else {
668 unreachable!()
669 };
670 Ok(Self {
671 embed_tokens,
672 layers,
673 norm,
674 lm_head,
675 device: normal_loading_metadata.real_device,
676 cache: EitherCache::Normal(NormalCache::new_sliding(
677 cfg.num_hidden_layers,
678 cfg.max_position_embeddings,
679 cfg.sliding_window,
680 )),
681 max_seq_len: cfg.max_position_embeddings,
682 sliding_window: cfg.sliding_window,
683 cfg: ModelConfigMetadata {
684 max_seq_len: cfg.max_position_embeddings,
685 num_layers: cfg.num_hidden_layers,
686 hidden_size: cfg.hidden_size,
687 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
688 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
689 .max(1),
690 sliding_window: cfg.sliding_window,
691 k_head_dim: cfg.head_dim(),
692 v_head_dim: cfg.head_dim(),
693 },
694 mapper,
695 })
696 }
697
698 pub fn forward(
699 &self,
700 input_ids: &Tensor,
701 seqlen_offsets: &[usize],
702 position_ids: &[usize],
703 context_lens: Vec<(usize, usize)>,
704 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
705 flash_params: &FlashParams,
706 ) -> Result<Tensor> {
707 let mut xs = self.embed_tokens.forward(input_ids)?;
708 let cache = &mut self.cache.normal().0;
709 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
710 input_ids,
711 metadata
712 .as_ref()
713 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
714 .unwrap_or(cache as &dyn PastKvLenCache),
715 self.sliding_window,
716 xs.dtype(),
717 self.cfg.num_attn_heads,
718 )?;
719 let attention_mask = attention_mask.filter(|_| {
721 metadata
722 .as_ref()
723 .map(|(_, meta)| meta.is_first_prompt_chunk)
724 .unwrap_or(true)
725 });
726
727 for (i, layer) in self.layers.iter().enumerate() {
728 xs = self.mapper.map(xs, i)?;
729 xs = layer.forward(
730 &xs,
731 attention_mask
732 .as_ref()
733 .map(|m| m.to_device(xs.device()).unwrap())
734 .as_ref(),
735 seqlen_offsets,
736 position_ids,
737 &mut cache[i],
738 metadata
739 .as_ref()
740 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
741 flash_params,
742 )?
743 }
744 let xs = xs.to_device(&self.device)?;
745 let mut xs = xs.apply(&self.norm)?;
746 if let Some(t) = self.lm_head.quantized_act_type() {
747 xs = xs.to_dtype(t)?;
748 }
749 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
750 }
751}
752
753impl IsqModel for Model {
754 fn get_layers(
755 &mut self,
756 ) -> (
757 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
758 &dyn DeviceMapper,
759 ) {
760 let mut tensors = Vec::new();
761 tensors.push((&mut self.lm_head, None));
762 for (i, layer) in self.layers.iter_mut().enumerate() {
763 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
764 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
765 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
766 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
767 for expert in &mut layer.mlp.experts {
768 tensors.push((&mut expert.w1, Some(i)));
769 tensors.push((&mut expert.w2, Some(i)));
770 tensors.push((&mut expert.w3, Some(i)));
771 }
772 }
773 (tensors, &*self.mapper)
774 }
775 fn get_layers_moe_experts_only(
776 &mut self,
777 ) -> (
778 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
779 &dyn DeviceMapper,
780 ) {
781 let mut tensors = Vec::new();
782 tensors.push((&mut self.lm_head, None));
783 for (i, layer) in self.layers.iter_mut().enumerate() {
784 for expert in &mut layer.mlp.experts {
785 tensors.push((&mut expert.w1, Some(i)));
786 tensors.push((&mut expert.w2, Some(i)));
787 tensors.push((&mut expert.w3, Some(i)));
788 }
789 }
790 (tensors, &*self.mapper)
791 }
792
793 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
794 let uvb = UnVarBuilder::new();
795
796 let uvb_m = uvb.pp("model");
797 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
798 uvb_m.pp("norm").add(&self.norm);
799
800 for (layer_idx, layer) in self.layers.iter().enumerate() {
801 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
802 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
803 uvb_l
804 .pp("post_attention_layernorm")
805 .add(&layer.post_attention_layernorm);
806 }
807
808 uvb.to_safetensors()
809 }
810
811 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
812 let uvb = UnVarBuilder::new();
813
814 let uvb_m = uvb.pp("model");
815 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
816 uvb_m.pp("norm").add(&self.norm);
817
818 for (layer_idx, layer) in self.layers.iter().enumerate() {
819 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
820 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
821 uvb_l
822 .pp("post_attention_layernorm")
823 .add(&layer.post_attention_layernorm);
824
825 let uvb_attn = uvb_l.pp("self_attn");
826 uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
827 uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
828 uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
829 uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
830 }
831
832 Some(uvb.to_safetensors())
833 }
834}
835
836impl NormalModel for Model {
837 fn forward(
838 &self,
839 input_ids: &Tensor,
840 seqlen_offsets: &[usize],
841 context_lens: Vec<(usize, usize)>,
842 position_ids: Vec<usize>,
843 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
844 flash_params: &FlashParams,
845 ) -> Result<Tensor> {
846 self.forward(
847 input_ids,
848 seqlen_offsets,
849 &position_ids,
850 context_lens,
851 metadata,
852 flash_params,
853 )
854 }
855 fn xlora_forward(
856 &self,
857 _input_ids: &Tensor,
858 _input_ids_full: &Tensor,
859 _seqlen_offsets: &[usize],
860 _seqlen_offsets_full: &[usize],
861 _no_kv_cache: bool,
862 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
863 _context_lens: Vec<(usize, usize)>,
864 _position_ids: Vec<usize>,
865 _flash_params: &FlashParams,
866 _flash_params_full: &FlashParams,
867 ) -> Result<Tensor> {
868 unimplemented!()
869 }
870 fn cache(&self) -> &EitherCache {
871 &self.cache
872 }
873 fn cache_mut(&mut self) -> &mut EitherCache {
874 &mut self.cache
875 }
876 fn device(&self) -> &Device {
877 &self.device
878 }
879 fn is_xlora(&self) -> bool {
880 false
881 }
882 fn max_seq_len(&self) -> usize {
883 self.max_seq_len
884 }
885 fn config(&self) -> &ModelConfigMetadata {
886 &self.cfg
887 }
888}
889
890impl AnyMoeBaseModelMixin for Model {}