1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{self, Activation, RmsNorm, RotaryEmbedding, Sdpa},
9 lora::{linear_b as linear, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::CausalMasker,
25 models::gemma::Config,
26 pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31fn default_max_position_embeddings() -> usize {
32 4096
33}
34
35#[derive(Clone)]
36#[allow(clippy::upper_case_acronyms)]
37struct MLP {
38 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
39 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
40 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
41 act_fn: Activation,
42}
43
44impl MLP {
45 #[allow(clippy::too_many_arguments)]
46 fn new(
47 cfg: &Config,
48 vb: ShardedVarBuilder,
49 lora_config: &[((String, String), LoraConfig)],
50 count: &mut usize,
51 ord: &Ordering,
52 mapper: &dyn DeviceMapper,
53 layer_idx: usize,
54 loading_isq: bool,
55 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
56 ) -> Result<Self> {
57 let hidden_sz = cfg.hidden_size;
58 let intermediate_sz = cfg.intermediate_size;
59 let gate_proj = linear(
60 hidden_sz,
61 intermediate_sz,
62 false,
63 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
64 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
65 lora_config,
66 count,
67 ord,
68 preload_adapters,
69 )?;
70 let up_proj = linear(
71 hidden_sz,
72 intermediate_sz,
73 false,
74 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
75 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
76 lora_config,
77 count,
78 ord,
79 preload_adapters,
80 )?;
81 let down_proj = linear(
82 intermediate_sz,
83 hidden_sz,
84 false,
85 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
86 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
87 lora_config,
88 count,
89 ord,
90 preload_adapters,
91 )?;
92 Ok(Self {
93 gate_proj,
94 up_proj,
95 down_proj,
96 act_fn: cfg.hidden_act()?,
97 })
98 }
99
100 fn forward(
101 &self,
102 xs: &Tensor,
103 scalings: Option<Tensor>,
104 global_scaling_weight: f64,
105 is_scaling_pass: Option<f64>,
106 ) -> Result<Tensor> {
107 let original_dtype = xs.dtype();
108 let mut xs = xs.clone();
109 if let Some(t) = self.gate_proj.quantized_act_type() {
110 xs = xs.to_dtype(t)?;
111 }
112 let lhs = self
113 .gate_proj
114 .lora_forward(
115 &xs,
116 scalings.clone(),
117 global_scaling_weight,
118 is_scaling_pass,
119 )?
120 .apply(&self.act_fn)?;
121 let rhs = self.up_proj.lora_forward(
122 &xs,
123 scalings.clone(),
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 let mut res = self.down_proj.lora_forward(
128 &(lhs * rhs)?,
129 scalings,
130 global_scaling_weight,
131 is_scaling_pass,
132 )?;
133 if self.gate_proj.quantized_act_type().is_some() {
134 res = res.to_dtype(original_dtype)?;
135 }
136 Ok(res)
137 }
138}
139
140struct Attention {
141 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
142 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
143 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
144 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
145 num_heads: usize,
146 num_kv_heads: usize,
147 head_dim: usize,
148 rotary_emb: Arc<RotaryEmbedding>,
149 sdpa_params: SdpaParams,
150}
151
152impl Attention {
153 #[allow(clippy::too_many_arguments)]
154 fn new(
155 rotary_emb: Arc<RotaryEmbedding>,
156 cfg: &Config,
157 vb: ShardedVarBuilder,
158 lora_config: &[((String, String), LoraConfig)],
159 count: &mut usize,
160 ord: &Ordering,
161 mapper: &dyn DeviceMapper,
162 layer_idx: usize,
163 loading_isq: bool,
164 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
165 ) -> Result<Self> {
166 let hidden_sz = cfg.hidden_size;
167 let num_heads = cfg.num_attention_heads;
168 let num_kv_heads = cfg.num_key_value_heads;
169 let head_dim = cfg.head_dim;
170 let bias = cfg.attention_bias;
171 let q_proj = linear(
172 hidden_sz,
173 num_heads * head_dim,
174 bias,
175 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
176 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
177 lora_config,
178 count,
179 ord,
180 preload_adapters,
181 )?;
182 let k_proj = linear(
183 hidden_sz,
184 num_kv_heads * head_dim,
185 bias,
186 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
187 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
188 lora_config,
189 count,
190 ord,
191 preload_adapters,
192 )?;
193 let v_proj = linear(
194 hidden_sz,
195 num_kv_heads * head_dim,
196 bias,
197 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
198 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
199 lora_config,
200 count,
201 ord,
202 preload_adapters,
203 )?;
204 let o_proj = linear(
205 num_heads * head_dim,
206 hidden_sz,
207 bias,
208 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
209 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
210 lora_config,
211 count,
212 ord,
213 preload_adapters,
214 )?;
215 Ok(Self {
216 q_proj,
217 k_proj,
218 v_proj,
219 o_proj,
220 num_heads,
221 num_kv_heads,
222 head_dim,
223 rotary_emb,
224 sdpa_params: SdpaParams {
225 n_kv_groups: num_heads / num_kv_heads,
226 use_flash_attn: cfg.use_flash_attn,
227 softcap: None,
228 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
229 sliding_window: None,
230 },
231 })
232 }
233
234 #[allow(clippy::too_many_arguments)]
235 fn forward(
236 &self,
237 xs: &Tensor,
238 attention_mask: Option<&Tensor>,
239 seqlen_offsets: &[usize],
240 kv_cache: &mut Option<(Tensor, Tensor)>,
241 scalings: Option<Tensor>,
242 global_scaling_weight: f64,
243 is_scaling_pass: Option<f64>,
244 flash_params: &FlashParams,
245 ) -> Result<Tensor> {
246 let (b_sz, q_len, _) = xs.dims3()?;
247
248 let original_dtype = xs.dtype();
249 let mut xs = xs.clone();
250 if let Some(t) = self.q_proj.quantized_act_type() {
251 xs = xs.to_dtype(t)?;
252 }
253 let mut q = self.q_proj.lora_forward(
254 &xs,
255 scalings.clone(),
256 global_scaling_weight,
257 is_scaling_pass,
258 )?;
259 let mut k = self.k_proj.lora_forward(
260 &xs,
261 scalings.clone(),
262 global_scaling_weight,
263 is_scaling_pass,
264 )?;
265 let mut v = self.v_proj.lora_forward(
266 &xs,
267 scalings.clone(),
268 global_scaling_weight,
269 is_scaling_pass,
270 )?;
271 if self.q_proj.quantized_act_type().is_some() {
272 q = q.to_dtype(original_dtype)?;
273 k = k.to_dtype(original_dtype)?;
274 v = v.to_dtype(original_dtype)?;
275 }
276
277 let (q, k, v) = if q_len != 1 {
278 let q = q
279 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
280 .transpose(1, 2)?;
281 let k = k
282 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
283 .transpose(1, 2)?;
284 let v = v
285 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
286 .transpose(1, 2)?;
287 (q, k, v)
288 } else {
289 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
290 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
291 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
292 (q, k, v)
293 };
294
295 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
296
297 let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
298
299 let mut attn_output = Sdpa.run_attention(
300 &q,
301 &k,
302 &v,
303 attention_mask,
304 Some(flash_params),
305 &self.sdpa_params,
306 )?;
307
308 if let Some(t) = self.q_proj.quantized_act_type() {
309 attn_output = attn_output.to_dtype(t)?;
310 }
311 let mut res = self.o_proj.lora_forward(
312 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
313 scalings.clone(),
314 global_scaling_weight,
315 is_scaling_pass,
316 )?;
317 if self.q_proj.quantized_act_type().is_some() {
318 res = res.to_dtype(original_dtype)?;
319 }
320 Ok(res)
321 }
322}
323
324struct DecoderLayer {
325 self_attn: Attention,
326 mlp: MLP,
327 input_layernorm: RmsNorm,
328 post_attention_layernorm: RmsNorm,
329}
330
331impl DecoderLayer {
332 #[allow(clippy::too_many_arguments)]
333 fn new(
334 rotary_emb: Arc<RotaryEmbedding>,
335 cfg: &Config,
336 vb: ShardedVarBuilder,
337 lora_config: &[((String, String), LoraConfig)],
338 count: &mut usize,
339 ord: &Ordering,
340 mapper: &dyn DeviceMapper,
341 layer_idx: usize,
342 loading_isq: bool,
343 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
344 ) -> Result<Self> {
345 let self_attn = Attention::new(
346 rotary_emb,
347 cfg,
348 vb.pp("self_attn"),
349 lora_config,
350 count,
351 ord,
352 mapper,
353 layer_idx,
354 loading_isq,
355 preload_adapters,
356 )?;
357 let mlp = MLP::new(
358 cfg,
359 vb.pp("mlp"),
360 lora_config,
361 count,
362 ord,
363 mapper,
364 layer_idx,
365 loading_isq,
366 preload_adapters,
367 )?;
368 let input_layernorm = RmsNorm::new_gemma(
369 cfg.hidden_size,
370 cfg.rms_norm_eps,
371 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
372 )?;
373 let post_attention_layernorm = RmsNorm::new_gemma(
374 cfg.hidden_size,
375 cfg.rms_norm_eps,
376 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
377 )?;
378 Ok(Self {
379 self_attn,
380 mlp,
381 input_layernorm,
382 post_attention_layernorm,
383 })
384 }
385
386 #[allow(clippy::too_many_arguments)]
387 fn forward(
388 &self,
389 xs: &Tensor,
390 attention_mask: Option<&Tensor>,
391 seqlen_offsets: &[usize],
392 kv_cache: &mut Option<(Tensor, Tensor)>,
393 scalings: Option<Tensor>,
394 global_scaling_weight: f64,
395 is_scaling_pass: Option<f64>,
396 flash_params: &FlashParams,
397 ) -> Result<Tensor> {
398 let residual = xs;
399 let xs = self.input_layernorm.forward(xs)?;
400 let xs = self.self_attn.forward(
401 &xs,
402 attention_mask,
403 seqlen_offsets,
404 kv_cache,
405 scalings.clone(),
406 global_scaling_weight,
407 is_scaling_pass,
408 flash_params,
409 )?;
410 let xs = (xs + residual)?;
411 let residual = &xs;
412 let xs = self.mlp.forward(
413 &xs.apply(&self.post_attention_layernorm)?,
414 scalings.clone(),
415 global_scaling_weight,
416 is_scaling_pass,
417 )?;
418 residual + xs
419 }
420}
421
422pub struct XLoraModel {
423 embed_tokens: candle_nn::Embedding,
424 layers: Vec<DecoderLayer>,
425 norm: RmsNorm,
426 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
427 dtype: DType,
428 hidden_size: usize,
429 device: Device,
430 cache: EitherCache,
431 max_seq_len: usize,
432 xlora_classifier: Option<XLoraClassifier>,
433 mapper: Box<dyn DeviceMapper + Send + Sync>,
434 cfg: ModelConfigMetadata,
435}
436
437impl XLoraModel {
438 #[allow(clippy::too_many_arguments)]
439 pub fn new(
440 cfg: &Config,
441 vb: ShardedVarBuilder,
442 lora_config: &[((String, String), LoraConfig)],
443 xlora_config: Option<XLoraConfig>,
444 xlora_ordering: Ordering,
445 is_gptx: bool,
446 normal_loading_metadata: NormalLoadingMetadata,
447 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
448 ) -> Result<Self> {
449 if let Some(ref quant_cfg) = &cfg.quantization_config {
450 tracing::info!(
451 "Using {} quantization: {}.",
452 quant_cfg.name(),
453 quant_cfg.get_bits_name(&vb)
454 );
455 }
456 let mapper = normal_loading_metadata.mapper;
457 let vb_m = vb.pp("model");
458
459 let embed_tokens = layers::embedding(
460 cfg.vocab_size,
461 cfg.hidden_size,
462 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
463 &cfg.quantization_config,
464 )?;
465 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
466 let vb_l = vb_m.pp("layers");
467 let mut ropes = HashMap::new();
468 for layer_idx in 0..cfg.num_hidden_layers {
469 let device = mapper
470 .device_for(layer_idx, false)
471 .unwrap_or(&normal_loading_metadata.real_device);
472 ropes.insert(
473 device.location(),
474 Arc::new(RotaryEmbedding::new(
475 cfg.rope_theta as f32,
476 cfg.head_dim,
477 cfg.max_position_embeddings,
478 device,
479 is_gptx,
480 vb.dtype(),
481 )?),
482 );
483 }
484
485 let mut count = 0;
486 for layer_idx in NiceProgressBar::<_, 'b'>(
487 0..cfg.num_hidden_layers,
488 "Loading repeating layers",
489 &normal_loading_metadata.multi_progress,
490 ) {
491 let device = mapper
492 .device_for(layer_idx, false)
493 .unwrap_or(&normal_loading_metadata.real_device);
494 let rotary_emb = ropes
495 .get(&device.location())
496 .expect("No RoPE for device location!")
497 .clone();
498 let layer = DecoderLayer::new(
499 rotary_emb.clone(),
500 cfg,
501 vb_l.pp(layer_idx),
502 lora_config,
503 &mut count,
504 &xlora_ordering,
505 &*mapper,
506 layer_idx,
507 normal_loading_metadata.loading_isq,
508 preload_adapters,
509 )?;
510 layers.push(layer)
511 }
512 if xlora_config.is_none() && preload_adapters.is_none() {
513 info!("Merging LoRA adapters.");
515 for layer in layers.iter_mut().tqdm() {
516 Arc::get_mut(&mut layer.self_attn.k_proj)
517 .unwrap()
518 .merge_weights()?;
519 Arc::get_mut(&mut layer.self_attn.o_proj)
520 .unwrap()
521 .merge_weights()?;
522 Arc::get_mut(&mut layer.self_attn.q_proj)
523 .unwrap()
524 .merge_weights()?;
525 Arc::get_mut(&mut layer.self_attn.v_proj)
526 .unwrap()
527 .merge_weights()?;
528
529 Arc::get_mut(&mut layer.mlp.down_proj)
530 .unwrap()
531 .merge_weights()?;
532 Arc::get_mut(&mut layer.mlp.gate_proj)
533 .unwrap()
534 .merge_weights()?;
535 Arc::get_mut(&mut layer.mlp.up_proj)
536 .unwrap()
537 .merge_weights()?;
538 }
539 }
540 let norm = RmsNorm::new_gemma(
541 cfg.hidden_size,
542 cfg.rms_norm_eps,
543 mapper.set_nm_device(vb_m.pp("norm"), false),
544 )?;
545 let lm_head = linear(
546 embed_tokens.embeddings().dim(1)?,
547 embed_tokens.embeddings().dim(0)?,
548 false,
549 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
550 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
551 lora_config,
552 &mut count,
553 &xlora_ordering,
554 preload_adapters,
555 )?;
556 if xlora_config.is_some() && lm_head.is_lora() {
557 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
559 }
560
561 Ok(Self {
562 embed_tokens,
563 layers,
564 norm,
565 lm_head,
566 device: normal_loading_metadata.real_device,
567 dtype: vb.dtype(),
568 hidden_size: cfg.hidden_size,
569 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
570 max_seq_len: default_max_position_embeddings(),
571 xlora_classifier: xlora_config.map(|xlora_config| {
572 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
573 }),
574 mapper,
575 cfg: ModelConfigMetadata {
576 max_seq_len: cfg.max_position_embeddings,
577 num_layers: cfg.num_hidden_layers,
578 hidden_size: cfg.hidden_size,
579 num_kv_heads: cfg.num_key_value_heads,
580 num_attn_heads: cfg.num_attention_heads,
581 sliding_window: None,
582 k_head_dim: cfg.head_dim,
583 v_head_dim: cfg.head_dim,
584 },
585 })
586 }
587
588 #[allow(clippy::too_many_arguments)]
589 fn inner_forward(
590 &self,
591 input_ids: &Tensor,
592 seqlen_offsets: &[usize],
593 scalings: Option<Tensor>,
594 is_full_pass: bool,
595 no_kv_cache: bool,
596 is_scaling_pass: Option<f64>,
597 flash_params: &FlashParams,
598 ) -> Result<Tensor> {
599 let mut cache = if is_full_pass {
600 if no_kv_cache {
601 let mut new_cache = Vec::new();
602 for _ in 0..self.cache.full().xlora_lock().len() {
603 new_cache.push(None);
604 }
605
606 self.cache.full().xlora_lock().clone_from(&new_cache);
607 }
608 self.cache.full().xlora_lock()
609 } else {
610 self.cache.full().lock()
611 };
612 let xs = self.embed_tokens.forward(input_ids)?;
613 let attention_mask = CausalMasker.make_causal_mask_matrix(
614 input_ids,
615 &*cache,
616 xs.dtype(),
617 self.cfg.num_attn_heads,
618 )?;
619 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
620 for (i, layer) in self.layers.iter().enumerate() {
621 xs = self.mapper.map(xs, i)?;
622 xs = layer.forward(
623 &xs,
624 attention_mask
625 .as_ref()
626 .map(|m| m.to_device(xs.device()).unwrap())
627 .as_ref(),
628 seqlen_offsets,
629 &mut cache[i],
630 scalings.clone(),
631 self.xlora_classifier
632 .as_ref()
633 .map(|classifier| classifier.get_global_scaling_weight())
634 .unwrap_or(1.0),
635 is_scaling_pass,
636 flash_params,
637 )?
638 }
639 let xs = xs.to_device(&self.device)?;
640 xs.apply(&self.norm)
641 }
642
643 #[allow(clippy::too_many_arguments)]
644 pub fn forward(
645 &self,
646 input_ids: &Tensor,
647 input_ids_full: &Tensor,
648 seqlen_offsets: &[usize],
649 seqlen_offsets_full: &[usize],
650 no_kv_cache: bool,
651 non_granular_state: &Option<NonGranularState>,
652 context_lens: Vec<(usize, usize)>,
653 flash_params: &FlashParams,
654 flash_params_full: &FlashParams,
655 ) -> Result<Tensor> {
656 if self.xlora_classifier.is_some() {
657 let scalings = self.get_scalings(
658 input_ids,
659 input_ids_full,
660 seqlen_offsets,
661 seqlen_offsets_full,
662 no_kv_cache,
663 non_granular_state,
664 &vec![usize::MAX; context_lens.len()],
665 flash_params,
666 flash_params_full,
667 )?;
668
669 if no_kv_cache {
670 let mut res = self
671 .inner_forward(
672 input_ids_full,
673 seqlen_offsets_full,
674 Some(scalings),
675 true,
676 no_kv_cache,
677 None,
678 flash_params_full,
679 )?
680 .contiguous()?;
681 if let Some(t) = self.lm_head.quantized_act_type() {
682 res = res.to_dtype(t)?;
683 }
684 extract_logits(
685 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
686 context_lens,
687 )
688 } else {
689 let mut res = self
691 .inner_forward(
692 input_ids,
693 seqlen_offsets,
694 Some(scalings),
695 true,
696 no_kv_cache,
697 None,
698 flash_params,
699 )?
700 .contiguous()?;
701 if let Some(t) = self.lm_head.quantized_act_type() {
702 res = res.to_dtype(t)?;
703 }
704 extract_logits(
705 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
706 context_lens,
707 )
708 }
709 } else {
710 let mut res = self
711 .inner_forward(
712 input_ids,
713 seqlen_offsets,
714 None,
715 false,
716 no_kv_cache,
717 None,
718 flash_params,
719 )?
720 .contiguous()?;
721 if let Some(t) = self.lm_head.quantized_act_type() {
722 res = res.to_dtype(t)?;
723 }
724 extract_logits(
725 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
726 context_lens,
727 )
728 }
729 }
730}
731
732impl IsqModel for XLoraModel {
733 fn get_layers(
734 &mut self,
735 ) -> (
736 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
737 &dyn DeviceMapper,
738 ) {
739 let mut tensors = Vec::new();
740 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
741 for (i, layer) in self.layers.iter_mut().enumerate() {
742 tensors.push((
743 Arc::get_mut(&mut layer.self_attn.q_proj)
744 .unwrap()
745 .quant_inner(),
746 Some(i),
747 ));
748 tensors.push((
749 Arc::get_mut(&mut layer.self_attn.k_proj)
750 .unwrap()
751 .quant_inner(),
752 Some(i),
753 ));
754 tensors.push((
755 Arc::get_mut(&mut layer.self_attn.v_proj)
756 .unwrap()
757 .quant_inner(),
758 Some(i),
759 ));
760 tensors.push((
761 Arc::get_mut(&mut layer.self_attn.o_proj)
762 .unwrap()
763 .quant_inner(),
764 Some(i),
765 ));
766 tensors.push((
767 Arc::get_mut(&mut layer.mlp.gate_proj)
768 .unwrap()
769 .quant_inner(),
770 Some(i),
771 ));
772 tensors.push((
773 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
774 Some(i),
775 ));
776 tensors.push((
777 Arc::get_mut(&mut layer.mlp.down_proj)
778 .unwrap()
779 .quant_inner(),
780 Some(i),
781 ));
782 }
783 (tensors, &*self.mapper)
784 }
785
786 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
787 panic!("Cannot generate UQFF for an adapter model.")
788 }
789}
790
791impl NormalModel for XLoraModel {
792 fn forward(
793 &self,
794 _input_ids: &Tensor,
795 _seqlen_offsets: &[usize],
796 _context_lens: Vec<(usize, usize)>,
797 _position_ids: Vec<usize>,
798 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
799 _flash_params: &FlashParams,
800 ) -> Result<Tensor> {
801 unreachable!()
802 }
803 fn xlora_forward(
804 &self,
805 input_ids: &Tensor,
806 input_ids_full: &Tensor,
807 seqlen_offsets: &[usize],
808 seqlen_offsets_full: &[usize],
809 no_kv_cache: bool,
810 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
811 context_lens: Vec<(usize, usize)>,
812 _position_ids: Vec<usize>,
813 flash_params: &FlashParams,
814 flash_params_full: &FlashParams,
815 ) -> Result<Tensor> {
816 self.forward(
817 input_ids,
818 input_ids_full,
819 seqlen_offsets,
820 seqlen_offsets_full,
821 no_kv_cache,
822 non_granular_state,
823 context_lens,
824 flash_params,
825 flash_params_full,
826 )
827 }
828 fn cache(&self) -> &EitherCache {
829 &self.cache
830 }
831 fn cache_mut(&mut self) -> &mut EitherCache {
832 &mut self.cache
833 }
834 fn device(&self) -> &Device {
835 &self.device
836 }
837 fn is_xlora(&self) -> bool {
838 true
839 }
840 fn max_seq_len(&self) -> usize {
841 self.max_seq_len
842 }
843 fn config(&self) -> &ModelConfigMetadata {
844 &self.cfg
845 }
846}
847
848impl ScalingsMaker for XLoraModel {
849 fn dtype(&self) -> DType {
850 self.dtype
851 }
852 fn get_cache(&self) -> &EitherCache {
853 &self.cache
854 }
855 fn get_classifier(&self) -> &XLoraClassifier {
856 self.xlora_classifier.as_ref().unwrap()
857 }
858 fn forward(
859 &self,
860 input_ids: &Tensor,
861 seqlen_offsets: &[usize],
862 scalings: Tensor,
863 is_full_pass: bool,
864 no_kv_cache: bool,
865 is_scaling_pass: Option<f64>,
866 _context_lens: &[usize],
867 flash_params: &FlashParams,
868 ) -> Result<Tensor> {
869 self.inner_forward(
870 input_ids,
871 seqlen_offsets,
872 Some(scalings),
873 is_full_pass,
874 no_kv_cache,
875 is_scaling_pass,
876 flash_params,
877 )
878 }
879}
880
881impl AnyMoeBaseModelMixin for XLoraModel {}