1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{self, Activation, RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel, NormalLoadingMetadata,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Module, Result, Tensor};
19use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
20use std::{collections::HashMap, sync::Arc};
21use tqdm::Iter;
22use tracing::info;
23
24use crate::{
25 device_map::DeviceMapper,
26 layers::{CausalMasker, RmsNorm},
27 models::mixtral::Config,
28 pipeline::{extract_logits, Cache, NormalModel},
29};
30
31use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
32
33struct Attention {
34 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
38 num_heads: usize,
39 num_kv_heads: usize,
40 head_dim: usize,
41 rotary_emb: Arc<RotaryEmbedding>,
42 sliding_window: Option<usize>,
43 sdpa_params: SdpaParams,
44}
45
46impl Attention {
47 #[allow(clippy::too_many_arguments)]
48 fn new(
49 rotary_emb: Arc<RotaryEmbedding>,
50 cfg: &Config,
51 vb: ShardedVarBuilder,
52 lora_config: &[((String, String), LoraConfig)],
53 count: &mut usize,
54 ord: &Ordering,
55 mapper: &dyn DeviceMapper,
56 layer_idx: usize,
57 loading_isq: bool,
58 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
59 ) -> Result<Self> {
60 let hidden_sz = cfg.hidden_size;
61 let num_heads = cfg.num_attention_heads;
62 let num_kv_heads = cfg.num_key_value_heads;
63 let head_dim = hidden_sz / num_heads;
64 let q_proj = linear_no_bias(
65 hidden_sz,
66 num_heads * head_dim,
67 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
68 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
69 lora_config,
70 count,
71 ord,
72 preload_adapters,
73 )?;
74 let k_proj = linear_no_bias(
75 hidden_sz,
76 num_kv_heads * head_dim,
77 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
78 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
79 lora_config,
80 count,
81 ord,
82 preload_adapters,
83 )?;
84 let v_proj = linear_no_bias(
85 hidden_sz,
86 num_kv_heads * head_dim,
87 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
88 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
89 lora_config,
90 count,
91 ord,
92 preload_adapters,
93 )?;
94 let o_proj = linear_no_bias(
95 num_heads * head_dim,
96 hidden_sz,
97 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
98 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
99 lora_config,
100 count,
101 ord,
102 preload_adapters,
103 )?;
104 Ok(Self {
105 q_proj,
106 k_proj,
107 v_proj,
108 o_proj,
109 num_heads,
110 num_kv_heads,
111 head_dim,
112 rotary_emb,
113 sliding_window: cfg.sliding_window,
114 sdpa_params: SdpaParams {
115 n_kv_groups: num_heads / num_kv_heads,
116 softcap: None,
117 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
118 sliding_window: cfg.sliding_window,
119 },
120 })
121 }
122
123 #[allow(clippy::too_many_arguments)]
124 fn forward(
125 &self,
126 xs: &Tensor,
127 attention_mask: Option<&Tensor>,
128 seqlen_offsets: &[usize],
129 kv_cache: &mut Option<(Tensor, Tensor)>,
130 scalings: Option<Tensor>,
131 global_scaling_weight: f64,
132 is_scaling_pass: Option<f64>,
133 flash_params: &FlashParams,
134 ) -> Result<Tensor> {
135 let (b_sz, q_len, _) = xs.dims3()?;
136
137 let original_dtype = xs.dtype();
138 let mut xs = xs.clone();
139 if let Some(t) = self.q_proj.quantized_act_type() {
140 xs = xs.to_dtype(t)?;
141 }
142 let mut q = self.q_proj.lora_forward(
143 &xs,
144 scalings.clone(),
145 global_scaling_weight,
146 is_scaling_pass,
147 )?;
148 let mut k = self.k_proj.lora_forward(
149 &xs,
150 scalings.clone(),
151 global_scaling_weight,
152 is_scaling_pass,
153 )?;
154 let mut v = self.v_proj.lora_forward(
155 &xs,
156 scalings.clone(),
157 global_scaling_weight,
158 is_scaling_pass,
159 )?;
160 if self.q_proj.quantized_act_type().is_some() {
161 q = q.to_dtype(original_dtype)?;
162 k = k.to_dtype(original_dtype)?;
163 v = v.to_dtype(original_dtype)?;
164 }
165
166 let (q, k, v) = if q_len != 1 {
167 let q = q
168 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
169 .transpose(1, 2)?;
170 let k = k
171 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
172 .transpose(1, 2)?;
173 let v = v
174 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
175 .transpose(1, 2)?;
176 (q, k, v)
177 } else {
178 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
179 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
180 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
181 (q, k, v)
182 };
183
184 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
185
186 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
187 kv_cache,
188 k,
189 v,
190 attention_mask,
191 self.sliding_window,
192 )?;
193
194 let mut attn_output = Sdpa.run_attention(
195 &q,
196 &k,
197 &v,
198 attn_mask.as_ref(),
199 Some(flash_params),
200 &self.sdpa_params,
201 )?;
202
203 if let Some(t) = self.q_proj.quantized_act_type() {
204 attn_output = attn_output.to_dtype(t)?;
205 }
206 let mut res = self.o_proj.lora_forward(
207 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
208 scalings.clone(),
209 global_scaling_weight,
210 is_scaling_pass,
211 )?;
212 if self.q_proj.quantized_act_type().is_some() {
213 res = res.to_dtype(original_dtype)?;
214 }
215 Ok(res)
216 }
217}
218
219#[derive(Clone)]
220struct BlockSparseTop2MLP {
221 w1: Arc<dyn LinearLayerLike + Send + Sync>,
222 w2: Arc<dyn LinearLayerLike + Send + Sync>,
223 w3: Arc<dyn LinearLayerLike + Send + Sync>,
224 act_fn: Activation,
225}
226
227impl BlockSparseTop2MLP {
228 #[allow(clippy::too_many_arguments)]
229 fn new(
230 cfg: &Config,
231 vb: ShardedVarBuilder,
232 lora_config: &[((String, String), LoraConfig)],
233 count: &mut usize,
234 ord: &Ordering,
235 mapper: &dyn DeviceMapper,
236 layer_idx: usize,
237 loading_isq: bool,
238 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
239 ) -> Result<Self> {
240 let hidden_sz = cfg.hidden_size;
241 let intermediate_sz = cfg.intermediate_size;
242 let w1 = linear_no_bias(
243 hidden_sz,
244 intermediate_sz,
245 mapper.set_device(layer_idx, vb.pp("w1"), loading_isq),
246 mapper.set_device(layer_idx, vb.pp("w1"), false),
247 lora_config,
248 count,
249 ord,
250 preload_adapters,
251 )?;
252 let w2 = linear_no_bias(
253 intermediate_sz,
254 hidden_sz,
255 mapper.set_device(layer_idx, vb.pp("w2"), loading_isq),
256 mapper.set_device(layer_idx, vb.pp("w2"), false),
257 lora_config,
258 count,
259 ord,
260 preload_adapters,
261 )?;
262 let w3 = linear_no_bias(
263 hidden_sz,
264 intermediate_sz,
265 mapper.set_device(layer_idx, vb.pp("w3"), loading_isq),
266 mapper.set_device(layer_idx, vb.pp("w3"), false),
267 lora_config,
268 count,
269 ord,
270 preload_adapters,
271 )?;
272 Ok(Self {
273 w1,
274 w2,
275 w3,
276 act_fn: cfg.hidden_act,
277 })
278 }
279
280 fn forward(
281 &self,
282 xs: &Tensor,
283 scalings: Option<Tensor>,
284 global_scaling_weight: f64,
285 is_scaling_pass: Option<f64>,
286 ) -> Result<Tensor> {
287 let original_dtype = xs.dtype();
288 let mut xs = xs.clone();
289 if let Some(t) = self.w1.quantized_act_type() {
290 xs = xs.to_dtype(t)?;
291 }
292 let lhs = self
293 .w1
294 .lora_forward(
295 &xs,
296 scalings.clone(),
297 global_scaling_weight,
298 is_scaling_pass,
299 )?
300 .apply(&self.act_fn)?;
301 let rhs = self.w3.lora_forward(
302 &xs,
303 scalings.clone(),
304 global_scaling_weight,
305 is_scaling_pass,
306 )?;
307 let mut res = self.w2.lora_forward(
308 &(lhs * rhs)?,
309 scalings.clone(),
310 global_scaling_weight,
311 is_scaling_pass,
312 )?;
313 if self.w1.quantized_act_type().is_some() {
314 res = res.to_dtype(original_dtype)?;
315 }
316 Ok(res)
317 }
318}
319
320#[derive(Clone)]
321struct SparseMoeBlock {
322 gate: Arc<dyn LinearLayerLike + Send + Sync>,
323 experts: Vec<BlockSparseTop2MLP>,
324 num_experts_per_tok: usize,
325}
326
327impl SparseMoeBlock {
328 #[allow(clippy::too_many_arguments)]
329 fn new(
330 cfg: &Config,
331 vb: ShardedVarBuilder,
332 lora_config: &[((String, String), LoraConfig)],
333 count: &mut usize,
334 ord: &Ordering,
335 mapper: &dyn DeviceMapper,
336 layer_idx: usize,
337 loading_isq: bool,
338 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
339 ) -> Result<Self> {
340 let gate = linear_no_bias(
341 cfg.hidden_size,
342 cfg.num_local_experts,
343 mapper.set_device(layer_idx, vb.pp("gate"), loading_isq),
344 mapper.set_device(layer_idx, vb.pp("gate"), false),
345 lora_config,
346 count,
347 ord,
348 preload_adapters,
349 )?;
350 let mut experts = Vec::with_capacity(cfg.num_local_experts);
351 let vb = vb.pp("experts");
352 for idx in 0..cfg.num_local_experts {
353 let expert = BlockSparseTop2MLP::new(
354 cfg,
355 vb.pp(idx),
356 lora_config,
357 count,
358 ord,
359 mapper,
360 layer_idx,
361 loading_isq,
362 preload_adapters,
363 )?;
364 experts.push(expert)
365 }
366 Ok(SparseMoeBlock {
367 gate,
368 experts,
369 num_experts_per_tok: cfg.num_experts_per_tok,
370 })
371 }
372
373 fn forward(
374 &self,
375 xs: &Tensor,
376 scalings: Option<Tensor>,
377 global_scaling_weight: f64,
378 is_scaling_pass: Option<f64>,
379 ) -> Result<Tensor> {
380 let (b_size, seq_len, hidden_dim) = xs.dims3()?;
381 let xs = xs.reshape(((), hidden_dim))?;
382
383 let original_dtype = xs.dtype();
384 let mut xs = xs.clone();
385 if let Some(t) = self.gate.quantized_act_type() {
386 xs = xs.to_dtype(t)?;
387 }
388 let mut router_logits = self.gate.lora_forward(
389 &xs,
390 scalings.clone(),
391 global_scaling_weight,
392 is_scaling_pass,
393 )?;
394 if self.gate.quantized_act_type().is_some() {
395 router_logits = router_logits.to_dtype(original_dtype)?;
396 }
397
398 let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
399
400 let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
403
404 let mut top_x = vec![vec![]; self.experts.len()];
407 let mut selected_rws = vec![vec![]; self.experts.len()];
408 for (row_idx, rw) in routing_weights.iter().enumerate() {
409 let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
410 dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
411 let mut sum_routing_weights = 0f32;
412 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
413 let expert_idx = expert_idx as usize;
414 let routing_weight = rw[expert_idx];
415 sum_routing_weights += routing_weight;
416 top_x[expert_idx].push(row_idx as u32);
417 }
418 for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
419 let expert_idx = expert_idx as usize;
420 let routing_weight = rw[expert_idx];
421 selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
422 }
423 }
424
425 let mut ys = xs.zeros_like()?;
429 for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
430 let top_x = &top_x[expert_idx];
431 if top_x.is_empty() {
432 continue;
433 }
434 let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
435 let selected_rws =
436 Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
437 let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
441 let current_hidden_states = expert_layer.forward(
443 ¤t_state,
444 scalings.clone(),
445 global_scaling_weight,
446 is_scaling_pass,
447 )?;
448 let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
449 ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
450 }
451
452 let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
453 Ok(ys)
454 }
455}
456
457struct DecoderLayer {
458 self_attn: Attention,
459 block_sparse_moe: SparseMoeBlock,
460 input_layernorm: RmsNorm,
461 post_attention_layernorm: RmsNorm,
462}
463
464impl DecoderLayer {
465 #[allow(clippy::too_many_arguments)]
466 fn new(
467 rotary_emb: Arc<RotaryEmbedding>,
468 cfg: &Config,
469 vb: ShardedVarBuilder,
470 lora_config: &[((String, String), LoraConfig)],
471 count: &mut usize,
472 ord: &Ordering,
473 mapper: &dyn DeviceMapper,
474 layer_idx: usize,
475 loading_isq: bool,
476 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
477 ) -> Result<Self> {
478 let self_attn = Attention::new(
479 rotary_emb,
480 cfg,
481 vb.pp("self_attn"),
482 lora_config,
483 count,
484 ord,
485 mapper,
486 layer_idx,
487 loading_isq,
488 preload_adapters,
489 )?;
490 let block_sparse_moe = SparseMoeBlock::new(
491 cfg,
492 vb.pp("block_sparse_moe"),
493 lora_config,
494 count,
495 ord,
496 mapper,
497 layer_idx,
498 loading_isq,
499 preload_adapters,
500 )?;
501 let input_layernorm = RmsNorm::new(
502 cfg.hidden_size,
503 cfg.rms_norm_eps,
504 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
505 )?;
506 let post_attention_layernorm = RmsNorm::new(
507 cfg.hidden_size,
508 cfg.rms_norm_eps,
509 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
510 )?;
511 Ok(Self {
512 self_attn,
513 block_sparse_moe,
514 input_layernorm,
515 post_attention_layernorm,
516 })
517 }
518
519 #[allow(clippy::too_many_arguments)]
520 fn forward(
521 &self,
522 xs: &Tensor,
523 attention_mask: Option<&Tensor>,
524 seqlen_offsets: &[usize],
525 kv_cache: &mut Option<(Tensor, Tensor)>,
526 scalings: Option<Tensor>,
527 global_scaling_weight: f64,
528 is_scaling_pass: Option<f64>,
529 flash_params: &FlashParams,
530 ) -> Result<Tensor> {
531 let residual = xs;
532 let xs = self.input_layernorm.forward(xs)?;
533 let xs = self.self_attn.forward(
534 &xs,
535 attention_mask,
536 seqlen_offsets,
537 kv_cache,
538 scalings.clone(),
539 global_scaling_weight,
540 is_scaling_pass,
541 flash_params,
542 )?;
543 let xs = (xs + residual)?;
544 let residual = &xs;
545 let xs = self
546 .block_sparse_moe
547 .forward(
548 &xs.apply(&self.post_attention_layernorm)?,
549 scalings.clone(),
550 global_scaling_weight,
551 is_scaling_pass,
552 )?
553 .to_dtype(residual.dtype())?;
554 residual + xs
555 }
556}
557
558pub struct XLoraModel {
559 embed_tokens: candle_nn::Embedding,
560 layers: Vec<DecoderLayer>,
561 norm: RmsNorm,
562 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
563 sliding_window: Option<usize>,
564 device: Device,
565 cache: EitherCache,
566 dtype: DType,
567 max_seq_len: usize,
568 xlora_classifier: Option<XLoraClassifier>,
569 mapper: Box<dyn DeviceMapper + Send + Sync>,
570 cfg: ModelConfigMetadata,
571}
572
573impl XLoraModel {
574 #[allow(clippy::too_many_arguments)]
575 pub fn new(
576 cfg: &Config,
577 vb: ShardedVarBuilder,
578 lora_config: &[((String, String), LoraConfig)],
579 xlora_config: Option<XLoraConfig>,
580 xlora_ordering: Ordering,
581 is_gptx: bool,
582 normal_loading_metadata: NormalLoadingMetadata,
583 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
584 ) -> Result<Self> {
585 if let Some(ref quant_cfg) = &cfg.quantization_config {
586 tracing::info!(
587 "Using {} quantization: {}.",
588 quant_cfg.name(),
589 quant_cfg.get_bits_name(&vb)
590 );
591 }
592 let mapper = normal_loading_metadata.mapper;
593 let vb_m = vb.pp("model");
594
595 let embed_tokens = layers::embedding(
596 cfg.vocab_size,
597 cfg.hidden_size,
598 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
599 &cfg.quantization_config,
600 )?;
601 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
602 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
603 let vb_l = vb_m.pp("layers");
604 let mut ropes = HashMap::new();
605 for layer_idx in 0..cfg.num_hidden_layers {
606 let device = mapper
607 .device_for(layer_idx, false)
608 .unwrap_or(&normal_loading_metadata.real_device);
609 ropes.insert(
610 device.location(),
611 Arc::new(RotaryEmbedding::new(
612 cfg.rope_theta as f32,
613 head_dim,
614 cfg.max_position_embeddings,
615 device,
616 is_gptx,
617 vb_m.dtype(),
618 )?),
619 );
620 }
621
622 let mut count = 0;
623 for layer_idx in NiceProgressBar::<_, 'b'>(
624 0..cfg.num_hidden_layers,
625 "Loading repeating layers",
626 &normal_loading_metadata.multi_progress,
627 ) {
628 let device = mapper
629 .device_for(layer_idx, false)
630 .unwrap_or(&normal_loading_metadata.real_device);
631 let rotary_emb = ropes
632 .get(&device.location())
633 .expect("No RoPE for device location!")
634 .clone();
635 let layer = DecoderLayer::new(
636 rotary_emb.clone(),
637 cfg,
638 vb_l.pp(layer_idx),
639 lora_config,
640 &mut count,
641 &xlora_ordering,
642 &*mapper,
643 layer_idx,
644 normal_loading_metadata.loading_isq,
645 preload_adapters,
646 )?;
647 layers.push(layer)
648 }
649 if xlora_config.is_none() && preload_adapters.is_none() {
650 info!("Merging LoRA adapters.");
652 for layer in layers.iter_mut().tqdm() {
653 Arc::get_mut(&mut layer.self_attn.k_proj)
654 .unwrap()
655 .merge_weights()?;
656 Arc::get_mut(&mut layer.self_attn.o_proj)
657 .unwrap()
658 .merge_weights()?;
659 Arc::get_mut(&mut layer.self_attn.q_proj)
660 .unwrap()
661 .merge_weights()?;
662 Arc::get_mut(&mut layer.self_attn.v_proj)
663 .unwrap()
664 .merge_weights()?;
665
666 Arc::get_mut(&mut layer.block_sparse_moe.gate)
667 .unwrap()
668 .merge_weights()?;
669 for expert in layer.block_sparse_moe.experts.iter_mut() {
670 Arc::get_mut(&mut expert.w1).unwrap().merge_weights()?;
671 Arc::get_mut(&mut expert.w2).unwrap().merge_weights()?;
672 Arc::get_mut(&mut expert.w3).unwrap().merge_weights()?;
673 }
674 }
675 }
676 let norm = RmsNorm::new(
677 cfg.hidden_size,
678 cfg.rms_norm_eps,
679 mapper.set_nm_device(vb_m.pp("norm"), false),
680 )?;
681 let lm_head = linear_no_bias(
682 cfg.hidden_size,
683 cfg.vocab_size,
684 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
685 mapper.set_nm_device(vb.pp("lm_head"), false),
686 lora_config,
687 &mut count,
688 &xlora_ordering,
689 preload_adapters,
690 )?;
691 if xlora_config.is_some() && lm_head.is_lora() {
692 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
694 }
695 Ok(Self {
696 embed_tokens,
697 layers,
698 norm,
699 lm_head,
700 sliding_window: cfg.sliding_window,
701 device: normal_loading_metadata.real_device,
702 dtype: vb.dtype(),
703 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, false)),
704 max_seq_len: cfg.max_position_embeddings,
705 xlora_classifier: xlora_config.map(|xlora_config| {
706 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
707 }),
708 mapper,
709 cfg: ModelConfigMetadata {
710 max_seq_len: cfg.max_position_embeddings,
711 num_layers: cfg.num_hidden_layers,
712 hidden_size: cfg.hidden_size,
713 num_kv_heads: cfg.num_key_value_heads,
714 num_attn_heads: cfg.num_attention_heads,
715 sliding_window: cfg.sliding_window,
716 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
717 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
718 },
719 })
720 }
721
722 #[allow(clippy::too_many_arguments)]
723 fn inner_forward(
724 &self,
725 input_ids: &Tensor,
726 seqlen_offsets: &[usize],
727 scalings: Option<Tensor>,
728 is_full_pass: bool,
729 no_kv_cache: bool,
730 is_scaling_pass: Option<f64>,
731 flash_params: &FlashParams,
732 ) -> Result<Tensor> {
733 let mut cache = if is_full_pass {
734 if no_kv_cache {
735 let mut new_cache = Vec::new();
736 for _ in 0..self.cache.full().xlora_lock().len() {
737 new_cache.push(None);
738 }
739
740 self.cache.full().xlora_lock().clone_from(&new_cache);
741 }
742 self.cache.full().xlora_lock()
743 } else {
744 self.cache.full().lock()
745 };
746 let mut xs = self.embed_tokens.forward(input_ids)?;
747 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
748 input_ids,
749 &*cache,
750 self.sliding_window,
751 xs.dtype(),
752 self.cfg.num_attn_heads,
753 )?;
754 for (i, layer) in self.layers.iter().enumerate() {
755 xs = self.mapper.map(xs, i)?;
756 xs = layer.forward(
757 &xs,
758 attention_mask
759 .as_ref()
760 .map(|m| m.to_device(xs.device()).unwrap())
761 .as_ref(),
762 seqlen_offsets,
763 &mut cache[i],
764 scalings.clone(),
765 self.xlora_classifier
766 .as_ref()
767 .map(|classifier| classifier.get_global_scaling_weight())
768 .unwrap_or(1.0),
769 is_scaling_pass,
770 flash_params,
771 )?
772 }
773 let xs = xs.to_device(&self.device)?;
774 xs.apply(&self.norm)
775 }
776
777 #[allow(clippy::too_many_arguments)]
778 pub fn forward(
779 &self,
780 input_ids: &Tensor,
781 input_ids_full: &Tensor,
782 seqlen_offsets: &[usize],
783 seqlen_offsets_full: &[usize],
784 no_kv_cache: bool,
785 non_granular_state: &Option<NonGranularState>,
786 context_lens: Vec<(usize, usize)>,
787 flash_params: &FlashParams,
788 flash_params_full: &FlashParams,
789 ) -> Result<Tensor> {
790 if self.xlora_classifier.is_some() {
791 let scalings = self.get_scalings(
792 input_ids,
793 input_ids_full,
794 seqlen_offsets,
795 seqlen_offsets_full,
796 no_kv_cache,
797 non_granular_state,
798 &vec![usize::MAX; context_lens.len()],
799 flash_params,
800 flash_params_full,
801 )?;
802
803 if no_kv_cache {
804 let mut res = self
805 .inner_forward(
806 input_ids_full,
807 seqlen_offsets_full,
808 Some(scalings),
809 true,
810 no_kv_cache,
811 None,
812 flash_params_full,
813 )?
814 .contiguous()?;
815 if let Some(t) = self.lm_head.quantized_act_type() {
816 res = res.to_dtype(t)?;
817 }
818 extract_logits(
819 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
820 context_lens,
821 )
822 } else {
823 let mut res = self
825 .inner_forward(
826 input_ids,
827 seqlen_offsets,
828 Some(scalings),
829 true,
830 no_kv_cache,
831 None,
832 flash_params,
833 )?
834 .contiguous()?;
835 if let Some(t) = self.lm_head.quantized_act_type() {
836 res = res.to_dtype(t)?;
837 }
838 extract_logits(
839 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
840 context_lens,
841 )
842 }
843 } else {
844 let mut res = self
845 .inner_forward(
846 input_ids,
847 seqlen_offsets,
848 None,
849 false,
850 no_kv_cache,
851 None,
852 flash_params,
853 )?
854 .contiguous()?;
855 if let Some(t) = self.lm_head.quantized_act_type() {
856 res = res.to_dtype(t)?;
857 }
858 extract_logits(
859 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
860 context_lens,
861 )
862 }
863 }
864}
865
866impl IsqModel for XLoraModel {
867 fn get_layers(
868 &mut self,
869 ) -> (
870 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
871 &dyn DeviceMapper,
872 ) {
873 let mut tensors = Vec::new();
874 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
875 for (i, layer) in self.layers.iter_mut().enumerate() {
876 tensors.push((
877 Arc::get_mut(&mut layer.self_attn.q_proj)
878 .unwrap()
879 .quant_inner(),
880 Some(i),
881 ));
882 tensors.push((
883 Arc::get_mut(&mut layer.self_attn.k_proj)
884 .unwrap()
885 .quant_inner(),
886 Some(i),
887 ));
888 tensors.push((
889 Arc::get_mut(&mut layer.self_attn.v_proj)
890 .unwrap()
891 .quant_inner(),
892 Some(i),
893 ));
894 tensors.push((
895 Arc::get_mut(&mut layer.self_attn.o_proj)
896 .unwrap()
897 .quant_inner(),
898 Some(i),
899 ));
900 tensors.push((
901 Arc::get_mut(&mut layer.block_sparse_moe.gate)
902 .unwrap()
903 .quant_inner(),
904 Some(i),
905 ));
906 for expert in &mut layer.block_sparse_moe.experts {
907 tensors.push((Arc::get_mut(&mut expert.w1).unwrap().quant_inner(), Some(i)));
908 tensors.push((Arc::get_mut(&mut expert.w2).unwrap().quant_inner(), Some(i)));
909 tensors.push((Arc::get_mut(&mut expert.w3).unwrap().quant_inner(), Some(i)));
910 }
911 }
912 (tensors, &*self.mapper)
913 }
914
915 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
916 panic!("Cannot generate UQFF for an adapter model.")
917 }
918}
919
920impl NormalModel for XLoraModel {
921 fn forward(
922 &self,
923 _input_ids: &Tensor,
924 _seqlen_offsets: &[usize],
925 _context_lens: Vec<(usize, usize)>,
926 _position_ids: Vec<usize>,
927 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
928 _flash_params: &FlashParams,
929 ) -> Result<Tensor> {
930 unreachable!()
931 }
932 fn xlora_forward(
933 &self,
934 input_ids: &Tensor,
935 input_ids_full: &Tensor,
936 seqlen_offsets: &[usize],
937 seqlen_offsets_full: &[usize],
938 no_kv_cache: bool,
939 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
940 context_lens: Vec<(usize, usize)>,
941 _position_ids: Vec<usize>,
942 flash_params: &FlashParams,
943 flash_params_full: &FlashParams,
944 ) -> Result<Tensor> {
945 self.forward(
946 input_ids,
947 input_ids_full,
948 seqlen_offsets,
949 seqlen_offsets_full,
950 no_kv_cache,
951 non_granular_state,
952 context_lens,
953 flash_params,
954 flash_params_full,
955 )
956 }
957 fn cache(&self) -> &EitherCache {
958 &self.cache
959 }
960 fn cache_mut(&mut self) -> &mut EitherCache {
961 &mut self.cache
962 }
963 fn device(&self) -> &Device {
964 &self.device
965 }
966 fn is_xlora(&self) -> bool {
967 true
968 }
969 fn max_seq_len(&self) -> usize {
970 self.max_seq_len
971 }
972 fn config(&self) -> &ModelConfigMetadata {
973 &self.cfg
974 }
975}
976
977impl ScalingsMaker for XLoraModel {
978 fn dtype(&self) -> DType {
979 self.dtype
980 }
981 fn get_cache(&self) -> &EitherCache {
982 &self.cache
983 }
984 fn get_classifier(&self) -> &XLoraClassifier {
985 self.xlora_classifier.as_ref().unwrap()
986 }
987 fn forward(
988 &self,
989 input_ids: &Tensor,
990 seqlen_offsets: &[usize],
991 scalings: Tensor,
992 is_full_pass: bool,
993 no_kv_cache: bool,
994 is_scaling_pass: Option<f64>,
995 _context_lens: &[usize],
996 flash_params: &FlashParams,
997 ) -> Result<Tensor> {
998 self.inner_forward(
999 input_ids,
1000 seqlen_offsets,
1001 Some(scalings),
1002 is_full_pass,
1003 no_kv_cache,
1004 is_scaling_pass,
1005 flash_params,
1006 )
1007 }
1008}
1009
1010impl AnyMoeBaseModelMixin for XLoraModel {}