1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{self, Activation, RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel, NormalLoadingMetadata,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25 device_map::DeviceMapper,
26 layers::{CausalMasker, RmsNorm},
27 models::mixtral::Config,
28 pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38 num_heads: usize,
39 num_kv_heads: usize,
40 head_dim: usize,
41 rotary_emb: Arc<RotaryEmbedding>,
42 sliding_window: Option<usize>,
43 sdpa_params: SdpaParams,
44}
45
46impl Attention {
47 #[allow(clippy::too_many_arguments)]
48 fn new(
49 rotary_emb: Arc<RotaryEmbedding>,
50 cfg: &Config,
51 vb: ShardedVarBuilder,
52 lora_config: &[((String, String), LoraConfig)],
53 count: &mut usize,
54 ord: &Ordering,
55 mapper: &dyn DeviceMapper,
56 layer_idx: usize,
57 loading_isq: bool,
58 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59 ) -> Result<Self> {
60 let hidden_sz = cfg.hidden_size;
61 let num_heads = cfg.num_attention_heads;
62 let num_kv_heads = cfg.num_key_value_heads;
63 let head_dim = hidden_sz / num_heads;
64 let q_proj = linear_no_bias(
65 hidden_sz,
66 num_heads * head_dim,
67 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69 lora_config,
70 count,
71 ord,
72 preload_adapters,
73 )?;
74 let k_proj = linear_no_bias(
75 hidden_sz,
76 num_kv_heads * head_dim,
77 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79 lora_config,
80 count,
81 ord,
82 preload_adapters,
83 )?;
84 let v_proj = linear_no_bias(
85 hidden_sz,
86 num_kv_heads * head_dim,
87 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89 lora_config,
90 count,
91 ord,
92 preload_adapters,
93 )?;
94 let o_proj = linear_no_bias(
95 num_heads * head_dim,
96 hidden_sz,
97 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99 lora_config,
100 count,
101 ord,
102 preload_adapters,
103 )?;
104 Ok(Self {
105 q_proj,
106 k_proj,
107 v_proj,
108 o_proj,
109 num_heads,
110 num_kv_heads,
111 head_dim,
112 rotary_emb,
113 sliding_window: cfg.sliding_window,
114 sdpa_params: SdpaParams {
115 n_kv_groups: num_heads / num_kv_heads,
116 use_flash_attn: cfg.use_flash_attn,
117 softcap: None,
118 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
119 sliding_window: cfg.sliding_window,
120 },
121 })
122 }
123
124 #[allow(clippy::too_many_arguments)]
125 fn forward(
126 &self,
127 xs: &Tensor,
128 attention_mask: Option<&Tensor>,
129 seqlen_offsets: &[usize],
130 kv_cache: &mut Option<(Tensor, Tensor)>,
131 scalings: Option<Tensor>,
132 global_scaling_weight: f64,
133 is_scaling_pass: Option<f64>,
134 flash_params: &FlashParams,
135 ) -> Result<Tensor> {
136 let (b_sz, q_len, _) = xs.dims3()?;
137
138 let original_dtype = xs.dtype();
139 let mut xs = xs.clone();
140 if let Some(t) = self.q_proj.quantized_act_type() {
141 xs = xs.to_dtype(t)?;
142 }
143 let mut q = self.q_proj.lora_forward(
144 &xs,
145 scalings.clone(),
146 global_scaling_weight,
147 is_scaling_pass,
148 )?;
149 let mut k = self.k_proj.lora_forward(
150 &xs,
151 scalings.clone(),
152 global_scaling_weight,
153 is_scaling_pass,
154 )?;
155 let mut v = self.v_proj.lora_forward(
156 &xs,
157 scalings.clone(),
158 global_scaling_weight,
159 is_scaling_pass,
160 )?;
161 if self.q_proj.quantized_act_type().is_some() {
162 q = q.to_dtype(original_dtype)?;
163 k = k.to_dtype(original_dtype)?;
164 v = v.to_dtype(original_dtype)?;
165 }
166
167 let (q, k, v) = if q_len != 1 {
168 let q = q
169 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
170 .transpose(1, 2)?;
171 let k = k
172 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173 .transpose(1, 2)?;
174 let v = v
175 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
176 .transpose(1, 2)?;
177 (q, k, v)
178 } else {
179 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
180 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182 (q, k, v)
183 };
184
185 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
186
187 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
188 kv_cache,
189 k,
190 v,
191 attention_mask,
192 self.sliding_window,
193 false,
194 )?;
195
196 let mut attn_output = Sdpa.run_attention(
197 &q,
198 &k,
199 &v,
200 attn_mask.as_ref(),
201 Some(flash_params),
202 &self.sdpa_params,
203 )?;
204
205 if let Some(t) = self.q_proj.quantized_act_type() {
206 attn_output = attn_output.to_dtype(t)?;
207 }
208 let mut res = self.o_proj.lora_forward(
209 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
210 scalings.clone(),
211 global_scaling_weight,
212 is_scaling_pass,
213 )?;
214 if self.q_proj.quantized_act_type().is_some() {
215 res = res.to_dtype(original_dtype)?;
216 }
217 Ok(res)
218 }
219}
220
221#[derive(Clone)]
222struct BlockSparseTop2MLP {
223 w1: Arc<dyn LinearLayerLike + Send + Sync>,
224 w2: Arc<dyn LinearLayerLike + Send + Sync>,
225 w3: Arc<dyn LinearLayerLike + Send + Sync>,
226 act_fn: Activation,
227}
228
229impl BlockSparseTop2MLP {
230 #[allow(clippy::too_many_arguments)]
231 fn new(
232 cfg: &Config,
233 vb: ShardedVarBuilder,
234 lora_config: &[((String, String), LoraConfig)],
235 count: &mut usize,
236 ord: &Ordering,
237 mapper: &dyn DeviceMapper,
238 layer_idx: usize,
239 loading_isq: bool,
240 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
241 ) -> Result<Self> {
242 let hidden_sz = cfg.hidden_size;
243 let intermediate_sz = cfg.intermediate_size;
244 let w1 = linear_no_bias(
245 hidden_sz,
246 intermediate_sz,
247 mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
248 mapper.set_device(layer_idx, vb.pp("w1"), false),
249 lora_config,
250 count,
251 ord,
252 preload_adapters,
253 )?;
254 let w2 = linear_no_bias(
255 intermediate_sz,
256 hidden_sz,
257 mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
258 mapper.set_device(layer_idx, vb.pp("w2"), false),
259 lora_config,
260 count,
261 ord,
262 preload_adapters,
263 )?;
264 let w3 = linear_no_bias(
265 hidden_sz,
266 intermediate_sz,
267 mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
268 mapper.set_device(layer_idx, vb.pp("w3"), false),
269 lora_config,
270 count,
271 ord,
272 preload_adapters,
273 )?;
274 Ok(Self {
275 w1,
276 w2,
277 w3,
278 act_fn: cfg.hidden_act,
279 })
280 }
281
282 fn forward(
283 &self,
284 xs: &Tensor,
285 scalings: Option<Tensor>,
286 global_scaling_weight: f64,
287 is_scaling_pass: Option<f64>,
288 ) -> Result<Tensor> {
289 let original_dtype = xs.dtype();
290 let mut xs = xs.clone();
291 if let Some(t) = self.w1.quantized_act_type() {
292 xs = xs.to_dtype(t)?;
293 }
294 let lhs = self
295 .w1
296 .lora_forward(
297 &xs,
298 scalings.clone(),
299 global_scaling_weight,
300 is_scaling_pass,
301 )?
302 .apply(&self.act_fn)?;
303 let rhs = self.w3.lora_forward(
304 &xs,
305 scalings.clone(),
306 global_scaling_weight,
307 is_scaling_pass,
308 )?;
309 let mut res = self.w2.lora_forward(
310 &(lhs * rhs)?,
311 scalings.clone(),
312 global_scaling_weight,
313 is_scaling_pass,
314 )?;
315 if self.w1.quantized_act_type().is_some() {
316 res = res.to_dtype(original_dtype)?;
317 }
318 Ok(res)
319 }
320}
321
322#[derive(Clone)]
323struct SparseMoeBlock {
324 gate: Arc<dyn LinearLayerLike + Send + Sync>,
325 experts: Vec<BlockSparseTop2MLP>,
326 num_experts_per_tok: usize,
327}
328
329impl SparseMoeBlock {
330 #[allow(clippy::too_many_arguments)]
331 fn new(
332 cfg: &Config,
333 vb: ShardedVarBuilder,
334 lora_config: &[((String, String), LoraConfig)],
335 count: &mut usize,
336 ord: &Ordering,
337 mapper: &dyn DeviceMapper,
338 layer_idx: usize,
339 loading_isq: bool,
340 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341 ) -> Result<Self> {
342 let gate = linear_no_bias(
343 cfg.hidden_size,
344 cfg.num_local_experts,
345 mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
346 mapper.set_device(layer_idx, vb.pp("gate"), false),
347 lora_config,
348 count,
349 ord,
350 preload_adapters,
351 )?;
352 let mut experts = Vec::with_capacity(cfg.num_local_experts);
353 let vb = vb.pp("experts");
354 for idx in 0..cfg.num_local_experts {
355 let expert = BlockSparseTop2MLP::new(
356 cfg,
357 vb.pp(idx),
358 lora_config,
359 count,
360 ord,
361 mapper,
362 layer_idx,
363 loading_isq,
364 preload_adapters,
365 )?;
366 experts.push(expert)
367 }
368 Ok(SparseMoeBlock {
369 gate,
370 experts,
371 num_experts_per_tok: cfg.num_experts_per_tok,
372 })
373 }
374
375 fn forward(
376 &self,
377 xs: &Tensor,
378 scalings: Option<Tensor>,
379 global_scaling_weight: f64,
380 is_scaling_pass: Option<f64>,
381 ) -> Result<Tensor> {
382 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
383 let xs = xs.reshape(((), hidden_dim))?;
384
385 let original_dtype = xs.dtype();
386 let mut xs = xs.clone();
387 if let Some(t) = self.gate.quantized_act_type() {
388 xs = xs.to_dtype(t)?;
389 }
390 let mut router_logits = self.gate.lora_forward(
391 &xs,
392 scalings.clone(),
393 global_scaling_weight,
394 is_scaling_pass,
395 )?;
396 if self.gate.quantized_act_type().is_some() {
397 router_logits = router_logits.to_dtype(original_dtype)?;
398 }
399
400 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
401
402 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
405
406 let mut top_x = vec![vec![]; self.experts.len()];
409 let mut selected_rws = vec![vec![]; self.experts.len()];
410 for (row_idx, rw) in routing_weights.iter().enumerate() {
411 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
412 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
413 let mut sum_routing_weights = 0f32;
414 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
415 let expert_idx = expert_idx as usize;
416 let routing_weight = rw[expert_idx];
417 sum_routing_weights += routing_weight;
418 top_x[expert_idx].push(row_idx as u32);
419 }
420 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
421 let expert_idx = expert_idx as usize;
422 let routing_weight = rw[expert_idx];
423 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
424 }
425 }
426
427 let mut ys = xs.zeros_like()?;
431 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
432 let top_x = &top_x[expert_idx];
433 if top_x.is_empty() {
434 continue;
435 }
436 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
437 let selected_rws =
438 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
439 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
443 let current_hidden_states = expert_layer.forward(
445 ¤t_state,
446 scalings.clone(),
447 global_scaling_weight,
448 is_scaling_pass,
449 )?;
450 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
451 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
452 }
453
454 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
455 Ok(ys)
456 }
457}
458
459struct DecoderLayer {
460 self_attn: Attention,
461 block_sparse_moe: SparseMoeBlock,
462 input_layernorm: RmsNorm,
463 post_attention_layernorm: RmsNorm,
464}
465
466impl DecoderLayer {
467 #[allow(clippy::too_many_arguments)]
468 fn new(
469 rotary_emb: Arc<RotaryEmbedding>,
470 cfg: &Config,
471 vb: ShardedVarBuilder,
472 lora_config: &[((String, String), LoraConfig)],
473 count: &mut usize,
474 ord: &Ordering,
475 mapper: &dyn DeviceMapper,
476 layer_idx: usize,
477 loading_isq: bool,
478 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
479 ) -> Result<Self> {
480 let self_attn = Attention::new(
481 rotary_emb,
482 cfg,
483 vb.pp("self_attn"),
484 lora_config,
485 count,
486 ord,
487 mapper,
488 layer_idx,
489 loading_isq,
490 preload_adapters,
491 )?;
492 let block_sparse_moe = SparseMoeBlock::new(
493 cfg,
494 vb.pp("block_sparse_moe"),
495 lora_config,
496 count,
497 ord,
498 mapper,
499 layer_idx,
500 loading_isq,
501 preload_adapters,
502 )?;
503 let input_layernorm = RmsNorm::new(
504 cfg.hidden_size,
505 cfg.rms_norm_eps,
506 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
507 )?;
508 let post_attention_layernorm = RmsNorm::new(
509 cfg.hidden_size,
510 cfg.rms_norm_eps,
511 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
512 )?;
513 Ok(Self {
514 self_attn,
515 block_sparse_moe,
516 input_layernorm,
517 post_attention_layernorm,
518 })
519 }
520
521 #[allow(clippy::too_many_arguments)]
522 fn forward(
523 &self,
524 xs: &Tensor,
525 attention_mask: Option<&Tensor>,
526 seqlen_offsets: &[usize],
527 kv_cache: &mut Option<(Tensor, Tensor)>,
528 scalings: Option<Tensor>,
529 global_scaling_weight: f64,
530 is_scaling_pass: Option<f64>,
531 flash_params: &FlashParams,
532 ) -> Result<Tensor> {
533 let residual = xs;
534 let xs = self.input_layernorm.forward(xs)?;
535 let xs = self.self_attn.forward(
536 &xs,
537 attention_mask,
538 seqlen_offsets,
539 kv_cache,
540 scalings.clone(),
541 global_scaling_weight,
542 is_scaling_pass,
543 flash_params,
544 )?;
545 let xs = (xs + residual)?;
546 let residual = &xs;
547 let xs = self
548 .block_sparse_moe
549 .forward(
550 &xs.apply(&self.post_attention_layernorm)?,
551 scalings.clone(),
552 global_scaling_weight,
553 is_scaling_pass,
554 )?
555 .to_dtype(residual.dtype())?;
556 residual + xs
557 }
558}
559
560pub struct XLoraModel {
561 embed_tokens: candle_nn::Embedding,
562 layers: Vec<DecoderLayer>,
563 norm: RmsNorm,
564 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
565 sliding_window: Option<usize>,
566 device: Device,
567 cache: EitherCache,
568 dtype: DType,
569 max_seq_len: usize,
570 xlora_classifier: Option<XLoraClassifier>,
571 mapper: Box<dyn DeviceMapper + Send + Sync>,
572 cfg: ModelConfigMetadata,
573}
574
575impl XLoraModel {
576 #[allow(clippy::too_many_arguments)]
577 pub fn new(
578 cfg: &Config,
579 vb: ShardedVarBuilder,
580 lora_config: &[((String, String), LoraConfig)],
581 xlora_config: Option<XLoraConfig>,
582 xlora_ordering: Ordering,
583 is_gptx: bool,
584 normal_loading_metadata: NormalLoadingMetadata,
585 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
586 ) -> Result<Self> {
587 if let Some(ref quant_cfg) = &cfg.quantization_config {
588 tracing::info!(
589 "Using {} quantization: {}.",
590 quant_cfg.name(),
591 quant_cfg.get_bits_name(&vb)
592 );
593 }
594 let mapper = normal_loading_metadata.mapper;
595 let vb_m = vb.pp("model");
596
597 let embed_tokens = layers::embedding(
598 cfg.vocab_size,
599 cfg.hidden_size,
600 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
601 &cfg.quantization_config,
602 )?;
603 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
604 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
605 let vb_l = vb_m.pp("layers");
606 let mut ropes = HashMap::new();
607 for layer_idx in 0..cfg.num_hidden_layers {
608 let device = mapper
609 .device_for(layer_idx, false)
610 .unwrap_or(&normal_loading_metadata.real_device);
611 ropes.insert(
612 device.location(),
613 Arc::new(RotaryEmbedding::new(
614 cfg.rope_theta as f32,
615 head_dim,
616 cfg.max_position_embeddings,
617 device,
618 is_gptx,
619 vb_m.dtype(),
620 )?),
621 );
622 }
623
624 let mut count = 0;
625 for layer_idx in NiceProgressBar::<_, 'b'>(
626 0..cfg.num_hidden_layers,
627 "Loading repeating layers",
628 &normal_loading_metadata.multi_progress,
629 ) {
630 let device = mapper
631 .device_for(layer_idx, false)
632 .unwrap_or(&normal_loading_metadata.real_device);
633 let rotary_emb = ropes
634 .get(&device.location())
635 .expect("No RoPE for device location!")
636 .clone();
637 let layer = DecoderLayer::new(
638 rotary_emb.clone(),
639 cfg,
640 vb_l.pp(layer_idx),
641 lora_config,
642 &mut count,
643 &xlora_ordering,
644 &*mapper,
645 layer_idx,
646 normal_loading_metadata.loading_isq,
647 preload_adapters,
648 )?;
649 layers.push(layer)
650 }
651 if xlora_config.is_none() && preload_adapters.is_none() {
652 info!("Merging LoRA adapters.");
654 for layer in layers.iter_mut().tqdm() {
655 Arc::get_mut(&mut layer.self_attn.k_proj)
656 .unwrap()
657 .merge_weights()?;
658 Arc::get_mut(&mut layer.self_attn.o_proj)
659 .unwrap()
660 .merge_weights()?;
661 Arc::get_mut(&mut layer.self_attn.q_proj)
662 .unwrap()
663 .merge_weights()?;
664 Arc::get_mut(&mut layer.self_attn.v_proj)
665 .unwrap()
666 .merge_weights()?;
667
668 Arc::get_mut(&mut layer.block_sparse_moe.gate)
669 .unwrap()
670 .merge_weights()?;
671 for expert in layer.block_sparse_moe.experts.iter_mut() {
672 Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
673 Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
674 Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
675 }
676 }
677 }
678 let norm = RmsNorm::new(
679 cfg.hidden_size,
680 cfg.rms_norm_eps,
681 mapper.set_nm_device(vb_m.pp("norm"), false),
682 )?;
683 let lm_head = linear_no_bias(
684 cfg.hidden_size,
685 cfg.vocab_size,
686 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
687 mapper.set_nm_device(vb.pp("lm_head"), false),
688 lora_config,
689 &mut count,
690 &xlora_ordering,
691 preload_adapters,
692 )?;
693 if xlora_config.is_some() && lm_head.is_lora() {
694 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
696 }
697 Ok(Self {
698 embed_tokens,
699 layers,
700 norm,
701 lm_head,
702 sliding_window: cfg.sliding_window,
703 device: normal_loading_metadata.real_device,
704 dtype: vb.dtype(),
705 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
706 max_seq_len: cfg.max_position_embeddings,
707 xlora_classifier: xlora_config.map(|xlora_config| {
708 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
709 }),
710 mapper,
711 cfg: ModelConfigMetadata {
712 max_seq_len: cfg.max_position_embeddings,
713 num_layers: cfg.num_hidden_layers,
714 hidden_size: cfg.hidden_size,
715 num_kv_heads: cfg.num_key_value_heads,
716 num_attn_heads: cfg.num_attention_heads,
717 sliding_window: cfg.sliding_window,
718 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
719 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
720 },
721 })
722 }
723
724 #[allow(clippy::too_many_arguments)]
725 fn inner_forward(
726 &self,
727 input_ids: &Tensor,
728 seqlen_offsets: &[usize],
729 scalings: Option<Tensor>,
730 is_full_pass: bool,
731 no_kv_cache: bool,
732 is_scaling_pass: Option<f64>,
733 flash_params: &FlashParams,
734 ) -> Result<Tensor> {
735 let mut cache = if is_full_pass {
736 if no_kv_cache {
737 let mut new_cache = Vec::new();
738 for _ in 0..self.cache.full().xlora_lock().len() {
739 new_cache.push(None);
740 }
741
742 self.cache.full().xlora_lock().clone_from(&new_cache);
743 }
744 self.cache.full().xlora_lock()
745 } else {
746 self.cache.full().lock()
747 };
748 let mut xs = self.embed_tokens.forward(input_ids)?;
749 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
750 input_ids,
751 &*cache,
752 self.sliding_window,
753 xs.dtype(),
754 self.cfg.num_attn_heads,
755 )?;
756 for (i, layer) in self.layers.iter().enumerate() {
757 xs = self.mapper.map(xs, i)?;
758 xs = layer.forward(
759 &xs,
760 attention_mask
761 .as_ref()
762 .map(|m| m.to_device(xs.device()).unwrap())
763 .as_ref(),
764 seqlen_offsets,
765 &mut cache[i],
766 scalings.clone(),
767 self.xlora_classifier
768 .as_ref()
769 .map(|classifier| classifier.get_global_scaling_weight())
770 .unwrap_or(1.0),
771 is_scaling_pass,
772 flash_params,
773 )?
774 }
775 let xs = xs.to_device(&self.device)?;
776 xs.apply(&self.norm)
777 }
778
779 #[allow(clippy::too_many_arguments)]
780 pub fn forward(
781 &self,
782 input_ids: &Tensor,
783 input_ids_full: &Tensor,
784 seqlen_offsets: &[usize],
785 seqlen_offsets_full: &[usize],
786 no_kv_cache: bool,
787 non_granular_state: &Option<NonGranularState>,
788 context_lens: Vec<(usize, usize)>,
789 flash_params: &FlashParams,
790 flash_params_full: &FlashParams,
791 ) -> Result<Tensor> {
792 if self.xlora_classifier.is_some() {
793 let scalings = self.get_scalings(
794 input_ids,
795 input_ids_full,
796 seqlen_offsets,
797 seqlen_offsets_full,
798 no_kv_cache,
799 non_granular_state,
800 &vec![usize::MAX; context_lens.len()],
801 flash_params,
802 flash_params_full,
803 )?;
804
805 if no_kv_cache {
806 let mut res = self
807 .inner_forward(
808 input_ids_full,
809 seqlen_offsets_full,
810 Some(scalings),
811 true,
812 no_kv_cache,
813 None,
814 flash_params_full,
815 )?
816 .contiguous()?;
817 if let Some(t) = self.lm_head.quantized_act_type() {
818 res = res.to_dtype(t)?;
819 }
820 extract_logits(
821 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
822 context_lens,
823 )
824 } else {
825 let mut res = self
827 .inner_forward(
828 input_ids,
829 seqlen_offsets,
830 Some(scalings),
831 true,
832 no_kv_cache,
833 None,
834 flash_params,
835 )?
836 .contiguous()?;
837 if let Some(t) = self.lm_head.quantized_act_type() {
838 res = res.to_dtype(t)?;
839 }
840 extract_logits(
841 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
842 context_lens,
843 )
844 }
845 } else {
846 let mut res = self
847 .inner_forward(
848 input_ids,
849 seqlen_offsets,
850 None,
851 false,
852 no_kv_cache,
853 None,
854 flash_params,
855 )?
856 .contiguous()?;
857 if let Some(t) = self.lm_head.quantized_act_type() {
858 res = res.to_dtype(t)?;
859 }
860 extract_logits(
861 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
862 context_lens,
863 )
864 }
865 }
866}
867
868impl IsqModel for XLoraModel {
869 fn get_layers(
870 &mut self,
871 ) -> (
872 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
873 &dyn DeviceMapper,
874 ) {
875 let mut tensors = Vec::new();
876 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
877 for (i, layer) in self.layers.iter_mut().enumerate() {
878 tensors.push((
879 Arc::get_mut(&mut layer.self_attn.q_proj)
880 .unwrap()
881 .quant_inner(),
882 Some(i),
883 ));
884 tensors.push((
885 Arc::get_mut(&mut layer.self_attn.k_proj)
886 .unwrap()
887 .quant_inner(),
888 Some(i),
889 ));
890 tensors.push((
891 Arc::get_mut(&mut layer.self_attn.v_proj)
892 .unwrap()
893 .quant_inner(),
894 Some(i),
895 ));
896 tensors.push((
897 Arc::get_mut(&mut layer.self_attn.o_proj)
898 .unwrap()
899 .quant_inner(),
900 Some(i),
901 ));
902 tensors.push((
903 Arc::get_mut(&mut layer.block_sparse_moe.gate)
904 .unwrap()
905 .quant_inner(),
906 Some(i),
907 ));
908 for expert in &mut layer.block_sparse_moe.experts {
909 tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
910 tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
911 tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
912 }
913 }
914 (tensors, &*self.mapper)
915 }
916
917 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
918 panic!("Cannot generate UQFF for an adapter model.")
919 }
920}
921
922impl NormalModel for XLoraModel {
923 fn forward(
924 &self,
925 _input_ids: &Tensor,
926 _seqlen_offsets: &[usize],
927 _context_lens: Vec<(usize, usize)>,
928 _position_ids: Vec<usize>,
929 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
930 _flash_params: &FlashParams,
931 ) -> Result<Tensor> {
932 unreachable!()
933 }
934 fn xlora_forward(
935 &self,
936 input_ids: &Tensor,
937 input_ids_full: &Tensor,
938 seqlen_offsets: &[usize],
939 seqlen_offsets_full: &[usize],
940 no_kv_cache: bool,
941 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
942 context_lens: Vec<(usize, usize)>,
943 _position_ids: Vec<usize>,
944 flash_params: &FlashParams,
945 flash_params_full: &FlashParams,
946 ) -> Result<Tensor> {
947 self.forward(
948 input_ids,
949 input_ids_full,
950 seqlen_offsets,
951 seqlen_offsets_full,
952 no_kv_cache,
953 non_granular_state,
954 context_lens,
955 flash_params,
956 flash_params_full,
957 )
958 }
959 fn cache(&self) -> &EitherCache {
960 &self.cache
961 }
962 fn cache_mut(&mut self) -> &mut EitherCache {
963 &mut self.cache
964 }
965 fn device(&self) -> &Device {
966 &self.device
967 }
968 fn is_xlora(&self) -> bool {
969 true
970 }
971 fn max_seq_len(&self) -> usize {
972 self.max_seq_len
973 }
974 fn config(&self) -> &ModelConfigMetadata {
975 &self.cfg
976 }
977}
978
979impl ScalingsMaker for XLoraModel {
980 fn dtype(&self) -> DType {
981 self.dtype
982 }
983 fn get_cache(&self) -> &EitherCache {
984 &self.cache
985 }
986 fn get_classifier(&self) -> &XLoraClassifier {
987 self.xlora_classifier.as_ref().unwrap()
988 }
989 fn forward(
990 &self,
991 input_ids: &Tensor,
992 seqlen_offsets: &[usize],
993 scalings: Tensor,
994 is_full_pass: bool,
995 no_kv_cache: bool,
996 is_scaling_pass: Option<f64>,
997 _context_lens: &[usize],
998 flash_params: &FlashParams,
999 ) -> Result<Tensor> {
1000 self.inner_forward(
1001 input_ids,
1002 seqlen_offsets,
1003 Some(scalings),
1004 is_full_pass,
1005 no_kv_cache,
1006 is_scaling_pass,
1007 flash_params,
1008 )
1009 }
1010}
1011
1012impl AnyMoeBaseModelMixin for XLoraModel {}