1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{Device, IndexOp, Module, Result, Tensor, D};
6use candle_nn::LayerNorm;
7use mistralrs_quant::{
8 ColumnParallelLayer, QuantMethod, QuantizedConfig, ReplicatedLayer, RowParallelLayer,
9 ShardedVarBuilder,
10};
11use std::{collections::HashMap, sync::Arc};
12
13use crate::{
14 amoe::AnyMoeBaseModelMixin,
15 attention::SdpaParams,
16 device_map::DeviceMapper,
17 layers::{
18 self, layer_norm, Activation, CausalMasker, MatMul, PhiRopeConfig, PhiRopeScalingConfig,
19 PhiRotaryEmbedding, Sdpa,
20 },
21 layers_masker::{masked_fill, PastKvLenCache},
22 ops::NonZeroOp,
23 paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
24 pipeline::{
25 extract_logits,
26 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
27 EitherCache, IsqModel, KvCache, NormalCache, NormalLoadingMetadata, NormalModel,
28 },
29 serde_default_fn,
30 utils::{progress::NiceProgressBar, unvarbuilder::UnVarBuilder},
31};
32
33serde_default_fn!(bool, word_emb_default, false);
34
35#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
37pub struct Config {
38 pub(crate) vocab_size: usize,
39 pub(crate) hidden_act: Activation,
40 pub(crate) hidden_size: usize,
41 pub(crate) intermediate_size: usize,
42 pub(crate) num_hidden_layers: usize,
43 pub(crate) num_attention_heads: usize,
44 pub(crate) num_key_value_heads: usize,
45 pub(crate) rms_norm_eps: f64,
46 pub(crate) rope_theta: f64,
47 pub(crate) rope_scaling: Option<PhiRopeScalingConfig>,
48 pub(crate) max_position_embeddings: usize,
49 pub(crate) use_flash_attn: bool,
50 pub(crate) sliding_window: Option<usize>,
51 pub(crate) original_max_position_embeddings: usize,
52 pub(crate) quantization_config: Option<QuantizedConfig>,
53 pub(crate) lm_head_bias: bool,
54 pub(crate) attention_bias: bool,
55 pub(crate) num_local_experts: usize,
56 pub(crate) router_jitter_noise: f64,
57 #[serde(default = "word_emb_default")]
58 pub(crate) tie_word_embeddings: bool,
59}
60
61impl From<Config> for PhiRopeConfig {
62 fn from(val: Config) -> Self {
63 PhiRopeConfig {
64 rope_scaling: val.rope_scaling,
65 max_position_embeddings: val.max_position_embeddings,
66 original_max_position_embeddings: val.original_max_position_embeddings,
67 rope_theta: val.rope_theta,
68 head_dim: val.hidden_size / val.num_attention_heads,
69 partial_rotary_factor: None,
70 }
71 }
72}
73
74impl Config {
75 pub fn head_dim(&self) -> usize {
76 self.hidden_size / self.num_attention_heads
77 }
78}
79
80struct Attention {
81 q_proj: Arc<dyn QuantMethod>,
82 k_proj: Arc<dyn QuantMethod>,
83 v_proj: Arc<dyn QuantMethod>,
84 o_proj: Arc<dyn QuantMethod>,
85 num_heads: usize,
86 num_kv_heads: usize,
87 head_dim: usize,
88 rotary_emb: Arc<PhiRotaryEmbedding>,
89 paged_attn: Option<PagedAttention>,
90 sdpa_params: SdpaParams,
91}
92
93impl Attention {
94 fn new(
95 rotary_emb: Arc<PhiRotaryEmbedding>,
96 cfg: &Config,
97 vb: ShardedVarBuilder,
98 paged_attn: Option<PagedAttention>,
99 comm: &Arc<mistralrs_quant::Comm>,
100 ) -> Result<Self> {
101 let num_heads = cfg.num_attention_heads;
102 let num_kv_heads = cfg.num_key_value_heads;
103 let head_dim = cfg.head_dim();
104
105 let q_proj = ColumnParallelLayer::new(
106 cfg.hidden_size,
107 num_heads * head_dim,
108 &cfg.quantization_config,
109 cfg.attention_bias,
110 comm,
111 vb.pp("q_proj"),
112 )?;
113 let kv_shard = mistralrs_quant::compute_kv_shard(
114 cfg.num_key_value_heads,
115 cfg.hidden_size / cfg.num_attention_heads,
116 comm,
117 );
118 let k_proj = ColumnParallelLayer::new_with_shard(
119 cfg.hidden_size,
120 num_kv_heads * head_dim,
121 &cfg.quantization_config,
122 cfg.attention_bias,
123 comm,
124 kv_shard,
125 vb.pp("k_proj"),
126 )?;
127 let v_proj = ColumnParallelLayer::new_with_shard(
128 cfg.hidden_size,
129 num_kv_heads * head_dim,
130 &cfg.quantization_config,
131 cfg.attention_bias,
132 comm,
133 kv_shard,
134 vb.pp("v_proj"),
135 )?;
136 let o_proj = RowParallelLayer::new(
137 num_heads * head_dim,
138 cfg.hidden_size,
139 &cfg.quantization_config,
140 cfg.attention_bias,
141 comm,
142 vb.pp("o_proj"),
143 )?;
144
145 Ok(Self {
146 q_proj,
147 k_proj,
148 v_proj,
149 o_proj,
150 rotary_emb,
151 num_heads: num_heads / comm.world_size(),
152 num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
153 head_dim,
154 paged_attn,
155 sdpa_params: SdpaParams {
156 n_kv_groups: mistralrs_quant::compute_n_kv_groups(
157 cfg.num_key_value_heads,
158 cfg.num_attention_heads,
159 comm,
160 ),
161 use_flash_attn: cfg.use_flash_attn,
162 softcap: None,
163 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
164 sliding_window: cfg.sliding_window,
165 },
166 })
167 }
168
169 #[allow(clippy::too_many_arguments)]
170 fn forward(
171 &self,
172 xs: &Tensor,
173 attention_mask: Option<&Tensor>,
174 seqlen_offsets: &[usize],
175 position_ids: &[usize],
176 kv_cache: &mut KvCache,
177 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
178 flash_params: &FlashParams,
179 ) -> Result<Tensor> {
180 let (b_sz, q_len, _) = xs.dims3()?;
181
182 let original_dtype = xs.dtype();
183 let mut xs = xs.clone();
184 if let Some(t) = self.q_proj.quantized_act_type() {
185 xs = xs.to_dtype(t)?;
186 }
187 let mut q = MatMul.qmethod_matmul(&xs, &*self.q_proj)?;
188 let mut k = MatMul.qmethod_matmul(&xs, &*self.k_proj)?;
189 let mut v = MatMul.qmethod_matmul(&xs, &*self.v_proj)?;
190 if self.q_proj.quantized_act_type().is_some() {
191 q = q.to_dtype(original_dtype)?;
192 k = k.to_dtype(original_dtype)?;
193 v = v.to_dtype(original_dtype)?;
194 }
195
196 let (q, k, v) = if q_len != 1 {
197 let q = q
198 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
199 .transpose(1, 2)?;
200 let k = k
201 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
202 .transpose(1, 2)?;
203 let v = v
204 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
205 .transpose(1, 2)?;
206 (q, k, v)
207 } else {
208 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
209 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
210 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
211 (q, k, v)
212 };
213
214 let (q, k) = self
215 .rotary_emb
216 .forward(&q, &k, seqlen_offsets, position_ids)?;
217
218 let mut attn_output = match &self.paged_attn {
219 Some(paged_attn) => match metadata {
220 Some(((key_cache, value_cache), input_metadata)) => paged_attn.forward(
221 &q,
222 &k,
223 &v,
224 attention_mask,
225 Some(key_cache),
226 Some(value_cache),
227 input_metadata,
228 &self.sdpa_params,
229 Some(flash_params),
230 )?,
231 None => {
232 let input_metadata = PagedAttentionInputMetadata::dummy(q.device())?;
235 assert!(attention_mask.is_some());
237 paged_attn.forward(
238 &q,
239 &k,
240 &v,
241 attention_mask,
242 None,
243 None,
244 &input_metadata,
245 &self.sdpa_params,
246 Some(flash_params),
247 )?
248 }
249 },
250 _ => {
251 let (k, v) = kv_cache.append(&k, &v)?;
252
253 Sdpa.run_attention(
254 &q,
255 &k,
256 &v,
257 attention_mask,
258 Some(flash_params),
259 &self.sdpa_params,
260 )?
261 }
262 };
263
264 if let Some(t) = self.q_proj.quantized_act_type() {
265 attn_output = attn_output.to_dtype(t)?;
266 }
267 attn_output = if attention_mask.is_some() {
268 attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?
269 } else {
270 attn_output.reshape((b_sz, q_len, ()))?
271 };
272 let mut res = MatMul.qmethod_matmul(&attn_output, &*self.o_proj)?;
273 if self.q_proj.quantized_act_type().is_some() {
274 res = res.to_dtype(original_dtype)?;
275 }
276 Ok(res)
277 }
278}
279
280#[derive(Clone)]
281struct Mlp {
282 w1: Arc<dyn QuantMethod>,
283 w2: Arc<dyn QuantMethod>,
284 w3: Arc<dyn QuantMethod>,
285 act_fn: Activation,
286}
287
288impl Mlp {
289 fn new(cfg: &Config, vb: ShardedVarBuilder, comm: &Arc<mistralrs_quant::Comm>) -> Result<Self> {
290 let hidden_size = cfg.hidden_size;
291 let i_size = cfg.intermediate_size;
292
293 let w1 = ColumnParallelLayer::new(
294 hidden_size,
295 i_size,
296 &cfg.quantization_config,
297 false,
298 comm,
299 vb.pp("w1"),
300 )?;
301 let w2 = RowParallelLayer::new(
302 i_size,
303 hidden_size,
304 &cfg.quantization_config,
305 false,
306 comm,
307 vb.pp("w2"),
308 )?;
309 let w3 = ColumnParallelLayer::new(
310 hidden_size,
311 i_size,
312 &cfg.quantization_config,
313 false,
314 comm,
315 vb.pp("w3"),
316 )?;
317
318 Ok(Self {
319 w1,
320 w2,
321 w3,
322 act_fn: cfg.hidden_act,
323 })
324 }
325
326 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
327 let original_dtype = xs.dtype();
328 let mut xs = xs.clone();
329 if let Some(t) = self.w1.quantized_act_type() {
330 xs = xs.to_dtype(t)?;
331 }
332 let mut current_hidden_states =
333 MatMul.qmethod_matmul(&xs, &*self.w1)?.apply(&self.act_fn)?;
334 let rhs = MatMul.qmethod_matmul(&xs, &*self.w3)?;
335 current_hidden_states = current_hidden_states.broadcast_mul(&rhs)?;
336 let mut res = MatMul.qmethod_matmul(¤t_hidden_states, &*self.w2)?;
337 if self.w1.quantized_act_type().is_some() {
338 res = res.to_dtype(original_dtype)?;
339 }
340 Ok(res)
341 }
342}
343
344struct MoeMlp {
345 gate: candle_nn::Linear,
346 experts: Vec<Mlp>,
347 router_jitter_noise: f64,
348 num_experts: usize,
349}
350
351impl MoeMlp {
352 fn new(
353 cfg: &Config,
354 vb: ShardedVarBuilder,
355 layer_device: Device,
356 comm: &Arc<mistralrs_quant::Comm>,
357 ) -> Result<Self> {
358 let num_experts = cfg.num_local_experts;
359 let gate = layers::linear_no_bias(
360 cfg.hidden_size,
361 num_experts,
362 vb.pp("gate").set_device(layer_device),
363 )?;
364
365 let experts_vb = vb.pp("experts");
366 let mut experts = Vec::with_capacity(num_experts);
367 for i in 0..num_experts {
368 experts.push(Mlp::new(cfg, experts_vb.pp(i), comm)?);
369 }
370
371 Ok(Self {
372 gate,
373 experts,
374 router_jitter_noise: cfg.router_jitter_noise,
375 num_experts,
376 })
377 }
378
379 fn sparsemixer(&self, scores: &Tensor, jitter_eps: f64) -> Result<(Tensor, Tensor)> {
380 let selected_experts = scores.argmax_keepdim(D::Minus1)?;
382 let mask_logits_threshold = scores.gather(&selected_experts, D::Minus1)?;
383 let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
384 let mask_logits_threshold = mask_logits_threshold
385 .broadcast_sub(scores)?
386 .broadcast_div(&factor)?
387 .gt(2. * jitter_eps)?;
388
389 let masked_gates = masked_fill(scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
391
392 let masked_gates = candle_nn::ops::softmax_last_dim(&masked_gates)?;
394 let multiplier = masked_gates.gather(&selected_experts, D::Minus1)?;
395
396 let masked_scores = scores.scatter_add(
398 &selected_experts
399 .broadcast_as(scores.shape())?
400 .contiguous()?,
401 &(scores.ones_like()? * f64::NEG_INFINITY)?,
402 D::Minus1,
403 )?;
404
405 let selected_experts_top2 = masked_scores.argmax_keepdim(D::Minus1)?;
407 let mask_logits_threshold = masked_scores.gather(&selected_experts_top2, D::Minus1)?;
408 let factor = scores.abs()?.broadcast_minimum(&mask_logits_threshold)?;
409 let mask_logits_threshold = mask_logits_threshold
410 .broadcast_sub(scores)?
411 .broadcast_div(&factor)?
412 .gt(2. * jitter_eps)?;
413
414 let masked_gates_top2 =
416 masked_fill(&masked_scores, &mask_logits_threshold, f64::NEG_INFINITY)?;
417 let masked_gates_top2 = candle_nn::ops::softmax_last_dim(&masked_gates_top2)?;
418 let multiplier_top2 = masked_gates_top2.gather(&selected_experts_top2, D::Minus1)?;
419
420 let multiplier = Tensor::cat(&[multiplier, multiplier_top2], D::Minus1)?;
421 let selected_experts = Tensor::cat(&[selected_experts, selected_experts_top2], D::Minus1)?;
422
423 Ok((multiplier, selected_experts))
424 }
425
426 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
427 let (bs, seq, hidden) = xs.dims3()?;
428 let xs = xs.reshape(((), hidden))?;
429 let xs_dev = xs.device();
430 let xs = xs.to_device(&Device::Cpu)?;
431
432 let router_logits = self
435 .gate
436 .forward(&xs.to_device(xs_dev)?)?
437 .to_device(&Device::Cpu)?;
438 let (routing_weights, selected_experts) = self.sparsemixer(
439 &router_logits.to_device(&Device::Cpu)?,
440 self.router_jitter_noise,
441 )?;
442
443 let mut final_hidden_states = Tensor::zeros((bs * seq, hidden), xs.dtype(), xs.device())?;
444
445 let experts_mask =
448 candle_nn::encoding::one_hot(selected_experts, self.num_experts, 1u8, 0u8)?
449 .permute((2, 1, 0))?;
450
451 for expert_idx in 0..self.num_experts {
453 let expert = &self.experts[expert_idx];
454 let expert_mask = experts_mask.i(expert_idx)?;
455 assert_eq!(expert_mask.rank(), 2);
456 let nonzero_mask = expert_mask.contiguous()?.nonzero()?;
457 let idx = nonzero_mask.i((.., 0))?;
458 let top_x = nonzero_mask.i((.., 1))?;
459
460 if top_x.dim(0)? == 0 {
461 continue;
462 }
463
464 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden))?;
468 let current_routing_weights = routing_weights
469 .index_select(&top_x, 0)?
470 .gather(&idx.unsqueeze(1)?.contiguous()?, 1)?;
471 let exp_out = expert
472 .forward(¤t_state.to_device(xs_dev)?)?
473 .to_device(&Device::Cpu)?;
474
475 let current_hidden_states = exp_out.broadcast_mul(¤t_routing_weights)?;
476
477 final_hidden_states = final_hidden_states.index_add(
478 &top_x.contiguous()?,
479 ¤t_hidden_states.to_dtype(xs.dtype())?,
480 0,
481 )?;
482 }
483
484 final_hidden_states
485 .reshape((bs, seq, hidden))?
486 .to_device(xs_dev)
487 }
488}
489
490struct DecoderLayer {
491 self_attn: Attention,
492 mlp: MoeMlp,
493 input_layernorm: LayerNorm,
494 post_attention_layernorm: LayerNorm,
495}
496
497impl DecoderLayer {
498 #[allow(clippy::too_many_arguments)]
499 fn new(
500 rotary_emb: Arc<PhiRotaryEmbedding>,
501 cfg: &Config,
502 vb: ShardedVarBuilder,
503 mapper: &dyn DeviceMapper,
504 layer_idx: usize,
505 loading_isq: bool,
506 paged_attn: Option<PagedAttention>,
507 real_device: Device,
508 comm: &Arc<mistralrs_quant::Comm>,
509 ) -> Result<Self> {
510 let self_attn = Attention::new(
511 rotary_emb,
512 cfg,
513 mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
514 paged_attn,
515 comm,
516 )?;
517 let mlp = MoeMlp::new(
518 cfg,
519 mapper.set_device(layer_idx, vb.pp("block_sparse_moe"), loading_isq),
520 mapper
521 .device_for(layer_idx, false)
522 .cloned()
523 .unwrap_or(real_device),
524 comm,
525 )?;
526 let input_layernorm = layer_norm(
527 cfg.hidden_size,
528 cfg.rms_norm_eps,
529 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
530 )?;
531 let post_attention_layernorm = layer_norm(
532 cfg.hidden_size,
533 cfg.rms_norm_eps,
534 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
535 )?;
536 Ok(Self {
537 self_attn,
538 mlp,
539 input_layernorm,
540 post_attention_layernorm,
541 })
542 }
543
544 #[allow(clippy::too_many_arguments)]
545 fn forward(
546 &self,
547 xs: &Tensor,
548 attention_mask: Option<&Tensor>,
549 seqlen_offsets: &[usize],
550 position_ids: &[usize],
551 kv_cache: &mut KvCache,
552 metadata: Option<((Tensor, Tensor), &PagedAttentionInputMetadata)>,
553 flash_params: &FlashParams,
554 ) -> Result<Tensor> {
555 let residual = xs;
556 let xs = self.input_layernorm.forward(xs)?;
557 let xs = self.self_attn.forward(
558 &xs,
559 attention_mask,
560 seqlen_offsets,
561 position_ids,
562 kv_cache,
563 metadata,
564 flash_params,
565 )?;
566 let xs = (xs + residual)?;
567 let residual = &xs;
568 let xs = self
569 .mlp
570 .forward(&xs.apply(&self.post_attention_layernorm)?)?;
571 residual + xs
572 }
573}
574
575pub struct Model {
576 embed_tokens: candle_nn::Embedding,
577 layers: Vec<DecoderLayer>,
578 norm: LayerNorm,
579 lm_head: Arc<dyn QuantMethod>,
580 device: Device,
581 cache: EitherCache,
582 max_seq_len: usize,
583 mapper: Box<dyn DeviceMapper + Send + Sync>,
584 sliding_window: Option<usize>,
585 cfg: ModelConfigMetadata,
586}
587
588impl Model {
589 pub fn new(
590 cfg: &Config,
591 vb: ShardedVarBuilder,
592 _is_gptx: bool,
593 normal_loading_metadata: NormalLoadingMetadata,
594 attention_mechanism: AttentionImplementation,
595 ) -> Result<Self> {
596 if let Some(ref quant_cfg) = &cfg.quantization_config {
597 tracing::info!(
598 "Using {} quantization: {}.",
599 quant_cfg.name(),
600 quant_cfg.get_bits_name(&vb)
601 );
602 }
603 let mapper = normal_loading_metadata.mapper;
604 let vb_m = vb.pp("model");
605
606 let embed_tokens = layers::embedding(
607 cfg.vocab_size,
608 cfg.hidden_size,
609 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
610 &cfg.quantization_config,
611 )?;
612 let mut ropes = HashMap::new();
613 for layer_idx in 0..cfg.num_hidden_layers {
614 let device = mapper
615 .device_for(layer_idx, false)
616 .unwrap_or(&normal_loading_metadata.real_device);
617 ropes.insert(
618 device.location(),
619 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
620 );
621 }
622 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
623 let vb_l = vb_m.pp("layers");
624 for layer_idx in NiceProgressBar::<_, 'b'>(
625 0..cfg.num_hidden_layers,
626 "Loading repeating layers",
627 &normal_loading_metadata.multi_progress,
628 ) {
629 let device = mapper
630 .device_for(layer_idx, false)
631 .unwrap_or(&normal_loading_metadata.real_device);
632 let rotary_emb = ropes
633 .get(&device.location())
634 .expect("No RoPE for device location!")
635 .clone();
636 let paged_attn = match &attention_mechanism {
637 AttentionImplementation::Eager => None,
638 AttentionImplementation::PagedAttention => {
639 Some(PagedAttention::new(cfg.head_dim(), device, None)?)
640 }
641 };
642 let comm = mapper.get_comm_for(layer_idx)?;
643 let layer = DecoderLayer::new(
644 rotary_emb.clone(),
645 cfg,
646 vb_l.pp(layer_idx),
647 &*mapper,
648 layer_idx,
649 normal_loading_metadata.loading_isq,
650 paged_attn,
651 normal_loading_metadata.real_device.clone(),
652 &comm,
653 )?;
654 layers.push(layer)
655 }
656 let norm = layer_norm(
657 cfg.hidden_size,
658 cfg.rms_norm_eps,
659 mapper.set_nm_device(vb_m.pp("norm"), false),
660 )?;
661 let lm_head = if !cfg.tie_word_embeddings {
662 ReplicatedLayer::new(
663 cfg.hidden_size,
664 cfg.vocab_size,
665 &None,
666 cfg.lm_head_bias,
667 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
668 )?
669 } else {
670 unreachable!()
671 };
672 Ok(Self {
673 embed_tokens,
674 layers,
675 norm,
676 lm_head,
677 device: normal_loading_metadata.real_device,
678 cache: EitherCache::Normal(NormalCache::new_sliding(
679 cfg.num_hidden_layers,
680 cfg.max_position_embeddings,
681 cfg.sliding_window,
682 )),
683 max_seq_len: cfg.max_position_embeddings,
684 sliding_window: cfg.sliding_window,
685 cfg: ModelConfigMetadata {
686 max_seq_len: cfg.max_position_embeddings,
687 num_layers: cfg.num_hidden_layers,
688 hidden_size: cfg.hidden_size,
689 num_attn_heads: cfg.num_attention_heads / mapper.get_comm_for(0)?.world_size(),
690 num_kv_heads: (cfg.num_key_value_heads / mapper.get_comm_for(0)?.world_size())
691 .max(1),
692 sliding_window: cfg.sliding_window,
693 k_head_dim: cfg.head_dim(),
694 v_head_dim: cfg.head_dim(),
695 },
696 mapper,
697 })
698 }
699
700 pub fn forward(
701 &self,
702 input_ids: &Tensor,
703 seqlen_offsets: &[usize],
704 position_ids: &[usize],
705 context_lens: Vec<(usize, usize)>,
706 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
707 flash_params: &FlashParams,
708 ) -> Result<Tensor> {
709 let mut xs = self.embed_tokens.forward(input_ids)?;
710 let cache = &mut self.cache.normal().0;
711 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
712 input_ids,
713 metadata
714 .as_ref()
715 .map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
716 .unwrap_or(cache as &dyn PastKvLenCache),
717 self.sliding_window,
718 xs.dtype(),
719 self.cfg.num_attn_heads,
720 )?;
721 let attention_mask = attention_mask.filter(|_| {
723 metadata
724 .as_ref()
725 .map(|(_, meta)| meta.is_first_prompt_chunk)
726 .unwrap_or(true)
727 });
728
729 for (i, layer) in self.layers.iter().enumerate() {
730 xs = self.mapper.map(xs, i)?;
731 xs = layer.forward(
732 &xs,
733 attention_mask
734 .as_ref()
735 .map(|m| m.to_device(xs.device()).unwrap())
736 .as_ref(),
737 seqlen_offsets,
738 position_ids,
739 &mut cache[i],
740 metadata
741 .as_ref()
742 .map(|(kv_cache, metadata)| (kv_cache[i].clone(), *metadata)),
743 flash_params,
744 )?
745 }
746 let xs = xs.to_device(&self.device)?;
747 let mut xs = xs.apply(&self.norm)?;
748 if let Some(t) = self.lm_head.quantized_act_type() {
749 xs = xs.to_dtype(t)?;
750 }
751 extract_logits(&MatMul.qmethod_matmul(&xs, &*self.lm_head)?, context_lens)
752 }
753}
754
755impl IsqModel for Model {
756 fn get_layers(
757 &mut self,
758 ) -> (
759 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
760 &dyn DeviceMapper,
761 ) {
762 let mut tensors = Vec::new();
763 tensors.push((&mut self.lm_head, None));
764 for (i, layer) in self.layers.iter_mut().enumerate() {
765 tensors.push((&mut layer.self_attn.q_proj, Some(i)));
766 tensors.push((&mut layer.self_attn.k_proj, Some(i)));
767 tensors.push((&mut layer.self_attn.v_proj, Some(i)));
768 tensors.push((&mut layer.self_attn.o_proj, Some(i)));
769 for expert in &mut layer.mlp.experts {
770 tensors.push((&mut expert.w1, Some(i)));
771 tensors.push((&mut expert.w2, Some(i)));
772 tensors.push((&mut expert.w3, Some(i)));
773 }
774 }
775 (tensors, &*self.mapper)
776 }
777 fn get_layers_moe_experts_only(
778 &mut self,
779 ) -> (
780 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
781 &dyn DeviceMapper,
782 ) {
783 let mut tensors = Vec::new();
784 tensors.push((&mut self.lm_head, None));
785 for (i, layer) in self.layers.iter_mut().enumerate() {
786 for expert in &mut layer.mlp.experts {
787 tensors.push((&mut expert.w1, Some(i)));
788 tensors.push((&mut expert.w2, Some(i)));
789 tensors.push((&mut expert.w3, Some(i)));
790 }
791 }
792 (tensors, &*self.mapper)
793 }
794
795 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
796 let uvb = UnVarBuilder::new();
797
798 let uvb_m = uvb.pp("model");
799 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
800 uvb_m.pp("norm").add(&self.norm);
801
802 for (layer_idx, layer) in self.layers.iter().enumerate() {
803 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
804 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
805 uvb_l
806 .pp("post_attention_layernorm")
807 .add(&layer.post_attention_layernorm);
808 }
809
810 uvb.to_safetensors()
811 }
812
813 fn residual_tensors_moe_experts_only(&self) -> Option<Vec<(String, Tensor)>> {
814 let uvb = UnVarBuilder::new();
815
816 let uvb_m = uvb.pp("model");
817 uvb_m.pp("embed_tokens").add(&self.embed_tokens);
818 uvb_m.pp("norm").add(&self.norm);
819
820 for (layer_idx, layer) in self.layers.iter().enumerate() {
821 let uvb_l = uvb_m.pp("layers").pp(layer_idx);
822 uvb_l.pp("input_layernorm").add(&layer.input_layernorm);
823 uvb_l
824 .pp("post_attention_layernorm")
825 .add(&layer.post_attention_layernorm);
826
827 let uvb_attn = uvb_l.pp("self_attn");
828 uvb_attn.pp("q_proj").add(&layer.self_attn.q_proj);
829 uvb_attn.pp("k_proj").add(&layer.self_attn.k_proj);
830 uvb_attn.pp("v_proj").add(&layer.self_attn.v_proj);
831 uvb_attn.pp("o_proj").add(&layer.self_attn.o_proj);
832 }
833
834 Some(uvb.to_safetensors())
835 }
836}
837
838impl NormalModel for Model {
839 fn forward(
840 &self,
841 input_ids: &Tensor,
842 seqlen_offsets: &[usize],
843 context_lens: Vec<(usize, usize)>,
844 position_ids: Vec<usize>,
845 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
846 flash_params: &FlashParams,
847 ) -> Result<Tensor> {
848 self.forward(
849 input_ids,
850 seqlen_offsets,
851 &position_ids,
852 context_lens,
853 metadata,
854 flash_params,
855 )
856 }
857 fn xlora_forward(
858 &self,
859 _input_ids: &Tensor,
860 _input_ids_full: &Tensor,
861 _seqlen_offsets: &[usize],
862 _seqlen_offsets_full: &[usize],
863 _no_kv_cache: bool,
864 _non_granular_state: &Option<crate::xlora_models::NonGranularState>,
865 _context_lens: Vec<(usize, usize)>,
866 _position_ids: Vec<usize>,
867 _flash_params: &FlashParams,
868 _flash_params_full: &FlashParams,
869 ) -> Result<Tensor> {
870 unimplemented!()
871 }
872 fn cache(&self) -> &EitherCache {
873 &self.cache
874 }
875 fn cache_mut(&mut self) -> &mut EitherCache {
876 &mut self.cache
877 }
878 fn device(&self) -> &Device {
879 &self.device
880 }
881 fn is_xlora(&self) -> bool {
882 false
883 }
884 fn max_seq_len(&self) -> usize {
885 self.max_seq_len
886 }
887 fn config(&self) -> &ModelConfigMetadata {
888 &self.cfg
889 }
890}
891
892impl AnyMoeBaseModelMixin for Model {}