1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{self, RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel, NormalLoadingMetadata,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Module, Result, Tensor};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::{Activation, CausalMasker, RmsNorm},
25 models::mistral::Config,
26 pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker};
30
31#[derive(Clone)]
32#[allow(clippy::upper_case_acronyms)]
33struct MLP {
34 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 act_fn: Activation,
38}
39
40impl MLP {
41 #[allow(clippy::too_many_arguments)]
42 fn new(
43 cfg: &Config,
44 vb: ShardedVarBuilder,
45 lora_config: &[((String, String), LoraConfig)],
46 count: &mut usize,
47 ord: &Ordering,
48 mapper: &dyn DeviceMapper,
49 layer_idx: usize,
50 loading_isq: bool,
51 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
52 ) -> Result<Self> {
53 let hidden_sz = cfg.hidden_size;
54 let intermediate_sz = cfg.intermediate_size;
55 let gate_proj = linear_no_bias(
56 hidden_sz,
57 intermediate_sz,
58 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
59 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
60 lora_config,
61 count,
62 ord,
63 preload_adapters,
64 )?;
65 let up_proj = linear_no_bias(
66 hidden_sz,
67 intermediate_sz,
68 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70 lora_config,
71 count,
72 ord,
73 preload_adapters,
74 )?;
75 let down_proj = linear_no_bias(
76 intermediate_sz,
77 hidden_sz,
78 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
79 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
80 lora_config,
81 count,
82 ord,
83 preload_adapters,
84 )?;
85 Ok(Self {
86 gate_proj,
87 up_proj,
88 down_proj,
89 act_fn: cfg.hidden_act,
90 })
91 }
92
93 fn forward(
94 &self,
95 xs: &Tensor,
96 scalings: Option<Tensor>,
97 global_scaling_weight: f64,
98 is_scaling_pass: Option<f64>,
99 ) -> Result<Tensor> {
100 let original_dtype = xs.dtype();
101 let mut xs = xs.clone();
102 if let Some(t) = self.gate_proj.quantized_act_type() {
103 xs = xs.to_dtype(t)?;
104 }
105 let lhs = self
106 .gate_proj
107 .lora_forward(
108 &xs,
109 scalings.clone(),
110 global_scaling_weight,
111 is_scaling_pass,
112 )?
113 .apply(&self.act_fn)?;
114 let rhs = self.up_proj.lora_forward(
115 &xs,
116 scalings.clone(),
117 global_scaling_weight,
118 is_scaling_pass,
119 )?;
120 let mut res = self.down_proj.lora_forward(
121 &(lhs * rhs)?,
122 scalings,
123 global_scaling_weight,
124 is_scaling_pass,
125 )?;
126 if self.gate_proj.quantized_act_type().is_some() {
127 res = res.to_dtype(original_dtype)?;
128 }
129 Ok(res)
130 }
131}
132
133struct Attention {
134 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
135 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138 num_heads: usize,
139 num_kv_heads: usize,
140 head_dim: usize,
141 rotary_emb: Arc<RotaryEmbedding>,
142 sliding_window: Option<usize>,
143 sdpa_params: SdpaParams,
144}
145
146impl Attention {
147 #[allow(clippy::too_many_arguments)]
148 fn new(
149 rotary_emb: Arc<RotaryEmbedding>,
150 cfg: &Config,
151 vb: ShardedVarBuilder,
152 lora_config: &[((String, String), LoraConfig)],
153 count: &mut usize,
154 ord: &Ordering,
155 mapper: &dyn DeviceMapper,
156 layer_idx: usize,
157 loading_isq: bool,
158 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
159 ) -> Result<Self> {
160 let hidden_sz = cfg.hidden_size;
161 let num_heads = cfg.num_attention_heads;
162 let num_kv_heads = cfg.num_key_value_heads;
163 let head_dim = cfg.head_dim();
164 let q_proj = linear_no_bias(
165 hidden_sz,
166 num_heads * head_dim,
167 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
168 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
169 lora_config,
170 count,
171 ord,
172 preload_adapters,
173 )?;
174 let k_proj = linear_no_bias(
175 hidden_sz,
176 num_kv_heads * head_dim,
177 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
178 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
179 lora_config,
180 count,
181 ord,
182 preload_adapters,
183 )?;
184 let v_proj = linear_no_bias(
185 hidden_sz,
186 num_kv_heads * head_dim,
187 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
188 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
189 lora_config,
190 count,
191 ord,
192 preload_adapters,
193 )?;
194 let o_proj = linear_no_bias(
195 num_heads * head_dim,
196 hidden_sz,
197 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
198 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
199 lora_config,
200 count,
201 ord,
202 preload_adapters,
203 )?;
204 Ok(Self {
205 q_proj,
206 k_proj,
207 v_proj,
208 o_proj,
209 num_heads,
210 num_kv_heads,
211 head_dim,
212 rotary_emb,
213 sliding_window: cfg.sliding_window,
214 sdpa_params: SdpaParams {
215 n_kv_groups: num_heads / num_kv_heads,
216 use_flash_attn: cfg.use_flash_attn,
217 softcap: None,
218 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
219 sliding_window: cfg.sliding_window,
220 },
221 })
222 }
223
224 #[allow(clippy::too_many_arguments)]
225 fn forward(
226 &self,
227 xs: &Tensor,
228 attention_mask: Option<&Tensor>,
229 seqlen_offsets: &[usize],
230 kv_cache: &mut Option<(Tensor, Tensor)>,
231 scalings: Option<Tensor>,
232 global_scaling_weight: f64,
233 is_scaling_pass: Option<f64>,
234 flash_params: &FlashParams,
235 ) -> Result<Tensor> {
236 let (b_sz, q_len, _) = xs.dims3()?;
237
238 let original_dtype = xs.dtype();
239 let mut xs = xs.clone();
240 if let Some(t) = self.q_proj.quantized_act_type() {
241 xs = xs.to_dtype(t)?;
242 }
243 let mut q = self.q_proj.lora_forward(
244 &xs,
245 scalings.clone(),
246 global_scaling_weight,
247 is_scaling_pass,
248 )?;
249 let mut k = self.k_proj.lora_forward(
250 &xs,
251 scalings.clone(),
252 global_scaling_weight,
253 is_scaling_pass,
254 )?;
255 let mut v = self.v_proj.lora_forward(
256 &xs,
257 scalings.clone(),
258 global_scaling_weight,
259 is_scaling_pass,
260 )?;
261 if self.q_proj.quantized_act_type().is_some() {
262 q = q.to_dtype(original_dtype)?;
263 k = k.to_dtype(original_dtype)?;
264 v = v.to_dtype(original_dtype)?;
265 }
266
267 let (q, k, v) = if q_len != 1 {
268 let q = q
269 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
270 .transpose(1, 2)?;
271 let k = k
272 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
273 .transpose(1, 2)?;
274 let v = v
275 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
276 .transpose(1, 2)?;
277 (q, k, v)
278 } else {
279 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
280 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
281 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
282 (q, k, v)
283 };
284
285 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
286
287 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
288 kv_cache,
289 k,
290 v,
291 attention_mask,
292 self.sliding_window,
293 false,
294 )?;
295
296 let mut attn_output = Sdpa.run_attention(
297 &q,
298 &k,
299 &v,
300 attn_mask.as_ref(),
301 Some(flash_params),
302 &self.sdpa_params,
303 )?;
304
305 if let Some(t) = self.q_proj.quantized_act_type() {
306 attn_output = attn_output.to_dtype(t)?;
307 }
308 let mut res = self.o_proj.lora_forward(
309 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
310 scalings.clone(),
311 global_scaling_weight,
312 is_scaling_pass,
313 )?;
314 if self.q_proj.quantized_act_type().is_some() {
315 res = res.to_dtype(original_dtype)?;
316 }
317 Ok(res)
318 }
319}
320
321struct DecoderLayer {
322 self_attn: Attention,
323 mlp: MLP,
324 input_layernorm: RmsNorm,
325 post_attention_layernorm: RmsNorm,
326}
327
328impl DecoderLayer {
329 #[allow(clippy::too_many_arguments)]
330 fn new(
331 rotary_emb: Arc<RotaryEmbedding>,
332 cfg: &Config,
333 vb: ShardedVarBuilder,
334 lora_config: &[((String, String), LoraConfig)],
335 count: &mut usize,
336 ord: &Ordering,
337 mapper: &dyn DeviceMapper,
338 layer_idx: usize,
339 loading_isq: bool,
340 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341 ) -> Result<Self> {
342 let self_attn = Attention::new(
343 rotary_emb,
344 cfg,
345 vb.pp("self_attn"),
346 lora_config,
347 count,
348 ord,
349 mapper,
350 layer_idx,
351 loading_isq,
352 preload_adapters,
353 )?;
354 let mlp = MLP::new(
355 cfg,
356 vb.pp("mlp"),
357 lora_config,
358 count,
359 ord,
360 mapper,
361 layer_idx,
362 loading_isq,
363 preload_adapters,
364 )?;
365 let input_layernorm = RmsNorm::new(
366 cfg.hidden_size,
367 cfg.rms_norm_eps,
368 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
369 )?;
370 let post_attention_layernorm = RmsNorm::new(
371 cfg.hidden_size,
372 cfg.rms_norm_eps,
373 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
374 )?;
375 Ok(Self {
376 self_attn,
377 mlp,
378 input_layernorm,
379 post_attention_layernorm,
380 })
381 }
382
383 #[allow(clippy::too_many_arguments)]
384 fn forward(
385 &self,
386 xs: &Tensor,
387 attention_mask: Option<&Tensor>,
388 seqlen_offsets: &[usize],
389 kv_cache: &mut Option<(Tensor, Tensor)>,
390 scalings: Option<Tensor>,
391 global_scaling_weight: f64,
392 is_scaling_pass: Option<f64>,
393 flash_params: &FlashParams,
394 ) -> Result<Tensor> {
395 let residual = xs;
396 let xs = self.input_layernorm.forward(xs)?;
397 let xs = self.self_attn.forward(
398 &xs,
399 attention_mask,
400 seqlen_offsets,
401 kv_cache,
402 scalings.clone(),
403 global_scaling_weight,
404 is_scaling_pass,
405 flash_params,
406 )?;
407 let xs = (xs + residual)?;
408 let residual = &xs;
409 let xs = self.mlp.forward(
410 &xs.apply(&self.post_attention_layernorm)?,
411 scalings,
412 global_scaling_weight,
413 is_scaling_pass,
414 )?;
415 residual + xs
416 }
417}
418
419pub struct XLoraModel {
420 embed_tokens: candle_nn::Embedding,
421 layers: Vec<DecoderLayer>,
422 norm: RmsNorm,
423 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
424 sliding_window: Option<usize>,
425 dtype: DType,
426 device: Device,
427 cache: EitherCache,
428 max_seq_len: usize,
429 xlora_classifier: Option<XLoraClassifier>,
430 mapper: Box<dyn DeviceMapper + Send + Sync>,
431 cfg: ModelConfigMetadata,
432}
433
434impl XLoraModel {
435 #[allow(clippy::too_many_arguments)]
436 pub fn new(
437 cfg: &Config,
438 vb: ShardedVarBuilder,
439 lora_config: &[((String, String), LoraConfig)],
440 xlora_config: Option<XLoraConfig>,
441 xlora_ordering: Ordering,
442 is_gptx: bool,
443 normal_loading_metadata: NormalLoadingMetadata,
444 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
445 ) -> Result<Self> {
446 if let Some(ref quant_cfg) = &cfg.quantization_config {
447 tracing::info!(
448 "Using {} quantization: {}.",
449 quant_cfg.quant_method.to_string(),
450 quant_cfg.get_bits_name(&vb)
451 );
452 }
453 let mapper = normal_loading_metadata.mapper;
454
455 let vb_m = vb.pp("model");
456 let embed_tokens = layers::embedding(
457 cfg.vocab_size,
458 cfg.hidden_size,
459 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
460 )?;
461 let head_dim = cfg.head_dim();
462 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
463 let vb_l = vb_m.pp("layers");
464 let mut ropes = HashMap::new();
465 for layer_idx in 0..cfg.num_hidden_layers {
466 let device = mapper
467 .device_for(layer_idx, false)
468 .unwrap_or(&normal_loading_metadata.real_device);
469 ropes.insert(
470 device.location(),
471 Arc::new(RotaryEmbedding::new(
472 cfg.rope_theta as f32,
473 head_dim,
474 cfg.max_position_embeddings,
475 device,
476 is_gptx,
477 vb_m.dtype(),
478 )?),
479 );
480 }
481 let mut count = 0;
482 for layer_idx in NiceProgressBar::<_, 'b'>(
483 0..cfg.num_hidden_layers,
484 "Loading repeating layers",
485 &normal_loading_metadata.multi_progress,
486 ) {
487 let device = mapper
488 .device_for(layer_idx, false)
489 .unwrap_or(&normal_loading_metadata.real_device);
490 let rotary_emb = ropes
491 .get(&device.location())
492 .expect("No RoPE for device location!")
493 .clone();
494 let layer = DecoderLayer::new(
495 rotary_emb.clone(),
496 cfg,
497 vb_l.pp(layer_idx),
498 lora_config,
499 &mut count,
500 &xlora_ordering,
501 &*mapper,
502 layer_idx,
503 normal_loading_metadata.loading_isq,
504 preload_adapters,
505 )?;
506 layers.push(layer)
507 }
508 if xlora_config.is_none() && preload_adapters.is_none() {
509 info!("Merging LoRA adapters.");
511 for layer in layers.iter_mut().tqdm() {
512 Arc::get_mut(&mut layer.self_attn.k_proj)
513 .unwrap()
514 .merge_weights()?;
515 Arc::get_mut(&mut layer.self_attn.o_proj)
516 .unwrap()
517 .merge_weights()?;
518 Arc::get_mut(&mut layer.self_attn.q_proj)
519 .unwrap()
520 .merge_weights()?;
521 Arc::get_mut(&mut layer.self_attn.v_proj)
522 .unwrap()
523 .merge_weights()?;
524
525 Arc::get_mut(&mut layer.mlp.down_proj)
526 .unwrap()
527 .merge_weights()?;
528 Arc::get_mut(&mut layer.mlp.gate_proj)
529 .unwrap()
530 .merge_weights()?;
531 Arc::get_mut(&mut layer.mlp.up_proj)
532 .unwrap()
533 .merge_weights()?;
534 }
535 }
536 let norm = RmsNorm::new(
537 cfg.hidden_size,
538 cfg.rms_norm_eps,
539 mapper.set_nm_device(vb_m.pp("norm"), false),
540 )?;
541 let lm_head = linear_no_bias(
542 cfg.hidden_size,
543 cfg.vocab_size,
544 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
545 mapper.set_nm_device(vb.pp("lm_head"), false),
546 lora_config,
547 &mut count,
548 &xlora_ordering,
549 preload_adapters,
550 )?;
551 if xlora_config.is_some() && lm_head.is_lora() {
552 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
554 }
555 Ok(Self {
556 embed_tokens,
557 layers,
558 norm,
559 lm_head,
560 sliding_window: cfg.sliding_window,
561 device: normal_loading_metadata.real_device,
562 dtype: vb.dtype(),
563 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
564 max_seq_len: cfg.max_position_embeddings,
565 xlora_classifier: xlora_config.map(|xlora_config| {
566 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
567 }),
568 mapper,
569 cfg: ModelConfigMetadata {
570 max_seq_len: cfg.max_position_embeddings,
571 num_layers: cfg.num_hidden_layers,
572 hidden_size: cfg.hidden_size,
573 num_kv_heads: cfg.num_key_value_heads,
574 num_attn_heads: cfg.num_attention_heads,
575 sliding_window: cfg.sliding_window,
576 k_head_dim: cfg.head_dim(),
577 v_head_dim: cfg.head_dim(),
578 },
579 })
580 }
581
582 #[allow(clippy::too_many_arguments)]
583 fn inner_forward(
584 &self,
585 input_ids: &Tensor,
586 seqlen_offsets: &[usize],
587 scalings: Option<Tensor>,
588 is_full_pass: bool,
589 no_kv_cache: bool,
590 is_scaling_pass: Option<f64>,
591 flash_params: &FlashParams,
592 ) -> Result<Tensor> {
593 let mut cache = if is_full_pass {
594 if no_kv_cache {
595 let mut new_cache = Vec::new();
596 for _ in 0..self.cache.full().xlora_lock().len() {
597 new_cache.push(None);
598 }
599
600 self.cache.full().xlora_lock().clone_from(&new_cache);
601 }
602 self.cache.full().xlora_lock()
603 } else {
604 self.cache.full().lock()
605 };
606 let mut xs = self.embed_tokens.forward(input_ids)?;
607 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
608 input_ids,
609 &*cache,
610 self.sliding_window,
611 xs.dtype(),
612 self.cfg.num_attn_heads,
613 )?;
614 for (i, layer) in self.layers.iter().enumerate() {
615 xs = self.mapper.map(xs, i)?;
616 xs = layer.forward(
617 &xs,
618 attention_mask
619 .as_ref()
620 .map(|m| m.to_device(xs.device()).unwrap())
621 .as_ref(),
622 seqlen_offsets,
623 &mut cache[i],
624 scalings.clone(),
625 self.xlora_classifier
626 .as_ref()
627 .map(|classifier| classifier.get_global_scaling_weight())
628 .unwrap_or(1.0),
629 is_scaling_pass,
630 flash_params,
631 )?
632 }
633 let xs = xs.to_device(&self.device)?;
634 xs.apply(&self.norm)
635 }
636
637 #[allow(clippy::too_many_arguments)]
638 pub fn forward(
639 &self,
640 input_ids: &Tensor,
641 input_ids_full: &Tensor,
642 seqlen_offsets: &[usize],
643 seqlen_offsets_full: &[usize],
644 no_kv_cache: bool,
645 non_granular_state: &Option<NonGranularState>,
646 context_lens: Vec<(usize, usize)>,
647 flash_params: &FlashParams,
648 flash_params_full: &FlashParams,
649 ) -> Result<Tensor> {
650 if self.xlora_classifier.is_some() {
651 let scalings = self.get_scalings(
652 input_ids,
653 input_ids_full,
654 seqlen_offsets,
655 seqlen_offsets_full,
656 no_kv_cache,
657 non_granular_state,
658 &vec![usize::MAX; context_lens.len()],
659 flash_params,
660 flash_params_full,
661 )?;
662
663 if no_kv_cache {
664 let mut res = self
665 .inner_forward(
666 input_ids_full,
667 seqlen_offsets_full,
668 Some(scalings),
669 true,
670 no_kv_cache,
671 None,
672 flash_params_full,
673 )?
674 .contiguous()?;
675 if let Some(t) = self.lm_head.quantized_act_type() {
676 res = res.to_dtype(t)?;
677 }
678 extract_logits(
679 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
680 context_lens,
681 )
682 } else {
683 let mut res = self
685 .inner_forward(
686 input_ids,
687 seqlen_offsets,
688 Some(scalings),
689 true,
690 no_kv_cache,
691 None,
692 flash_params,
693 )?
694 .contiguous()?;
695 if let Some(t) = self.lm_head.quantized_act_type() {
696 res = res.to_dtype(t)?;
697 }
698 extract_logits(
699 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
700 context_lens,
701 )
702 }
703 } else {
704 let mut res = self
705 .inner_forward(
706 input_ids,
707 seqlen_offsets,
708 None,
709 false,
710 no_kv_cache,
711 None,
712 flash_params,
713 )?
714 .contiguous()?;
715 if let Some(t) = self.lm_head.quantized_act_type() {
716 res = res.to_dtype(t)?;
717 }
718 extract_logits(
719 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
720 context_lens,
721 )
722 }
723 }
724}
725
726impl IsqModel for XLoraModel {
727 fn get_layers(
728 &mut self,
729 ) -> (
730 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
731 &dyn DeviceMapper,
732 ) {
733 let mut tensors = Vec::new();
734 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
735 for (i, layer) in self.layers.iter_mut().enumerate() {
736 tensors.push((
737 Arc::get_mut(&mut layer.self_attn.q_proj)
738 .unwrap()
739 .quant_inner(),
740 Some(i),
741 ));
742 tensors.push((
743 Arc::get_mut(&mut layer.self_attn.k_proj)
744 .unwrap()
745 .quant_inner(),
746 Some(i),
747 ));
748 tensors.push((
749 Arc::get_mut(&mut layer.self_attn.v_proj)
750 .unwrap()
751 .quant_inner(),
752 Some(i),
753 ));
754 tensors.push((
755 Arc::get_mut(&mut layer.self_attn.o_proj)
756 .unwrap()
757 .quant_inner(),
758 Some(i),
759 ));
760 tensors.push((
761 Arc::get_mut(&mut layer.mlp.gate_proj)
762 .unwrap()
763 .quant_inner(),
764 Some(i),
765 ));
766 tensors.push((
767 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
768 Some(i),
769 ));
770 tensors.push((
771 Arc::get_mut(&mut layer.mlp.down_proj)
772 .unwrap()
773 .quant_inner(),
774 Some(i),
775 ));
776 }
777 (tensors, &*self.mapper)
778 }
779
780 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
781 panic!("Cannot generate UQFF for an adapter model.")
782 }
783}
784
785impl NormalModel for XLoraModel {
786 fn forward(
787 &self,
788 _input_ids: &Tensor,
789 _seqlen_offsets: &[usize],
790 _context_lens: Vec<(usize, usize)>,
791 _position_ids: Vec<usize>,
792 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
793 _flash_params: &FlashParams,
794 ) -> Result<Tensor> {
795 unreachable!()
796 }
797 fn xlora_forward(
798 &self,
799 input_ids: &Tensor,
800 input_ids_full: &Tensor,
801 seqlen_offsets: &[usize],
802 seqlen_offsets_full: &[usize],
803 no_kv_cache: bool,
804 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
805 context_lens: Vec<(usize, usize)>,
806 _position_ids: Vec<usize>,
807 flash_params: &FlashParams,
808 flash_params_full: &FlashParams,
809 ) -> Result<Tensor> {
810 self.forward(
811 input_ids,
812 input_ids_full,
813 seqlen_offsets,
814 seqlen_offsets_full,
815 no_kv_cache,
816 non_granular_state,
817 context_lens,
818 flash_params,
819 flash_params_full,
820 )
821 }
822 fn cache(&self) -> &EitherCache {
823 &self.cache
824 }
825 fn cache_mut(&mut self) -> &mut EitherCache {
826 &mut self.cache
827 }
828 fn device(&self) -> &Device {
829 &self.device
830 }
831 fn is_xlora(&self) -> bool {
832 true
833 }
834 fn max_seq_len(&self) -> usize {
835 self.max_seq_len
836 }
837 fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
838 if self.xlora_classifier.is_some() {
839 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
840 }
841 let mut sum = 0;
842 for layer in self.layers.iter_mut() {
843 sum += Arc::get_mut(&mut layer.self_attn.k_proj)
844 .unwrap()
845 .activate(&adapter_names)?;
846 sum += Arc::get_mut(&mut layer.self_attn.o_proj)
847 .unwrap()
848 .activate(&adapter_names)?;
849 sum += Arc::get_mut(&mut layer.self_attn.q_proj)
850 .unwrap()
851 .activate(&adapter_names)?;
852 sum += Arc::get_mut(&mut layer.self_attn.v_proj)
853 .unwrap()
854 .activate(&adapter_names)?;
855
856 sum += Arc::get_mut(&mut layer.mlp.down_proj)
857 .unwrap()
858 .activate(&adapter_names)?;
859 sum += Arc::get_mut(&mut layer.mlp.gate_proj)
860 .unwrap()
861 .activate(&adapter_names)?;
862 sum += Arc::get_mut(&mut layer.mlp.up_proj)
863 .unwrap()
864 .activate(&adapter_names)?;
865 }
866 Ok(sum)
867 }
868 fn config(&self) -> &ModelConfigMetadata {
869 &self.cfg
870 }
871}
872
873impl ScalingsMaker for XLoraModel {
874 fn dtype(&self) -> DType {
875 self.dtype
876 }
877 fn get_cache(&self) -> &EitherCache {
878 &self.cache
879 }
880 fn get_classifier(&self) -> &XLoraClassifier {
881 self.xlora_classifier.as_ref().unwrap()
882 }
883 fn forward(
884 &self,
885 input_ids: &Tensor,
886 seqlen_offsets: &[usize],
887 scalings: Tensor,
888 is_full_pass: bool,
889 no_kv_cache: bool,
890 is_scaling_pass: Option<f64>,
891 _context_lens: &[usize],
892 flash_params: &FlashParams,
893 ) -> Result<Tensor> {
894 self.inner_forward(
895 input_ids,
896 seqlen_offsets,
897 Some(scalings),
898 is_full_pass,
899 no_kv_cache,
900 is_scaling_pass,
901 flash_params,
902 )
903 }
904}
905
906impl AnyMoeBaseModelMixin for XLoraModel {}