1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{Llama3RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::{embedding, CausalMasker, RmsNorm},
25 models::llama::Config,
26 pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31struct CausalSelfAttention {
32 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 num_attention_heads: usize,
37 num_key_value_heads: usize,
38 head_dim: usize,
39 rotary_emb: Arc<Llama3RotaryEmbedding>,
40 max_seq_len: usize,
41 sdpa_params: SdpaParams,
42}
43
44impl CausalSelfAttention {
45 #[allow(clippy::too_many_arguments)]
46 fn forward(
47 &self,
48 x: &Tensor,
49 mask: &Option<Tensor>,
50 seqlen_offsets: &[usize],
51 block_idx: usize,
52 kv_cache: &mut LayerCaches,
53 scalings: Option<Tensor>,
54 global_scaling_weight: f64,
55 is_scaling_pass: Option<f64>,
56 flash_params: &FlashParams,
57 ) -> Result<Tensor> {
58 let (b_sz, seq_len, hidden_size) = x.dims3()?;
59
60 let original_dtype = x.dtype();
61 let mut x = x.clone();
62 if let Some(t) = self.q_proj.quantized_act_type() {
63 x = x.to_dtype(t)?;
64 }
65 let mut q = self.q_proj.lora_forward(
66 &x,
67 scalings.clone(),
68 global_scaling_weight,
69 is_scaling_pass,
70 )?;
71 let mut k = self.k_proj.lora_forward(
72 &x,
73 scalings.clone(),
74 global_scaling_weight,
75 is_scaling_pass,
76 )?;
77 let mut v = self.v_proj.lora_forward(
78 &x,
79 scalings.clone(),
80 global_scaling_weight,
81 is_scaling_pass,
82 )?;
83 if self.q_proj.quantized_act_type().is_some() {
84 q = q.to_dtype(original_dtype)?;
85 k = k.to_dtype(original_dtype)?;
86 v = v.to_dtype(original_dtype)?;
87 }
88
89 let (q, k, v) = if seq_len != 1 {
90 let q = q
91 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
92 .transpose(1, 2)?;
93 let k = k
94 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
95 .transpose(1, 2)?;
96 let v = v
97 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
98 .transpose(1, 2)?;
99 (q, k, v)
100 } else {
101 let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
102 let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
103 let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
104 (q, k, v)
105 };
106
107 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
108
109 let (k, v) =
110 crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?;
111
112 let y = Sdpa.run_attention(
113 &q,
114 &k,
115 &v,
116 mask.clone().as_ref(),
117 Some(flash_params),
118 &self.sdpa_params,
119 )?;
120
121 let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
122 if let Some(t) = self.q_proj.quantized_act_type() {
123 y = y.to_dtype(t)?;
124 }
125 let mut res = self.o_proj.lora_forward(
126 &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?,
127 scalings.clone(),
128 global_scaling_weight,
129 is_scaling_pass,
130 )?;
131 if self.q_proj.quantized_act_type().is_some() {
132 res = res.to_dtype(original_dtype)?;
133 }
134 Ok(res)
135 }
136
137 #[allow(clippy::too_many_arguments)]
138 fn load(
139 vb: ShardedVarBuilder,
140 cfg: &Config,
141 lora_config: &[((String, String), LoraConfig)],
142 count: &mut usize,
143 ord: &Ordering,
144 mapper: &dyn DeviceMapper,
145 layer_idx: usize,
146 loading_isq: bool,
147 rope: Arc<Llama3RotaryEmbedding>,
148 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
149 ) -> Result<Self> {
150 let size_in = cfg.hidden_size;
151 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
152 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
153 let q_proj = linear(
154 size_in,
155 size_q,
156 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
157 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
158 lora_config,
159 count,
160 ord,
161 preload_adapters,
162 )?;
163 let k_proj = linear(
164 size_in,
165 size_kv,
166 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
167 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
168 lora_config,
169 count,
170 ord,
171 preload_adapters,
172 )?;
173 let v_proj = linear(
174 size_in,
175 size_kv,
176 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
177 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
178 lora_config,
179 count,
180 ord,
181 preload_adapters,
182 )?;
183 let o_proj = linear(
184 size_q,
185 size_in,
186 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
187 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
188 lora_config,
189 count,
190 ord,
191 preload_adapters,
192 )?;
193 Ok(Self {
194 q_proj,
195 k_proj,
196 v_proj,
197 o_proj,
198 num_attention_heads: cfg.num_attention_heads,
199 num_key_value_heads: cfg.num_key_value_heads,
200 head_dim: cfg.hidden_size / cfg.num_attention_heads,
201 rotary_emb: rope,
202 max_seq_len: cfg.max_position_embeddings,
203 sdpa_params: SdpaParams {
204 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
205 use_flash_attn: cfg.use_flash_attn,
206 softcap: None,
207 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
208 sliding_window: None,
209 },
210 })
211 }
212}
213
214#[derive(Clone)]
215struct Mlp {
216 c_fc1: Arc<dyn LinearLayerLike + Send + Sync>,
217 c_fc2: Arc<dyn LinearLayerLike + Send + Sync>,
218 c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
219}
220
221impl Mlp {
222 fn forward(
223 &self,
224 x: &Tensor,
225 scalings: Option<Tensor>,
226 global_scaling_weight: f64,
227 is_scaling_pass: Option<f64>,
228 ) -> Result<Tensor> {
229 let original_dtype = x.dtype();
230 let mut x = x.clone();
231 if let Some(t) = self.c_fc1.quantized_act_type() {
232 x = x.to_dtype(t)?;
233 }
234 let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward(
235 &x,
236 scalings.clone(),
237 global_scaling_weight,
238 is_scaling_pass,
239 )?)? * self.c_fc2.lora_forward(
240 &x,
241 scalings.clone(),
242 global_scaling_weight,
243 is_scaling_pass,
244 )?)?;
245 let mut res = self.c_proj.lora_forward(
246 &x,
247 scalings.clone(),
248 global_scaling_weight,
249 is_scaling_pass,
250 )?;
251 if self.c_fc1.quantized_act_type().is_some() {
252 res = res.to_dtype(original_dtype)?;
253 }
254 Ok(res)
255 }
256
257 #[allow(clippy::too_many_arguments)]
258 fn load(
259 vb: ShardedVarBuilder,
260 cfg: &Config,
261 lora_config: &[((String, String), LoraConfig)],
262 count: &mut usize,
263 ord: &Ordering,
264 mapper: &dyn DeviceMapper,
265 layer_idx: usize,
266 loading_isq: bool,
267 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
268 ) -> Result<Self> {
269 let h_size = cfg.hidden_size;
270 let i_size = cfg.intermediate_size;
271 let c_fc1 = linear(
272 h_size,
273 i_size,
274 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
275 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
276 lora_config,
277 count,
278 ord,
279 preload_adapters,
280 )?;
281 let c_fc2 = linear(
282 h_size,
283 i_size,
284 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
285 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
286 lora_config,
287 count,
288 ord,
289 preload_adapters,
290 )?;
291 let c_proj = linear(
292 i_size,
293 h_size,
294 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
295 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
296 lora_config,
297 count,
298 ord,
299 preload_adapters,
300 )?;
301 Ok(Self {
302 c_fc1,
303 c_fc2,
304 c_proj,
305 })
306 }
307}
308
309struct Block {
310 rms_1: RmsNorm,
311 attn: CausalSelfAttention,
312 rms_2: RmsNorm,
313 mlp: Mlp,
314}
315
316impl Block {
317 #[allow(clippy::too_many_arguments)]
318 fn forward(
319 &self,
320 x: &Tensor,
321 mask: &Option<Tensor>,
322 seqlen_offsets: &[usize],
323 block_idx: usize,
324 kv_cache: &mut LayerCaches,
325 scalings: Option<Tensor>,
326 global_scaling_weight: f64,
327 is_scaling_pass: Option<f64>,
328 flash_params: &FlashParams,
329 ) -> Result<Tensor> {
330 let residual = x;
331 let x = self.rms_1.forward(x)?;
332 let x = (self.attn.forward(
333 &x,
334 mask,
335 seqlen_offsets,
336 block_idx,
337 kv_cache,
338 scalings.clone(),
339 global_scaling_weight,
340 is_scaling_pass,
341 flash_params,
342 )? + residual)?;
343 let residual = &x;
344 let x = (self.mlp.forward(
345 &self.rms_2.forward(&x)?,
346 scalings,
347 global_scaling_weight,
348 is_scaling_pass,
349 )? + residual)?;
350 Ok(x)
351 }
352
353 #[allow(clippy::too_many_arguments)]
354 fn load(
355 vb: ShardedVarBuilder,
356 cfg: &Config,
357 lora_config: &[((String, String), LoraConfig)],
358 count: &mut usize,
359 ord: &Ordering,
360 mapper: &dyn DeviceMapper,
361 layer_idx: usize,
362 loading_isq: bool,
363 rope: Arc<Llama3RotaryEmbedding>,
364 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
365 ) -> Result<Self> {
366 let attn = CausalSelfAttention::load(
367 vb.pp("self_attn"),
368 cfg,
369 lora_config,
370 count,
371 ord,
372 mapper,
373 layer_idx,
374 loading_isq,
375 rope,
376 preload_adapters,
377 )?;
378 let mlp = Mlp::load(
379 vb.pp("mlp"),
380 cfg,
381 lora_config,
382 count,
383 ord,
384 mapper,
385 layer_idx,
386 loading_isq,
387 preload_adapters,
388 )?;
389 let rms_1 = RmsNorm::new(
390 cfg.hidden_size,
391 cfg.rms_norm_eps,
392 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
393 )?;
394 let rms_2 = RmsNorm::new(
395 cfg.hidden_size,
396 cfg.rms_norm_eps,
397 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
398 )?;
399 Ok(Self {
400 rms_1,
401 attn,
402 rms_2,
403 mlp,
404 })
405 }
406}
407
408pub struct XLoraLlama {
409 wte: Embedding,
410 blocks: Vec<Block>,
411 ln_f: RmsNorm,
412 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
413 kv_cache: pipeline::EitherCache,
414 device: Device,
415 xlora_classifier: Option<XLoraClassifier>,
416 dtype: DType,
417 mapper: Box<dyn DeviceMapper + Send + Sync>,
418 cfg: ModelConfigMetadata,
419}
420
421impl XLoraLlama {
422 #[allow(clippy::too_many_arguments)]
423 fn inner_forward(
424 &self,
425 input_ids: &Tensor,
426 seqlen_offsets: &[usize],
427 scalings: Option<Tensor>,
428 is_full_pass: bool,
429 no_kv_cache: bool,
430 is_scaling_pass: Option<f64>,
431 flash_params: &FlashParams,
432 ) -> Result<Tensor> {
433 let mut x = self.wte.forward(input_ids)?;
434 let mut cache = if is_full_pass {
435 if no_kv_cache {
436 let mut new_cache = Vec::new();
437 for _ in 0..self.kv_cache.full().xlora_lock().len() {
438 new_cache.push(None);
439 }
440
441 self.kv_cache.full().xlora_lock().clone_from(&new_cache);
442 }
443 self.kv_cache.full().xlora_lock()
444 } else {
445 self.kv_cache.full().lock()
446 };
447 let mask = CausalMasker.make_causal_mask_matrix(
448 input_ids,
449 &*cache,
450 x.dtype(),
451 self.cfg.num_attn_heads,
452 )?;
453 for (block_idx, block) in self.blocks.iter().enumerate() {
454 x = self.mapper.map(x, block_idx)?;
455 x = block.forward(
456 &x,
457 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
458 seqlen_offsets,
459 block_idx,
460 &mut cache,
461 scalings.clone(),
462 self.xlora_classifier
463 .as_ref()
464 .map(|classifier| classifier.get_global_scaling_weight())
465 .unwrap_or(1.0),
466 is_scaling_pass,
467 flash_params,
468 )?;
469 }
470 let x = x.to_device(&self.device)?;
471 self.ln_f.forward(&x)
472 }
473
474 #[allow(clippy::too_many_arguments)]
475 pub fn forward(
476 &self,
477 input_ids: &Tensor,
478 input_ids_full: &Tensor,
479 seqlen_offsets: &[usize],
480 seqlen_offsets_full: &[usize],
481 no_kv_cache: bool,
482 non_granular_state: &Option<NonGranularState>,
483 context_lens: Vec<(usize, usize)>,
484 flash_params: &FlashParams,
485 flash_params_full: &FlashParams,
486 ) -> Result<Tensor> {
487 if self.xlora_classifier.is_some() {
488 let scalings = self.get_scalings(
489 input_ids,
490 input_ids_full,
491 seqlen_offsets,
492 seqlen_offsets_full,
493 no_kv_cache,
494 non_granular_state,
495 &vec![usize::MAX; context_lens.len()],
496 flash_params,
497 flash_params_full,
498 )?;
499
500 if no_kv_cache {
501 let mut res = self
502 .inner_forward(
503 input_ids_full,
504 seqlen_offsets_full,
505 Some(scalings),
506 true,
507 no_kv_cache,
508 None,
509 flash_params_full,
510 )?
511 .contiguous()?;
512 if let Some(t) = self.lm_head.quantized_act_type() {
513 res = res.to_dtype(t)?;
514 }
515 extract_logits(
516 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
517 context_lens,
518 )
519 } else {
520 let mut res = self
522 .inner_forward(
523 input_ids,
524 seqlen_offsets,
525 Some(scalings),
526 true,
527 no_kv_cache,
528 None,
529 flash_params,
530 )?
531 .contiguous()?;
532 if let Some(t) = self.lm_head.quantized_act_type() {
533 res = res.to_dtype(t)?;
534 }
535 extract_logits(
536 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
537 context_lens,
538 )
539 }
540 } else {
541 let mut res = self
542 .inner_forward(
543 input_ids,
544 seqlen_offsets,
545 None,
546 false,
547 no_kv_cache,
548 None,
549 flash_params,
550 )?
551 .contiguous()?;
552 if let Some(t) = self.lm_head.quantized_act_type() {
553 res = res.to_dtype(t)?;
554 }
555 extract_logits(
556 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
557 context_lens,
558 )
559 }
560 }
561
562 #[allow(clippy::too_many_arguments)]
563 pub fn new(
564 cfg: &Config,
565 vb: ShardedVarBuilder,
566 lora_config: &[((String, String), LoraConfig)],
567 xlora_config: Option<XLoraConfig>,
568 xlora_ordering: Ordering,
569 is_gptx: bool,
570 normal_loading_metadata: NormalLoadingMetadata,
571 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
572 ) -> Result<Self> {
573 if let Some(ref quant_cfg) = &cfg.quantization_config {
574 tracing::info!(
575 "Using {} quantization: {}.",
576 quant_cfg.quant_method.to_string(),
577 quant_cfg.get_bits_name(&vb)
578 );
579 }
580 let mapper = normal_loading_metadata.mapper;
581 let dtype = vb.dtype();
582 let mut count = 0;
583
584 let wte = embedding(
585 cfg.vocab_size,
586 cfg.hidden_size,
587 mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
588 )?;
589 let lm_head = linear(
590 cfg.hidden_size,
591 cfg.vocab_size,
592 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
593 mapper.set_nm_device(vb.pp("lm_head"), false),
594 lora_config,
595 &mut count,
596 &xlora_ordering,
597 preload_adapters,
598 )?;
599 if xlora_config.is_some() && lm_head.is_lora() {
600 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
602 }
603 let ln_f = RmsNorm::new(
604 cfg.hidden_size,
605 cfg.rms_norm_eps,
606 mapper.set_nm_device(vb.pp("model.norm"), false),
607 )?;
608 let mut ropes = HashMap::new();
609 for i in 0..cfg.num_hidden_layers {
610 let device = mapper
611 .device_for(i, false)
612 .unwrap_or(&normal_loading_metadata.real_device);
613 ropes.insert(
614 device.location(),
615 Arc::new(Llama3RotaryEmbedding::new_llama3(
616 vb.dtype(),
617 cfg,
618 device,
619 is_gptx,
620 )?),
621 );
622 }
623 let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
624 0..cfg.num_hidden_layers,
625 "Loading repeating layers",
626 &normal_loading_metadata.multi_progress,
627 )
628 .into_iter()
629 .map(|i| {
630 let device = mapper
631 .device_for(i, false)
632 .unwrap_or(&normal_loading_metadata.real_device);
633 let rotary_emb = ropes
634 .get(&device.location())
635 .expect("No RoPE for device location!")
636 .clone();
637 Block::load(
638 vb.pp(format!("model.layers.{i}")),
639 cfg,
640 lora_config,
641 &mut count,
642 &xlora_ordering,
643 &*mapper,
644 i,
645 normal_loading_metadata.loading_isq,
646 rotary_emb,
647 preload_adapters,
648 )
649 .expect("Failed to load block.")
650 })
651 .collect();
652 if xlora_config.is_none() && preload_adapters.is_none() {
653 info!("Merging LoRA adapters.");
655 for layer in blocks.iter_mut().tqdm() {
656 Arc::get_mut(&mut layer.attn.k_proj)
657 .unwrap()
658 .merge_weights()?;
659 Arc::get_mut(&mut layer.attn.o_proj)
660 .unwrap()
661 .merge_weights()?;
662 Arc::get_mut(&mut layer.attn.q_proj)
663 .unwrap()
664 .merge_weights()?;
665 Arc::get_mut(&mut layer.attn.v_proj)
666 .unwrap()
667 .merge_weights()?;
668
669 Arc::get_mut(&mut layer.mlp.c_fc1)
670 .unwrap()
671 .merge_weights()?;
672 Arc::get_mut(&mut layer.mlp.c_fc2)
673 .unwrap()
674 .merge_weights()?;
675 Arc::get_mut(&mut layer.mlp.c_proj)
676 .unwrap()
677 .merge_weights()?;
678 }
679 }
680
681 Ok(Self {
682 wte,
683 blocks,
684 ln_f,
685 lm_head,
686 kv_cache: EitherCache::Full(pipeline::Cache::new(cfg.num_hidden_layers, true)),
687 device: normal_loading_metadata.real_device,
688 xlora_classifier: xlora_config.map(|xlora_config| {
689 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
690 }),
691 dtype,
692 mapper,
693 cfg: ModelConfigMetadata {
694 max_seq_len: cfg.max_position_embeddings,
695 num_layers: cfg.num_hidden_layers,
696 hidden_size: cfg.hidden_size,
697 num_kv_heads: cfg.num_key_value_heads,
698 num_attn_heads: cfg.num_attention_heads,
699 sliding_window: None,
700 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
702 },
703 })
704 }
705}
706
707impl IsqModel for XLoraLlama {
708 fn get_layers(
709 &mut self,
710 ) -> (
711 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
712 &dyn DeviceMapper,
713 ) {
714 let mut tensors = Vec::new();
715 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
716 for (i, layer) in self.blocks.iter_mut().enumerate() {
717 tensors.push((
718 Arc::get_mut(&mut layer.attn.q_proj).unwrap().quant_inner(),
719 Some(i),
720 ));
721 tensors.push((
722 Arc::get_mut(&mut layer.attn.k_proj).unwrap().quant_inner(),
723 Some(i),
724 ));
725 tensors.push((
726 Arc::get_mut(&mut layer.attn.v_proj).unwrap().quant_inner(),
727 Some(i),
728 ));
729 tensors.push((
730 Arc::get_mut(&mut layer.attn.o_proj).unwrap().quant_inner(),
731 Some(i),
732 ));
733 tensors.push((
734 Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().quant_inner(),
735 Some(i),
736 ));
737 tensors.push((
738 Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().quant_inner(),
739 Some(i),
740 ));
741 tensors.push((
742 Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
743 Some(i),
744 ));
745 }
746 (tensors, &*self.mapper)
747 }
748
749 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
750 panic!("Cannot generate UQFF for an adapter model.")
751 }
752}
753
754impl NormalModel for XLoraLlama {
755 fn forward(
756 &self,
757 _input_ids: &Tensor,
758 _seqlen_offsets: &[usize],
759 _context_lens: Vec<(usize, usize)>,
760 _position_ids: Vec<usize>,
761 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
762 _flash_params: &FlashParams,
763 ) -> Result<Tensor> {
764 unreachable!()
765 }
766 fn xlora_forward(
767 &self,
768 input_ids: &Tensor,
769 input_ids_full: &Tensor,
770 seqlen_offsets: &[usize],
771 seqlen_offsets_full: &[usize],
772 no_kv_cache: bool,
773 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
774 context_lens: Vec<(usize, usize)>,
775 _position_ids: Vec<usize>,
776 flash_params: &FlashParams,
777 flash_params_full: &FlashParams,
778 ) -> Result<Tensor> {
779 self.forward(
780 input_ids,
781 input_ids_full,
782 seqlen_offsets,
783 seqlen_offsets_full,
784 no_kv_cache,
785 non_granular_state,
786 context_lens,
787 flash_params,
788 flash_params_full,
789 )
790 }
791 fn cache(&self) -> &super::EitherCache {
792 &self.kv_cache
793 }
794 fn cache_mut(&mut self) -> &mut super::EitherCache {
795 &mut self.kv_cache
796 }
797 fn device(&self) -> &Device {
798 &self.device
799 }
800 fn is_xlora(&self) -> bool {
801 true
802 }
803 fn max_seq_len(&self) -> usize {
804 self.blocks[0].attn.max_seq_len
805 }
806 fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
807 if self.xlora_classifier.is_some() {
808 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
809 }
810 let mut sum = 0;
811 for layer in self.blocks.iter_mut() {
812 sum += Arc::get_mut(&mut layer.attn.k_proj)
813 .unwrap()
814 .activate(&adapter_names)?;
815 sum += Arc::get_mut(&mut layer.attn.o_proj)
816 .unwrap()
817 .activate(&adapter_names)?;
818 sum += Arc::get_mut(&mut layer.attn.q_proj)
819 .unwrap()
820 .activate(&adapter_names)?;
821 sum += Arc::get_mut(&mut layer.attn.v_proj)
822 .unwrap()
823 .activate(&adapter_names)?;
824
825 sum += Arc::get_mut(&mut layer.mlp.c_fc1)
826 .unwrap()
827 .activate(&adapter_names)?;
828 sum += Arc::get_mut(&mut layer.mlp.c_fc2)
829 .unwrap()
830 .activate(&adapter_names)?;
831 sum += Arc::get_mut(&mut layer.mlp.c_proj)
832 .unwrap()
833 .activate(&adapter_names)?;
834 }
835 Ok(sum)
836 }
837 fn config(&self) -> &ModelConfigMetadata {
838 &self.cfg
839 }
840}
841
842impl ScalingsMaker for XLoraLlama {
843 fn dtype(&self) -> DType {
844 self.dtype
845 }
846 fn get_cache(&self) -> &pipeline::EitherCache {
847 &self.kv_cache
848 }
849 fn get_classifier(&self) -> &XLoraClassifier {
850 self.xlora_classifier.as_ref().unwrap()
851 }
852 fn forward(
853 &self,
854 input_ids: &Tensor,
855 seqlen_offsets: &[usize],
856 scalings: Tensor,
857 is_full_pass: bool,
858 no_kv_cache: bool,
859 is_scaling_pass: Option<f64>,
860 _context_lens: &[usize],
861 flash_params: &FlashParams,
862 ) -> Result<Tensor> {
863 self.inner_forward(
864 input_ids,
865 seqlen_offsets,
866 Some(scalings),
867 is_full_pass,
868 no_kv_cache,
869 is_scaling_pass,
870 flash_params,
871 )
872 }
873}
874
875impl AnyMoeBaseModelMixin for XLoraLlama {}