1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{self, Activation, RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel, NormalLoadingMetadata,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25 device_map::DeviceMapper,
26 layers::{CausalMasker, RmsNorm},
27 models::mixtral::Config,
28 pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38 num_heads: usize,
39 num_kv_heads: usize,
40 head_dim: usize,
41 rotary_emb: Arc<RotaryEmbedding>,
42 sliding_window: Option<usize>,
43 sdpa_params: SdpaParams,
44}
45
46impl Attention {
47 #[allow(clippy::too_many_arguments)]
48 fn new(
49 rotary_emb: Arc<RotaryEmbedding>,
50 cfg: &Config,
51 vb: ShardedVarBuilder,
52 lora_config: &[((String, String), LoraConfig)],
53 count: &mut usize,
54 ord: &Ordering,
55 mapper: &dyn DeviceMapper,
56 layer_idx: usize,
57 loading_isq: bool,
58 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59 ) -> Result<Self> {
60 let hidden_sz = cfg.hidden_size;
61 let num_heads = cfg.num_attention_heads;
62 let num_kv_heads = cfg.num_key_value_heads;
63 let head_dim = hidden_sz / num_heads;
64 let q_proj = linear_no_bias(
65 hidden_sz,
66 num_heads * head_dim,
67 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69 lora_config,
70 count,
71 ord,
72 preload_adapters,
73 )?;
74 let k_proj = linear_no_bias(
75 hidden_sz,
76 num_kv_heads * head_dim,
77 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79 lora_config,
80 count,
81 ord,
82 preload_adapters,
83 )?;
84 let v_proj = linear_no_bias(
85 hidden_sz,
86 num_kv_heads * head_dim,
87 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89 lora_config,
90 count,
91 ord,
92 preload_adapters,
93 )?;
94 let o_proj = linear_no_bias(
95 num_heads * head_dim,
96 hidden_sz,
97 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99 lora_config,
100 count,
101 ord,
102 preload_adapters,
103 )?;
104 Ok(Self {
105 q_proj,
106 k_proj,
107 v_proj,
108 o_proj,
109 num_heads,
110 num_kv_heads,
111 head_dim,
112 rotary_emb,
113 sliding_window: cfg.sliding_window,
114 sdpa_params: SdpaParams {
115 n_kv_groups: num_heads / num_kv_heads,
116 use_flash_attn: cfg.use_flash_attn,
117 softcap: None,
118 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
119 sliding_window: cfg.sliding_window,
120 },
121 })
122 }
123
124 #[allow(clippy::too_many_arguments)]
125 fn forward(
126 &self,
127 xs: &Tensor,
128 attention_mask: Option<&Tensor>,
129 seqlen_offsets: &[usize],
130 kv_cache: &mut Option<(Tensor, Tensor)>,
131 scalings: Option<Tensor>,
132 global_scaling_weight: f64,
133 is_scaling_pass: Option<f64>,
134 flash_params: &FlashParams,
135 ) -> Result<Tensor> {
136 let (b_sz, q_len, _) = xs.dims3()?;
137
138 let original_dtype = xs.dtype();
139 let mut xs = xs.clone();
140 if let Some(t) = self.q_proj.quantized_act_type() {
141 xs = xs.to_dtype(t)?;
142 }
143 let mut q = self.q_proj.lora_forward(
144 &xs,
145 scalings.clone(),
146 global_scaling_weight,
147 is_scaling_pass,
148 )?;
149 let mut k = self.k_proj.lora_forward(
150 &xs,
151 scalings.clone(),
152 global_scaling_weight,
153 is_scaling_pass,
154 )?;
155 let mut v = self.v_proj.lora_forward(
156 &xs,
157 scalings.clone(),
158 global_scaling_weight,
159 is_scaling_pass,
160 )?;
161 if self.q_proj.quantized_act_type().is_some() {
162 q = q.to_dtype(original_dtype)?;
163 k = k.to_dtype(original_dtype)?;
164 v = v.to_dtype(original_dtype)?;
165 }
166
167 let (q, k, v) = if q_len != 1 {
168 let q = q
169 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
170 .transpose(1, 2)?;
171 let k = k
172 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
173 .transpose(1, 2)?;
174 let v = v
175 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
176 .transpose(1, 2)?;
177 (q, k, v)
178 } else {
179 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
180 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
182 (q, k, v)
183 };
184
185 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
186
187 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
188 kv_cache,
189 k,
190 v,
191 attention_mask,
192 self.sliding_window,
193 false,
194 )?;
195
196 let mut attn_output = Sdpa.run_attention(
197 &q,
198 &k,
199 &v,
200 attn_mask.as_ref(),
201 Some(flash_params),
202 &self.sdpa_params,
203 )?;
204
205 if let Some(t) = self.q_proj.quantized_act_type() {
206 attn_output = attn_output.to_dtype(t)?;
207 }
208 let mut res = self.o_proj.lora_forward(
209 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
210 scalings.clone(),
211 global_scaling_weight,
212 is_scaling_pass,
213 )?;
214 if self.q_proj.quantized_act_type().is_some() {
215 res = res.to_dtype(original_dtype)?;
216 }
217 Ok(res)
218 }
219}
220
221#[derive(Clone)]
222struct BlockSparseTop2MLP {
223 w1: Arc<dyn LinearLayerLike + Send + Sync>,
224 w2: Arc<dyn LinearLayerLike + Send + Sync>,
225 w3: Arc<dyn LinearLayerLike + Send + Sync>,
226 act_fn: Activation,
227}
228
229impl BlockSparseTop2MLP {
230 #[allow(clippy::too_many_arguments)]
231 fn new(
232 cfg: &Config,
233 vb: ShardedVarBuilder,
234 lora_config: &[((String, String), LoraConfig)],
235 count: &mut usize,
236 ord: &Ordering,
237 mapper: &dyn DeviceMapper,
238 layer_idx: usize,
239 loading_isq: bool,
240 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
241 ) -> Result<Self> {
242 let hidden_sz = cfg.hidden_size;
243 let intermediate_sz = cfg.intermediate_size;
244 let w1 = linear_no_bias(
245 hidden_sz,
246 intermediate_sz,
247 mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
248 mapper.set_device(layer_idx, vb.pp("w1"), false),
249 lora_config,
250 count,
251 ord,
252 preload_adapters,
253 )?;
254 let w2 = linear_no_bias(
255 intermediate_sz,
256 hidden_sz,
257 mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
258 mapper.set_device(layer_idx, vb.pp("w2"), false),
259 lora_config,
260 count,
261 ord,
262 preload_adapters,
263 )?;
264 let w3 = linear_no_bias(
265 hidden_sz,
266 intermediate_sz,
267 mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
268 mapper.set_device(layer_idx, vb.pp("w3"), false),
269 lora_config,
270 count,
271 ord,
272 preload_adapters,
273 )?;
274 Ok(Self {
275 w1,
276 w2,
277 w3,
278 act_fn: cfg.hidden_act,
279 })
280 }
281
282 fn forward(
283 &self,
284 xs: &Tensor,
285 scalings: Option<Tensor>,
286 global_scaling_weight: f64,
287 is_scaling_pass: Option<f64>,
288 ) -> Result<Tensor> {
289 let original_dtype = xs.dtype();
290 let mut xs = xs.clone();
291 if let Some(t) = self.w1.quantized_act_type() {
292 xs = xs.to_dtype(t)?;
293 }
294 let lhs = self
295 .w1
296 .lora_forward(
297 &xs,
298 scalings.clone(),
299 global_scaling_weight,
300 is_scaling_pass,
301 )?
302 .apply(&self.act_fn)?;
303 let rhs = self.w3.lora_forward(
304 &xs,
305 scalings.clone(),
306 global_scaling_weight,
307 is_scaling_pass,
308 )?;
309 let mut res = self.w2.lora_forward(
310 &(lhs * rhs)?,
311 scalings.clone(),
312 global_scaling_weight,
313 is_scaling_pass,
314 )?;
315 if self.w1.quantized_act_type().is_some() {
316 res = res.to_dtype(original_dtype)?;
317 }
318 Ok(res)
319 }
320}
321
322#[derive(Clone)]
323struct SparseMoeBlock {
324 gate: Arc<dyn LinearLayerLike + Send + Sync>,
325 experts: Vec<BlockSparseTop2MLP>,
326 num_experts_per_tok: usize,
327}
328
329impl SparseMoeBlock {
330 #[allow(clippy::too_many_arguments)]
331 fn new(
332 cfg: &Config,
333 vb: ShardedVarBuilder,
334 lora_config: &[((String, String), LoraConfig)],
335 count: &mut usize,
336 ord: &Ordering,
337 mapper: &dyn DeviceMapper,
338 layer_idx: usize,
339 loading_isq: bool,
340 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341 ) -> Result<Self> {
342 let gate = linear_no_bias(
343 cfg.hidden_size,
344 cfg.num_local_experts,
345 mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
346 mapper.set_device(layer_idx, vb.pp("gate"), false),
347 lora_config,
348 count,
349 ord,
350 preload_adapters,
351 )?;
352 let mut experts = Vec::with_capacity(cfg.num_local_experts);
353 let vb = vb.pp("experts");
354 for idx in 0..cfg.num_local_experts {
355 let expert = BlockSparseTop2MLP::new(
356 cfg,
357 vb.pp(idx),
358 lora_config,
359 count,
360 ord,
361 mapper,
362 layer_idx,
363 loading_isq,
364 preload_adapters,
365 )?;
366 experts.push(expert)
367 }
368 Ok(SparseMoeBlock {
369 gate,
370 experts,
371 num_experts_per_tok: cfg.num_experts_per_tok,
372 })
373 }
374
375 fn forward(
376 &self,
377 xs: &Tensor,
378 scalings: Option<Tensor>,
379 global_scaling_weight: f64,
380 is_scaling_pass: Option<f64>,
381 ) -> Result<Tensor> {
382 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
383 let xs = xs.reshape(((), hidden_dim))?;
384
385 let original_dtype = xs.dtype();
386 let mut xs = xs.clone();
387 if let Some(t) = self.gate.quantized_act_type() {
388 xs = xs.to_dtype(t)?;
389 }
390 let mut router_logits = self.gate.lora_forward(
391 &xs,
392 scalings.clone(),
393 global_scaling_weight,
394 is_scaling_pass,
395 )?;
396 if self.gate.quantized_act_type().is_some() {
397 router_logits = router_logits.to_dtype(original_dtype)?;
398 }
399
400 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
401
402 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
405
406 let mut top_x = vec![vec![]; self.experts.len()];
409 let mut selected_rws = vec![vec![]; self.experts.len()];
410 for (row_idx, rw) in routing_weights.iter().enumerate() {
411 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
412 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
413 let mut sum_routing_weights = 0f32;
414 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
415 let expert_idx = expert_idx as usize;
416 let routing_weight = rw[expert_idx];
417 sum_routing_weights += routing_weight;
418 top_x[expert_idx].push(row_idx as u32);
419 }
420 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
421 let expert_idx = expert_idx as usize;
422 let routing_weight = rw[expert_idx];
423 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
424 }
425 }
426
427 let mut ys = xs.zeros_like()?;
431 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
432 let top_x = &top_x[expert_idx];
433 if top_x.is_empty() {
434 continue;
435 }
436 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
437 let selected_rws =
438 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
439 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
443 let current_hidden_states = expert_layer.forward(
445 ¤t_state,
446 scalings.clone(),
447 global_scaling_weight,
448 is_scaling_pass,
449 )?;
450 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
451 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
452 }
453
454 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
455 Ok(ys)
456 }
457}
458
459struct DecoderLayer {
460 self_attn: Attention,
461 block_sparse_moe: SparseMoeBlock,
462 input_layernorm: RmsNorm,
463 post_attention_layernorm: RmsNorm,
464}
465
466impl DecoderLayer {
467 #[allow(clippy::too_many_arguments)]
468 fn new(
469 rotary_emb: Arc<RotaryEmbedding>,
470 cfg: &Config,
471 vb: ShardedVarBuilder,
472 lora_config: &[((String, String), LoraConfig)],
473 count: &mut usize,
474 ord: &Ordering,
475 mapper: &dyn DeviceMapper,
476 layer_idx: usize,
477 loading_isq: bool,
478 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
479 ) -> Result<Self> {
480 let self_attn = Attention::new(
481 rotary_emb,
482 cfg,
483 vb.pp("self_attn"),
484 lora_config,
485 count,
486 ord,
487 mapper,
488 layer_idx,
489 loading_isq,
490 preload_adapters,
491 )?;
492 let block_sparse_moe = SparseMoeBlock::new(
493 cfg,
494 vb.pp("block_sparse_moe"),
495 lora_config,
496 count,
497 ord,
498 mapper,
499 layer_idx,
500 loading_isq,
501 preload_adapters,
502 )?;
503 let input_layernorm = RmsNorm::new(
504 cfg.hidden_size,
505 cfg.rms_norm_eps,
506 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
507 )?;
508 let post_attention_layernorm = RmsNorm::new(
509 cfg.hidden_size,
510 cfg.rms_norm_eps,
511 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
512 )?;
513 Ok(Self {
514 self_attn,
515 block_sparse_moe,
516 input_layernorm,
517 post_attention_layernorm,
518 })
519 }
520
521 #[allow(clippy::too_many_arguments)]
522 fn forward(
523 &self,
524 xs: &Tensor,
525 attention_mask: Option<&Tensor>,
526 seqlen_offsets: &[usize],
527 kv_cache: &mut Option<(Tensor, Tensor)>,
528 scalings: Option<Tensor>,
529 global_scaling_weight: f64,
530 is_scaling_pass: Option<f64>,
531 flash_params: &FlashParams,
532 ) -> Result<Tensor> {
533 let residual = xs;
534 let xs = self.input_layernorm.forward(xs)?;
535 let xs = self.self_attn.forward(
536 &xs,
537 attention_mask,
538 seqlen_offsets,
539 kv_cache,
540 scalings.clone(),
541 global_scaling_weight,
542 is_scaling_pass,
543 flash_params,
544 )?;
545 let xs = (xs + residual)?;
546 let residual = &xs;
547 let xs = self
548 .block_sparse_moe
549 .forward(
550 &xs.apply(&self.post_attention_layernorm)?,
551 scalings.clone(),
552 global_scaling_weight,
553 is_scaling_pass,
554 )?
555 .to_dtype(residual.dtype())?;
556 residual + xs
557 }
558}
559
560pub struct XLoraModel {
561 embed_tokens: candle_nn::Embedding,
562 layers: Vec<DecoderLayer>,
563 norm: RmsNorm,
564 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
565 sliding_window: Option<usize>,
566 device: Device,
567 cache: EitherCache,
568 dtype: DType,
569 max_seq_len: usize,
570 xlora_classifier: Option<XLoraClassifier>,
571 mapper: Box<dyn DeviceMapper + Send + Sync>,
572 cfg: ModelConfigMetadata,
573}
574
575impl XLoraModel {
576 #[allow(clippy::too_many_arguments)]
577 pub fn new(
578 cfg: &Config,
579 vb: ShardedVarBuilder,
580 lora_config: &[((String, String), LoraConfig)],
581 xlora_config: Option<XLoraConfig>,
582 xlora_ordering: Ordering,
583 is_gptx: bool,
584 normal_loading_metadata: NormalLoadingMetadata,
585 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
586 ) -> Result<Self> {
587 if let Some(ref quant_cfg) = &cfg.quantization_config {
588 tracing::info!(
589 "Using {} quantization: {}.",
590 quant_cfg.quant_method.to_string(),
591 quant_cfg.get_bits_name(&vb)
592 );
593 }
594 let mapper = normal_loading_metadata.mapper;
595 let vb_m = vb.pp("model");
596
597 let embed_tokens = layers::embedding(
598 cfg.vocab_size,
599 cfg.hidden_size,
600 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
601 )?;
602 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
603 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
604 let vb_l = vb_m.pp("layers");
605 let mut ropes = HashMap::new();
606 for layer_idx in 0..cfg.num_hidden_layers {
607 let device = mapper
608 .device_for(layer_idx, false)
609 .unwrap_or(&normal_loading_metadata.real_device);
610 ropes.insert(
611 device.location(),
612 Arc::new(RotaryEmbedding::new(
613 cfg.rope_theta as f32,
614 head_dim,
615 cfg.max_position_embeddings,
616 device,
617 is_gptx,
618 vb_m.dtype(),
619 )?),
620 );
621 }
622
623 let mut count = 0;
624 for layer_idx in NiceProgressBar::<_, 'b'>(
625 0..cfg.num_hidden_layers,
626 "Loading repeating layers",
627 &normal_loading_metadata.multi_progress,
628 ) {
629 let device = mapper
630 .device_for(layer_idx, false)
631 .unwrap_or(&normal_loading_metadata.real_device);
632 let rotary_emb = ropes
633 .get(&device.location())
634 .expect("No RoPE for device location!")
635 .clone();
636 let layer = DecoderLayer::new(
637 rotary_emb.clone(),
638 cfg,
639 vb_l.pp(layer_idx),
640 lora_config,
641 &mut count,
642 &xlora_ordering,
643 &*mapper,
644 layer_idx,
645 normal_loading_metadata.loading_isq,
646 preload_adapters,
647 )?;
648 layers.push(layer)
649 }
650 if xlora_config.is_none() && preload_adapters.is_none() {
651 info!("Merging LoRA adapters.");
653 for layer in layers.iter_mut().tqdm() {
654 Arc::get_mut(&mut layer.self_attn.k_proj)
655 .unwrap()
656 .merge_weights()?;
657 Arc::get_mut(&mut layer.self_attn.o_proj)
658 .unwrap()
659 .merge_weights()?;
660 Arc::get_mut(&mut layer.self_attn.q_proj)
661 .unwrap()
662 .merge_weights()?;
663 Arc::get_mut(&mut layer.self_attn.v_proj)
664 .unwrap()
665 .merge_weights()?;
666
667 Arc::get_mut(&mut layer.block_sparse_moe.gate)
668 .unwrap()
669 .merge_weights()?;
670 for expert in layer.block_sparse_moe.experts.iter_mut() {
671 Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
672 Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
673 Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
674 }
675 }
676 }
677 let norm = RmsNorm::new(
678 cfg.hidden_size,
679 cfg.rms_norm_eps,
680 mapper.set_nm_device(vb_m.pp("norm"), false),
681 )?;
682 let lm_head = linear_no_bias(
683 cfg.hidden_size,
684 cfg.vocab_size,
685 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
686 mapper.set_nm_device(vb.pp("lm_head"), false),
687 lora_config,
688 &mut count,
689 &xlora_ordering,
690 preload_adapters,
691 )?;
692 if xlora_config.is_some() && lm_head.is_lora() {
693 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
695 }
696 Ok(Self {
697 embed_tokens,
698 layers,
699 norm,
700 lm_head,
701 sliding_window: cfg.sliding_window,
702 device: normal_loading_metadata.real_device,
703 dtype: vb.dtype(),
704 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
705 max_seq_len: cfg.max_position_embeddings,
706 xlora_classifier: xlora_config.map(|xlora_config| {
707 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
708 }),
709 mapper,
710 cfg: ModelConfigMetadata {
711 max_seq_len: cfg.max_position_embeddings,
712 num_layers: cfg.num_hidden_layers,
713 hidden_size: cfg.hidden_size,
714 num_kv_heads: cfg.num_key_value_heads,
715 num_attn_heads: cfg.num_attention_heads,
716 sliding_window: cfg.sliding_window,
717 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
718 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
719 },
720 })
721 }
722
723 #[allow(clippy::too_many_arguments)]
724 fn inner_forward(
725 &self,
726 input_ids: &Tensor,
727 seqlen_offsets: &[usize],
728 scalings: Option<Tensor>,
729 is_full_pass: bool,
730 no_kv_cache: bool,
731 is_scaling_pass: Option<f64>,
732 flash_params: &FlashParams,
733 ) -> Result<Tensor> {
734 let mut cache = if is_full_pass {
735 if no_kv_cache {
736 let mut new_cache = Vec::new();
737 for _ in 0..self.cache.full().xlora_lock().len() {
738 new_cache.push(None);
739 }
740
741 self.cache.full().xlora_lock().clone_from(&new_cache);
742 }
743 self.cache.full().xlora_lock()
744 } else {
745 self.cache.full().lock()
746 };
747 let mut xs = self.embed_tokens.forward(input_ids)?;
748 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
749 input_ids,
750 &*cache,
751 self.sliding_window,
752 xs.dtype(),
753 self.cfg.num_attn_heads,
754 )?;
755 for (i, layer) in self.layers.iter().enumerate() {
756 xs = self.mapper.map(xs, i)?;
757 xs = layer.forward(
758 &xs,
759 attention_mask
760 .as_ref()
761 .map(|m| m.to_device(xs.device()).unwrap())
762 .as_ref(),
763 seqlen_offsets,
764 &mut cache[i],
765 scalings.clone(),
766 self.xlora_classifier
767 .as_ref()
768 .map(|classifier| classifier.get_global_scaling_weight())
769 .unwrap_or(1.0),
770 is_scaling_pass,
771 flash_params,
772 )?
773 }
774 let xs = xs.to_device(&self.device)?;
775 xs.apply(&self.norm)
776 }
777
778 #[allow(clippy::too_many_arguments)]
779 pub fn forward(
780 &self,
781 input_ids: &Tensor,
782 input_ids_full: &Tensor,
783 seqlen_offsets: &[usize],
784 seqlen_offsets_full: &[usize],
785 no_kv_cache: bool,
786 non_granular_state: &Option<NonGranularState>,
787 context_lens: Vec<(usize, usize)>,
788 flash_params: &FlashParams,
789 flash_params_full: &FlashParams,
790 ) -> Result<Tensor> {
791 if self.xlora_classifier.is_some() {
792 let scalings = self.get_scalings(
793 input_ids,
794 input_ids_full,
795 seqlen_offsets,
796 seqlen_offsets_full,
797 no_kv_cache,
798 non_granular_state,
799 &vec![usize::MAX; context_lens.len()],
800 flash_params,
801 flash_params_full,
802 )?;
803
804 if no_kv_cache {
805 let mut res = self
806 .inner_forward(
807 input_ids_full,
808 seqlen_offsets_full,
809 Some(scalings),
810 true,
811 no_kv_cache,
812 None,
813 flash_params_full,
814 )?
815 .contiguous()?;
816 if let Some(t) = self.lm_head.quantized_act_type() {
817 res = res.to_dtype(t)?;
818 }
819 extract_logits(
820 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
821 context_lens,
822 )
823 } else {
824 let mut res = self
826 .inner_forward(
827 input_ids,
828 seqlen_offsets,
829 Some(scalings),
830 true,
831 no_kv_cache,
832 None,
833 flash_params,
834 )?
835 .contiguous()?;
836 if let Some(t) = self.lm_head.quantized_act_type() {
837 res = res.to_dtype(t)?;
838 }
839 extract_logits(
840 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
841 context_lens,
842 )
843 }
844 } else {
845 let mut res = self
846 .inner_forward(
847 input_ids,
848 seqlen_offsets,
849 None,
850 false,
851 no_kv_cache,
852 None,
853 flash_params,
854 )?
855 .contiguous()?;
856 if let Some(t) = self.lm_head.quantized_act_type() {
857 res = res.to_dtype(t)?;
858 }
859 extract_logits(
860 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
861 context_lens,
862 )
863 }
864 }
865}
866
867impl IsqModel for XLoraModel {
868 fn get_layers(
869 &mut self,
870 ) -> (
871 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
872 &dyn DeviceMapper,
873 ) {
874 let mut tensors = Vec::new();
875 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
876 for (i, layer) in self.layers.iter_mut().enumerate() {
877 tensors.push((
878 Arc::get_mut(&mut layer.self_attn.q_proj)
879 .unwrap()
880 .quant_inner(),
881 Some(i),
882 ));
883 tensors.push((
884 Arc::get_mut(&mut layer.self_attn.k_proj)
885 .unwrap()
886 .quant_inner(),
887 Some(i),
888 ));
889 tensors.push((
890 Arc::get_mut(&mut layer.self_attn.v_proj)
891 .unwrap()
892 .quant_inner(),
893 Some(i),
894 ));
895 tensors.push((
896 Arc::get_mut(&mut layer.self_attn.o_proj)
897 .unwrap()
898 .quant_inner(),
899 Some(i),
900 ));
901 tensors.push((
902 Arc::get_mut(&mut layer.block_sparse_moe.gate)
903 .unwrap()
904 .quant_inner(),
905 Some(i),
906 ));
907 for expert in &mut layer.block_sparse_moe.experts {
908 tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
909 tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
910 tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
911 }
912 }
913 (tensors, &*self.mapper)
914 }
915
916 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
917 panic!("Cannot generate UQFF for an adapter model.")
918 }
919}
920
921impl NormalModel for XLoraModel {
922 fn forward(
923 &self,
924 _input_ids: &Tensor,
925 _seqlen_offsets: &[usize],
926 _context_lens: Vec<(usize, usize)>,
927 _position_ids: Vec<usize>,
928 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
929 _flash_params: &FlashParams,
930 ) -> Result<Tensor> {
931 unreachable!()
932 }
933 fn xlora_forward(
934 &self,
935 input_ids: &Tensor,
936 input_ids_full: &Tensor,
937 seqlen_offsets: &[usize],
938 seqlen_offsets_full: &[usize],
939 no_kv_cache: bool,
940 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
941 context_lens: Vec<(usize, usize)>,
942 _position_ids: Vec<usize>,
943 flash_params: &FlashParams,
944 flash_params_full: &FlashParams,
945 ) -> Result<Tensor> {
946 self.forward(
947 input_ids,
948 input_ids_full,
949 seqlen_offsets,
950 seqlen_offsets_full,
951 no_kv_cache,
952 non_granular_state,
953 context_lens,
954 flash_params,
955 flash_params_full,
956 )
957 }
958 fn cache(&self) -> &EitherCache {
959 &self.cache
960 }
961 fn cache_mut(&mut self) -> &mut EitherCache {
962 &mut self.cache
963 }
964 fn device(&self) -> &Device {
965 &self.device
966 }
967 fn is_xlora(&self) -> bool {
968 true
969 }
970 fn max_seq_len(&self) -> usize {
971 self.max_seq_len
972 }
973 fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
974 if self.xlora_classifier.is_some() {
975 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
976 }
977 let mut sum = 0;
978 for layer in self.layers.iter_mut() {
979 sum += Arc::get_mut(&mut layer.self_attn.k_proj)
980 .unwrap()
981 .activate(&adapter_names)?;
982 sum += Arc::get_mut(&mut layer.self_attn.o_proj)
983 .unwrap()
984 .activate(&adapter_names)?;
985 sum += Arc::get_mut(&mut layer.self_attn.q_proj)
986 .unwrap()
987 .activate(&adapter_names)?;
988 sum += Arc::get_mut(&mut layer.self_attn.v_proj)
989 .unwrap()
990 .activate(&adapter_names)?;
991
992 sum += Arc::get_mut(&mut layer.block_sparse_moe.gate)
993 .unwrap()
994 .activate(&adapter_names)?;
995 for expert in &mut layer.block_sparse_moe.experts {
996 sum += Arc::get_mut(&mut expert.w1)
997 .unwrap()
998 .activate(&adapter_names)?;
999 sum += Arc::get_mut(&mut expert.w2)
1000 .unwrap()
1001 .activate(&adapter_names)?;
1002 sum += Arc::get_mut(&mut expert.w3)
1003 .unwrap()
1004 .activate(&adapter_names)?;
1005 }
1006 }
1007 Ok(sum)
1008 }
1009 fn config(&self) -> &ModelConfigMetadata {
1010 &self.cfg
1011 }
1012}
1013
1014impl ScalingsMaker for XLoraModel {
1015 fn dtype(&self) -> DType {
1016 self.dtype
1017 }
1018 fn get_cache(&self) -> &EitherCache {
1019 &self.cache
1020 }
1021 fn get_classifier(&self) -> &XLoraClassifier {
1022 self.xlora_classifier.as_ref().unwrap()
1023 }
1024 fn forward(
1025 &self,
1026 input_ids: &Tensor,
1027 seqlen_offsets: &[usize],
1028 scalings: Tensor,
1029 is_full_pass: bool,
1030 no_kv_cache: bool,
1031 is_scaling_pass: Option<f64>,
1032 _context_lens: &[usize],
1033 flash_params: &FlashParams,
1034 ) -> Result<Tensor> {
1035 self.inner_forward(
1036 input_ids,
1037 seqlen_offsets,
1038 Some(scalings),
1039 is_full_pass,
1040 no_kv_cache,
1041 is_scaling_pass,
1042 flash_params,
1043 )
1044 }
1045}
1046
1047impl AnyMoeBaseModelMixin for XLoraModel {}