1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{self, Activation, RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel, NormalLoadingMetadata,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25 device_map::DeviceMapper,
26 layers::{CausalMasker, RmsNorm},
27 models::mixtral::Config,
28 pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38 num_heads: usize,
39 num_kv_heads: usize,
40 head_dim: usize,
41 rotary_emb: Arc<RotaryEmbedding>,
42 sliding_window: Option<usize>,
43 sdpa_params: SdpaParams,
44}
45
46impl Attention {
47 #[allow(clippy::too_many_arguments)]
48 fn new(
49 rotary_emb: Arc<RotaryEmbedding>,
50 cfg: &Config,
51 vb: ShardedVarBuilder,
52 lora_config: &[((String, String), LoraConfig)],
53 count: &mut usize,
54 ord: &Ordering,
55 mapper: &dyn DeviceMapper,
56 layer_idx: usize,
57 loading_isq: bool,
58 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59 ) -> Result<Self> {
60 let hidden_sz = cfg.hidden_size;
61 let num_heads = cfg.num_attention_heads;
62 let num_kv_heads = cfg.num_key_value_heads;
63 let head_dim = hidden_sz / num_heads;
64 let q_proj = linear_no_bias(
65 hidden_sz,
66 num_heads * head_dim,
67 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69 lora_config,
70 count,
71 ord,
72 preload_adapters,
73 )?;
74 let k_proj = linear_no_bias(
75 hidden_sz,
76 num_kv_heads * head_dim,
77 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79 lora_config,
80 count,
81 ord,
82 preload_adapters,
83 )?;
84 let v_proj = linear_no_bias(
85 hidden_sz,
86 num_kv_heads * head_dim,
87 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89 lora_config,
90 count,
91 ord,
92 preload_adapters,
93 )?;
94 let o_proj = linear_no_bias(
95 num_heads * head_dim,
96 hidden_sz,
97 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99 lora_config,
100 count,
101 ord,
102 preload_adapters,
103 )?;
104 Ok(Self {
105 q_proj,
106 k_proj,
107 v_proj,
108 o_proj,
109 num_heads,
110 num_kv_heads,
111 head_dim,
112 rotary_emb,
113 sliding_window: cfg.sliding_window,
114 sdpa_params: SdpaParams {
115 n_kv_groups: num_heads / num_kv_heads,
116 softcap: None,
117 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
118 sliding_window: cfg.sliding_window,
119 },
120 })
121 }
122
123 #[allow(clippy::too_many_arguments)]
124 fn forward(
125 &self,
126 xs: &Tensor,
127 attention_mask: Option<&Tensor>,
128 seqlen_offsets: &[usize],
129 kv_cache: &mut Option<(Tensor, Tensor)>,
130 scalings: Option<Tensor>,
131 global_scaling_weight: f64,
132 is_scaling_pass: Option<f64>,
133 flash_params: &FlashParams,
134 ) -> Result<Tensor> {
135 let (b_sz, q_len, _) = xs.dims3()?;
136
137 let original_dtype = xs.dtype();
138 let mut xs = xs.clone();
139 if let Some(t) = self.q_proj.quantized_act_type() {
140 xs = xs.to_dtype(t)?;
141 }
142 let mut q = self.q_proj.lora_forward(
143 &xs,
144 scalings.clone(),
145 global_scaling_weight,
146 is_scaling_pass,
147 )?;
148 let mut k = self.k_proj.lora_forward(
149 &xs,
150 scalings.clone(),
151 global_scaling_weight,
152 is_scaling_pass,
153 )?;
154 let mut v = self.v_proj.lora_forward(
155 &xs,
156 scalings.clone(),
157 global_scaling_weight,
158 is_scaling_pass,
159 )?;
160 if self.q_proj.quantized_act_type().is_some() {
161 q = q.to_dtype(original_dtype)?;
162 k = k.to_dtype(original_dtype)?;
163 v = v.to_dtype(original_dtype)?;
164 }
165
166 let (q, k, v) = if q_len != 1 {
167 let q = q
168 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169 .transpose(1, 2)?;
170 let k = k
171 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172 .transpose(1, 2)?;
173 let v = v
174 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175 .transpose(1, 2)?;
176 (q, k, v)
177 } else {
178 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
179 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181 (q, k, v)
182 };
183
184 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
185
186 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
187 kv_cache,
188 k,
189 v,
190 attention_mask,
191 self.sliding_window,
192 false,
193 )?;
194
195 let mut attn_output = Sdpa.run_attention(
196 &q,
197 &k,
198 &v,
199 attn_mask.as_ref(),
200 Some(flash_params),
201 &self.sdpa_params,
202 )?;
203
204 if let Some(t) = self.q_proj.quantized_act_type() {
205 attn_output = attn_output.to_dtype(t)?;
206 }
207 let mut res = self.o_proj.lora_forward(
208 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
209 scalings.clone(),
210 global_scaling_weight,
211 is_scaling_pass,
212 )?;
213 if self.q_proj.quantized_act_type().is_some() {
214 res = res.to_dtype(original_dtype)?;
215 }
216 Ok(res)
217 }
218}
219
220#[derive(Clone)]
221struct BlockSparseTop2MLP {
222 w1: Arc<dyn LinearLayerLike + Send + Sync>,
223 w2: Arc<dyn LinearLayerLike + Send + Sync>,
224 w3: Arc<dyn LinearLayerLike + Send + Sync>,
225 act_fn: Activation,
226}
227
228impl BlockSparseTop2MLP {
229 #[allow(clippy::too_many_arguments)]
230 fn new(
231 cfg: &Config,
232 vb: ShardedVarBuilder,
233 lora_config: &[((String, String), LoraConfig)],
234 count: &mut usize,
235 ord: &Ordering,
236 mapper: &dyn DeviceMapper,
237 layer_idx: usize,
238 loading_isq: bool,
239 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
240 ) -> Result<Self> {
241 let hidden_sz = cfg.hidden_size;
242 let intermediate_sz = cfg.intermediate_size;
243 let w1 = linear_no_bias(
244 hidden_sz,
245 intermediate_sz,
246 mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
247 mapper.set_device(layer_idx, vb.pp("w1"), false),
248 lora_config,
249 count,
250 ord,
251 preload_adapters,
252 )?;
253 let w2 = linear_no_bias(
254 intermediate_sz,
255 hidden_sz,
256 mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
257 mapper.set_device(layer_idx, vb.pp("w2"), false),
258 lora_config,
259 count,
260 ord,
261 preload_adapters,
262 )?;
263 let w3 = linear_no_bias(
264 hidden_sz,
265 intermediate_sz,
266 mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
267 mapper.set_device(layer_idx, vb.pp("w3"), false),
268 lora_config,
269 count,
270 ord,
271 preload_adapters,
272 )?;
273 Ok(Self {
274 w1,
275 w2,
276 w3,
277 act_fn: cfg.hidden_act,
278 })
279 }
280
281 fn forward(
282 &self,
283 xs: &Tensor,
284 scalings: Option<Tensor>,
285 global_scaling_weight: f64,
286 is_scaling_pass: Option<f64>,
287 ) -> Result<Tensor> {
288 let original_dtype = xs.dtype();
289 let mut xs = xs.clone();
290 if let Some(t) = self.w1.quantized_act_type() {
291 xs = xs.to_dtype(t)?;
292 }
293 let lhs = self
294 .w1
295 .lora_forward(
296 &xs,
297 scalings.clone(),
298 global_scaling_weight,
299 is_scaling_pass,
300 )?
301 .apply(&self.act_fn)?;
302 let rhs = self.w3.lora_forward(
303 &xs,
304 scalings.clone(),
305 global_scaling_weight,
306 is_scaling_pass,
307 )?;
308 let mut res = self.w2.lora_forward(
309 &(lhs * rhs)?,
310 scalings.clone(),
311 global_scaling_weight,
312 is_scaling_pass,
313 )?;
314 if self.w1.quantized_act_type().is_some() {
315 res = res.to_dtype(original_dtype)?;
316 }
317 Ok(res)
318 }
319}
320
321#[derive(Clone)]
322struct SparseMoeBlock {
323 gate: Arc<dyn LinearLayerLike + Send + Sync>,
324 experts: Vec<BlockSparseTop2MLP>,
325 num_experts_per_tok: usize,
326}
327
328impl SparseMoeBlock {
329 #[allow(clippy::too_many_arguments)]
330 fn new(
331 cfg: &Config,
332 vb: ShardedVarBuilder,
333 lora_config: &[((String, String), LoraConfig)],
334 count: &mut usize,
335 ord: &Ordering,
336 mapper: &dyn DeviceMapper,
337 layer_idx: usize,
338 loading_isq: bool,
339 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
340 ) -> Result<Self> {
341 let gate = linear_no_bias(
342 cfg.hidden_size,
343 cfg.num_local_experts,
344 mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
345 mapper.set_device(layer_idx, vb.pp("gate"), false),
346 lora_config,
347 count,
348 ord,
349 preload_adapters,
350 )?;
351 let mut experts = Vec::with_capacity(cfg.num_local_experts);
352 let vb = vb.pp("experts");
353 for idx in 0..cfg.num_local_experts {
354 let expert = BlockSparseTop2MLP::new(
355 cfg,
356 vb.pp(idx),
357 lora_config,
358 count,
359 ord,
360 mapper,
361 layer_idx,
362 loading_isq,
363 preload_adapters,
364 )?;
365 experts.push(expert)
366 }
367 Ok(SparseMoeBlock {
368 gate,
369 experts,
370 num_experts_per_tok: cfg.num_experts_per_tok,
371 })
372 }
373
374 fn forward(
375 &self,
376 xs: &Tensor,
377 scalings: Option<Tensor>,
378 global_scaling_weight: f64,
379 is_scaling_pass: Option<f64>,
380 ) -> Result<Tensor> {
381 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
382 let xs = xs.reshape(((), hidden_dim))?;
383
384 let original_dtype = xs.dtype();
385 let mut xs = xs.clone();
386 if let Some(t) = self.gate.quantized_act_type() {
387 xs = xs.to_dtype(t)?;
388 }
389 let mut router_logits = self.gate.lora_forward(
390 &xs,
391 scalings.clone(),
392 global_scaling_weight,
393 is_scaling_pass,
394 )?;
395 if self.gate.quantized_act_type().is_some() {
396 router_logits = router_logits.to_dtype(original_dtype)?;
397 }
398
399 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
400
401 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
404
405 let mut top_x = vec![vec![]; self.experts.len()];
408 let mut selected_rws = vec![vec![]; self.experts.len()];
409 for (row_idx, rw) in routing_weights.iter().enumerate() {
410 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
411 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
412 let mut sum_routing_weights = 0f32;
413 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
414 let expert_idx = expert_idx as usize;
415 let routing_weight = rw[expert_idx];
416 sum_routing_weights += routing_weight;
417 top_x[expert_idx].push(row_idx as u32);
418 }
419 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
420 let expert_idx = expert_idx as usize;
421 let routing_weight = rw[expert_idx];
422 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
423 }
424 }
425
426 let mut ys = xs.zeros_like()?;
430 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
431 let top_x = &top_x[expert_idx];
432 if top_x.is_empty() {
433 continue;
434 }
435 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
436 let selected_rws =
437 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
438 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
442 let current_hidden_states = expert_layer.forward(
444 ¤t_state,
445 scalings.clone(),
446 global_scaling_weight,
447 is_scaling_pass,
448 )?;
449 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
450 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
451 }
452
453 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
454 Ok(ys)
455 }
456}
457
458struct DecoderLayer {
459 self_attn: Attention,
460 block_sparse_moe: SparseMoeBlock,
461 input_layernorm: RmsNorm,
462 post_attention_layernorm: RmsNorm,
463}
464
465impl DecoderLayer {
466 #[allow(clippy::too_many_arguments)]
467 fn new(
468 rotary_emb: Arc<RotaryEmbedding>,
469 cfg: &Config,
470 vb: ShardedVarBuilder,
471 lora_config: &[((String, String), LoraConfig)],
472 count: &mut usize,
473 ord: &Ordering,
474 mapper: &dyn DeviceMapper,
475 layer_idx: usize,
476 loading_isq: bool,
477 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
478 ) -> Result<Self> {
479 let self_attn = Attention::new(
480 rotary_emb,
481 cfg,
482 vb.pp("self_attn"),
483 lora_config,
484 count,
485 ord,
486 mapper,
487 layer_idx,
488 loading_isq,
489 preload_adapters,
490 )?;
491 let block_sparse_moe = SparseMoeBlock::new(
492 cfg,
493 vb.pp("block_sparse_moe"),
494 lora_config,
495 count,
496 ord,
497 mapper,
498 layer_idx,
499 loading_isq,
500 preload_adapters,
501 )?;
502 let input_layernorm = RmsNorm::new(
503 cfg.hidden_size,
504 cfg.rms_norm_eps,
505 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
506 )?;
507 let post_attention_layernorm = RmsNorm::new(
508 cfg.hidden_size,
509 cfg.rms_norm_eps,
510 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
511 )?;
512 Ok(Self {
513 self_attn,
514 block_sparse_moe,
515 input_layernorm,
516 post_attention_layernorm,
517 })
518 }
519
520 #[allow(clippy::too_many_arguments)]
521 fn forward(
522 &self,
523 xs: &Tensor,
524 attention_mask: Option<&Tensor>,
525 seqlen_offsets: &[usize],
526 kv_cache: &mut Option<(Tensor, Tensor)>,
527 scalings: Option<Tensor>,
528 global_scaling_weight: f64,
529 is_scaling_pass: Option<f64>,
530 flash_params: &FlashParams,
531 ) -> Result<Tensor> {
532 let residual = xs;
533 let xs = self.input_layernorm.forward(xs)?;
534 let xs = self.self_attn.forward(
535 &xs,
536 attention_mask,
537 seqlen_offsets,
538 kv_cache,
539 scalings.clone(),
540 global_scaling_weight,
541 is_scaling_pass,
542 flash_params,
543 )?;
544 let xs = (xs + residual)?;
545 let residual = &xs;
546 let xs = self
547 .block_sparse_moe
548 .forward(
549 &xs.apply(&self.post_attention_layernorm)?,
550 scalings.clone(),
551 global_scaling_weight,
552 is_scaling_pass,
553 )?
554 .to_dtype(residual.dtype())?;
555 residual + xs
556 }
557}
558
559pub struct XLoraModel {
560 embed_tokens: candle_nn::Embedding,
561 layers: Vec<DecoderLayer>,
562 norm: RmsNorm,
563 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
564 sliding_window: Option<usize>,
565 device: Device,
566 cache: EitherCache,
567 dtype: DType,
568 max_seq_len: usize,
569 xlora_classifier: Option<XLoraClassifier>,
570 mapper: Box<dyn DeviceMapper + Send + Sync>,
571 cfg: ModelConfigMetadata,
572}
573
574impl XLoraModel {
575 #[allow(clippy::too_many_arguments)]
576 pub fn new(
577 cfg: &Config,
578 vb: ShardedVarBuilder,
579 lora_config: &[((String, String), LoraConfig)],
580 xlora_config: Option<XLoraConfig>,
581 xlora_ordering: Ordering,
582 is_gptx: bool,
583 normal_loading_metadata: NormalLoadingMetadata,
584 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
585 ) -> Result<Self> {
586 if let Some(ref quant_cfg) = &cfg.quantization_config {
587 tracing::info!(
588 "Using {} quantization: {}.",
589 quant_cfg.name(),
590 quant_cfg.get_bits_name(&vb)
591 );
592 }
593 let mapper = normal_loading_metadata.mapper;
594 let vb_m = vb.pp("model");
595
596 let embed_tokens = layers::embedding(
597 cfg.vocab_size,
598 cfg.hidden_size,
599 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
600 &cfg.quantization_config,
601 )?;
602 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
603 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
604 let vb_l = vb_m.pp("layers");
605 let mut ropes = HashMap::new();
606 for layer_idx in 0..cfg.num_hidden_layers {
607 let device = mapper
608 .device_for(layer_idx, false)
609 .unwrap_or(&normal_loading_metadata.real_device);
610 ropes.insert(
611 device.location(),
612 Arc::new(RotaryEmbedding::new(
613 cfg.rope_theta as f32,
614 head_dim,
615 cfg.max_position_embeddings,
616 device,
617 is_gptx,
618 vb_m.dtype(),
619 )?),
620 );
621 }
622
623 let mut count = 0;
624 for layer_idx in NiceProgressBar::<_, 'b'>(
625 0..cfg.num_hidden_layers,
626 "Loading repeating layers",
627 &normal_loading_metadata.multi_progress,
628 ) {
629 let device = mapper
630 .device_for(layer_idx, false)
631 .unwrap_or(&normal_loading_metadata.real_device);
632 let rotary_emb = ropes
633 .get(&device.location())
634 .expect("No RoPE for device location!")
635 .clone();
636 let layer = DecoderLayer::new(
637 rotary_emb.clone(),
638 cfg,
639 vb_l.pp(layer_idx),
640 lora_config,
641 &mut count,
642 &xlora_ordering,
643 &*mapper,
644 layer_idx,
645 normal_loading_metadata.loading_isq,
646 preload_adapters,
647 )?;
648 layers.push(layer)
649 }
650 if xlora_config.is_none() && preload_adapters.is_none() {
651 info!("Merging LoRA adapters.");
653 for layer in layers.iter_mut().tqdm() {
654 Arc::get_mut(&mut layer.self_attn.k_proj)
655 .unwrap()
656 .merge_weights()?;
657 Arc::get_mut(&mut layer.self_attn.o_proj)
658 .unwrap()
659 .merge_weights()?;
660 Arc::get_mut(&mut layer.self_attn.q_proj)
661 .unwrap()
662 .merge_weights()?;
663 Arc::get_mut(&mut layer.self_attn.v_proj)
664 .unwrap()
665 .merge_weights()?;
666
667 Arc::get_mut(&mut layer.block_sparse_moe.gate)
668 .unwrap()
669 .merge_weights()?;
670 for expert in layer.block_sparse_moe.experts.iter_mut() {
671 Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
672 Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
673 Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
674 }
675 }
676 }
677 let norm = RmsNorm::new(
678 cfg.hidden_size,
679 cfg.rms_norm_eps,
680 mapper.set_nm_device(vb_m.pp("norm"), false),
681 )?;
682 let lm_head = linear_no_bias(
683 cfg.hidden_size,
684 cfg.vocab_size,
685 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
686 mapper.set_nm_device(vb.pp("lm_head"), false),
687 lora_config,
688 &mut count,
689 &xlora_ordering,
690 preload_adapters,
691 )?;
692 if xlora_config.is_some() && lm_head.is_lora() {
693 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
695 }
696 Ok(Self {
697 embed_tokens,
698 layers,
699 norm,
700 lm_head,
701 sliding_window: cfg.sliding_window,
702 device: normal_loading_metadata.real_device,
703 dtype: vb.dtype(),
704 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
705 max_seq_len: cfg.max_position_embeddings,
706 xlora_classifier: xlora_config.map(|xlora_config| {
707 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
708 }),
709 mapper,
710 cfg: ModelConfigMetadata {
711 max_seq_len: cfg.max_position_embeddings,
712 num_layers: cfg.num_hidden_layers,
713 hidden_size: cfg.hidden_size,
714 num_kv_heads: cfg.num_key_value_heads,
715 num_attn_heads: cfg.num_attention_heads,
716 sliding_window: cfg.sliding_window,
717 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
718 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
719 },
720 })
721 }
722
723 #[allow(clippy::too_many_arguments)]
724 fn inner_forward(
725 &self,
726 input_ids: &Tensor,
727 seqlen_offsets: &[usize],
728 scalings: Option<Tensor>,
729 is_full_pass: bool,
730 no_kv_cache: bool,
731 is_scaling_pass: Option<f64>,
732 flash_params: &FlashParams,
733 ) -> Result<Tensor> {
734 let mut cache = if is_full_pass {
735 if no_kv_cache {
736 let mut new_cache = Vec::new();
737 for _ in 0..self.cache.full().xlora_lock().len() {
738 new_cache.push(None);
739 }
740
741 self.cache.full().xlora_lock().clone_from(&new_cache);
742 }
743 self.cache.full().xlora_lock()
744 } else {
745 self.cache.full().lock()
746 };
747 let mut xs = self.embed_tokens.forward(input_ids)?;
748 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
749 input_ids,
750 &*cache,
751 self.sliding_window,
752 xs.dtype(),
753 self.cfg.num_attn_heads,
754 )?;
755 for (i, layer) in self.layers.iter().enumerate() {
756 xs = self.mapper.map(xs, i)?;
757 xs = layer.forward(
758 &xs,
759 attention_mask
760 .as_ref()
761 .map(|m| m.to_device(xs.device()).unwrap())
762 .as_ref(),
763 seqlen_offsets,
764 &mut cache[i],
765 scalings.clone(),
766 self.xlora_classifier
767 .as_ref()
768 .map(|classifier| classifier.get_global_scaling_weight())
769 .unwrap_or(1.0),
770 is_scaling_pass,
771 flash_params,
772 )?
773 }
774 let xs = xs.to_device(&self.device)?;
775 xs.apply(&self.norm)
776 }
777
778 #[allow(clippy::too_many_arguments)]
779 pub fn forward(
780 &self,
781 input_ids: &Tensor,
782 input_ids_full: &Tensor,
783 seqlen_offsets: &[usize],
784 seqlen_offsets_full: &[usize],
785 no_kv_cache: bool,
786 non_granular_state: &Option<NonGranularState>,
787 context_lens: Vec<(usize, usize)>,
788 flash_params: &FlashParams,
789 flash_params_full: &FlashParams,
790 ) -> Result<Tensor> {
791 if self.xlora_classifier.is_some() {
792 let scalings = self.get_scalings(
793 input_ids,
794 input_ids_full,
795 seqlen_offsets,
796 seqlen_offsets_full,
797 no_kv_cache,
798 non_granular_state,
799 &vec![usize::MAX; context_lens.len()],
800 flash_params,
801 flash_params_full,
802 )?;
803
804 if no_kv_cache {
805 let mut res = self
806 .inner_forward(
807 input_ids_full,
808 seqlen_offsets_full,
809 Some(scalings),
810 true,
811 no_kv_cache,
812 None,
813 flash_params_full,
814 )?
815 .contiguous()?;
816 if let Some(t) = self.lm_head.quantized_act_type() {
817 res = res.to_dtype(t)?;
818 }
819 extract_logits(
820 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
821 context_lens,
822 )
823 } else {
824 let mut res = self
826 .inner_forward(
827 input_ids,
828 seqlen_offsets,
829 Some(scalings),
830 true,
831 no_kv_cache,
832 None,
833 flash_params,
834 )?
835 .contiguous()?;
836 if let Some(t) = self.lm_head.quantized_act_type() {
837 res = res.to_dtype(t)?;
838 }
839 extract_logits(
840 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
841 context_lens,
842 )
843 }
844 } else {
845 let mut res = self
846 .inner_forward(
847 input_ids,
848 seqlen_offsets,
849 None,
850 false,
851 no_kv_cache,
852 None,
853 flash_params,
854 )?
855 .contiguous()?;
856 if let Some(t) = self.lm_head.quantized_act_type() {
857 res = res.to_dtype(t)?;
858 }
859 extract_logits(
860 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
861 context_lens,
862 )
863 }
864 }
865}
866
867impl IsqModel for XLoraModel {
868 fn get_layers(
869 &mut self,
870 ) -> (
871 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
872 &dyn DeviceMapper,
873 ) {
874 let mut tensors = Vec::new();
875 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
876 for (i, layer) in self.layers.iter_mut().enumerate() {
877 tensors.push((
878 Arc::get_mut(&mut layer.self_attn.q_proj)
879 .unwrap()
880 .quant_inner(),
881 Some(i),
882 ));
883 tensors.push((
884 Arc::get_mut(&mut layer.self_attn.k_proj)
885 .unwrap()
886 .quant_inner(),
887 Some(i),
888 ));
889 tensors.push((
890 Arc::get_mut(&mut layer.self_attn.v_proj)
891 .unwrap()
892 .quant_inner(),
893 Some(i),
894 ));
895 tensors.push((
896 Arc::get_mut(&mut layer.self_attn.o_proj)
897 .unwrap()
898 .quant_inner(),
899 Some(i),
900 ));
901 tensors.push((
902 Arc::get_mut(&mut layer.block_sparse_moe.gate)
903 .unwrap()
904 .quant_inner(),
905 Some(i),
906 ));
907 for expert in &mut layer.block_sparse_moe.experts {
908 tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
909 tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
910 tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
911 }
912 }
913 (tensors, &*self.mapper)
914 }
915
916 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
917 panic!("Cannot generate UQFF for an adapter model.")
918 }
919}
920
921impl NormalModel for XLoraModel {
922 fn forward(
923 &self,
924 _input_ids: &Tensor,
925 _seqlen_offsets: &[usize],
926 _context_lens: Vec<(usize, usize)>,
927 _position_ids: Vec<usize>,
928 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
929 _flash_params: &FlashParams,
930 ) -> Result<Tensor> {
931 unreachable!()
932 }
933 fn xlora_forward(
934 &self,
935 input_ids: &Tensor,
936 input_ids_full: &Tensor,
937 seqlen_offsets: &[usize],
938 seqlen_offsets_full: &[usize],
939 no_kv_cache: bool,
940 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
941 context_lens: Vec<(usize, usize)>,
942 _position_ids: Vec<usize>,
943 flash_params: &FlashParams,
944 flash_params_full: &FlashParams,
945 ) -> Result<Tensor> {
946 self.forward(
947 input_ids,
948 input_ids_full,
949 seqlen_offsets,
950 seqlen_offsets_full,
951 no_kv_cache,
952 non_granular_state,
953 context_lens,
954 flash_params,
955 flash_params_full,
956 )
957 }
958 fn cache(&self) -> &EitherCache {
959 &self.cache
960 }
961 fn cache_mut(&mut self) -> &mut EitherCache {
962 &mut self.cache
963 }
964 fn device(&self) -> &Device {
965 &self.device
966 }
967 fn is_xlora(&self) -> bool {
968 true
969 }
970 fn max_seq_len(&self) -> usize {
971 self.max_seq_len
972 }
973 fn config(&self) -> &ModelConfigMetadata {
974 &self.cfg
975 }
976}
977
978impl ScalingsMaker for XLoraModel {
979 fn dtype(&self) -> DType {
980 self.dtype
981 }
982 fn get_cache(&self) -> &EitherCache {
983 &self.cache
984 }
985 fn get_classifier(&self) -> &XLoraClassifier {
986 self.xlora_classifier.as_ref().unwrap()
987 }
988 fn forward(
989 &self,
990 input_ids: &Tensor,
991 seqlen_offsets: &[usize],
992 scalings: Tensor,
993 is_full_pass: bool,
994 no_kv_cache: bool,
995 is_scaling_pass: Option<f64>,
996 _context_lens: &[usize],
997 flash_params: &FlashParams,
998 ) -> Result<Tensor> {
999 self.inner_forward(
1000 input_ids,
1001 seqlen_offsets,
1002 Some(scalings),
1003 is_full_pass,
1004 no_kv_cache,
1005 is_scaling_pass,
1006 flash_params,
1007 )
1008 }
1009}
1010
1011impl AnyMoeBaseModelMixin for XLoraModel {}