1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{self, Activation, RmsNorm, RotaryEmbedding, Sdpa},
9 lora::{linear_b as linear, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::CausalMasker,
25 models::gemma::Config,
26 pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31fn default_max_position_embeddings() -> usize {
32 4096
33}
34
35#[derive(Clone)]
36#[allow(clippy::upper_case_acronyms)]
37struct MLP {
38 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
39 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
40 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
41 act_fn: Activation,
42}
43
44impl MLP {
45 #[allow(clippy::too_many_arguments)]
46 fn new(
47 cfg: &Config,
48 vb: ShardedVarBuilder,
49 lora_config: &[((String, String), LoraConfig)],
50 count: &mut usize,
51 ord: &Ordering,
52 mapper: &dyn DeviceMapper,
53 layer_idx: usize,
54 loading_isq: bool,
55 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
56 ) -> Result<Self> {
57 let hidden_sz = cfg.hidden_size;
58 let intermediate_sz = cfg.intermediate_size;
59 let gate_proj = linear(
60 hidden_sz,
61 intermediate_sz,
62 false,
63 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
64 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
65 lora_config,
66 count,
67 ord,
68 preload_adapters,
69 )?;
70 let up_proj = linear(
71 hidden_sz,
72 intermediate_sz,
73 false,
74 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
75 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
76 lora_config,
77 count,
78 ord,
79 preload_adapters,
80 )?;
81 let down_proj = linear(
82 intermediate_sz,
83 hidden_sz,
84 false,
85 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
86 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
87 lora_config,
88 count,
89 ord,
90 preload_adapters,
91 )?;
92 Ok(Self {
93 gate_proj,
94 up_proj,
95 down_proj,
96 act_fn: cfg.hidden_act()?,
97 })
98 }
99
100 fn forward(
101 &self,
102 xs: &Tensor,
103 scalings: Option<Tensor>,
104 global_scaling_weight: f64,
105 is_scaling_pass: Option<f64>,
106 ) -> Result<Tensor> {
107 let original_dtype = xs.dtype();
108 let mut xs = xs.clone();
109 if let Some(t) = self.gate_proj.quantized_act_type() {
110 xs = xs.to_dtype(t)?;
111 }
112 let lhs = self
113 .gate_proj
114 .lora_forward(
115 &xs,
116 scalings.clone(),
117 global_scaling_weight,
118 is_scaling_pass,
119 )?
120 .apply(&self.act_fn)?;
121 let rhs = self.up_proj.lora_forward(
122 &xs,
123 scalings.clone(),
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 let mut res = self.down_proj.lora_forward(
128 &(lhs * rhs)?,
129 scalings,
130 global_scaling_weight,
131 is_scaling_pass,
132 )?;
133 if self.gate_proj.quantized_act_type().is_some() {
134 res = res.to_dtype(original_dtype)?;
135 }
136 Ok(res)
137 }
138}
139
140struct Attention {
141 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
142 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
143 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
144 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
145 num_heads: usize,
146 num_kv_heads: usize,
147 head_dim: usize,
148 rotary_emb: Arc<RotaryEmbedding>,
149 sdpa_params: SdpaParams,
150}
151
152impl Attention {
153 #[allow(clippy::too_many_arguments)]
154 fn new(
155 rotary_emb: Arc<RotaryEmbedding>,
156 cfg: &Config,
157 vb: ShardedVarBuilder,
158 lora_config: &[((String, String), LoraConfig)],
159 count: &mut usize,
160 ord: &Ordering,
161 mapper: &dyn DeviceMapper,
162 layer_idx: usize,
163 loading_isq: bool,
164 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
165 ) -> Result<Self> {
166 let hidden_sz = cfg.hidden_size;
167 let num_heads = cfg.num_attention_heads;
168 let num_kv_heads = cfg.num_key_value_heads;
169 let head_dim = cfg.head_dim;
170 let bias = cfg.attention_bias;
171 let q_proj = linear(
172 hidden_sz,
173 num_heads * head_dim,
174 bias,
175 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
176 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
177 lora_config,
178 count,
179 ord,
180 preload_adapters,
181 )?;
182 let k_proj = linear(
183 hidden_sz,
184 num_kv_heads * head_dim,
185 bias,
186 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
187 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
188 lora_config,
189 count,
190 ord,
191 preload_adapters,
192 )?;
193 let v_proj = linear(
194 hidden_sz,
195 num_kv_heads * head_dim,
196 bias,
197 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
198 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
199 lora_config,
200 count,
201 ord,
202 preload_adapters,
203 )?;
204 let o_proj = linear(
205 num_heads * head_dim,
206 hidden_sz,
207 bias,
208 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
209 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
210 lora_config,
211 count,
212 ord,
213 preload_adapters,
214 )?;
215 Ok(Self {
216 q_proj,
217 k_proj,
218 v_proj,
219 o_proj,
220 num_heads,
221 num_kv_heads,
222 head_dim,
223 rotary_emb,
224 sdpa_params: SdpaParams {
225 n_kv_groups: num_heads / num_kv_heads,
226 softcap: None,
227 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
228 sliding_window: None,
229 },
230 })
231 }
232
233 #[allow(clippy::too_many_arguments)]
234 fn forward(
235 &self,
236 xs: &Tensor,
237 attention_mask: Option<&Tensor>,
238 seqlen_offsets: &[usize],
239 kv_cache: &mut Option<(Tensor, Tensor)>,
240 scalings: Option<Tensor>,
241 global_scaling_weight: f64,
242 is_scaling_pass: Option<f64>,
243 flash_params: &FlashParams,
244 ) -> Result<Tensor> {
245 let (b_sz, q_len, _) = xs.dims3()?;
246
247 let original_dtype = xs.dtype();
248 let mut xs = xs.clone();
249 if let Some(t) = self.q_proj.quantized_act_type() {
250 xs = xs.to_dtype(t)?;
251 }
252 let mut q = self.q_proj.lora_forward(
253 &xs,
254 scalings.clone(),
255 global_scaling_weight,
256 is_scaling_pass,
257 )?;
258 let mut k = self.k_proj.lora_forward(
259 &xs,
260 scalings.clone(),
261 global_scaling_weight,
262 is_scaling_pass,
263 )?;
264 let mut v = self.v_proj.lora_forward(
265 &xs,
266 scalings.clone(),
267 global_scaling_weight,
268 is_scaling_pass,
269 )?;
270 if self.q_proj.quantized_act_type().is_some() {
271 q = q.to_dtype(original_dtype)?;
272 k = k.to_dtype(original_dtype)?;
273 v = v.to_dtype(original_dtype)?;
274 }
275
276 let (q, k, v) = if q_len != 1 {
277 let q = q
278 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
279 .transpose(1, 2)?;
280 let k = k
281 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
282 .transpose(1, 2)?;
283 let v = v
284 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
285 .transpose(1, 2)?;
286 (q, k, v)
287 } else {
288 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
289 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
290 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
291 (q, k, v)
292 };
293
294 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
295
296 let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
297
298 let mut attn_output = Sdpa.run_attention(
299 &q,
300 &k,
301 &v,
302 attention_mask,
303 Some(flash_params),
304 &self.sdpa_params,
305 )?;
306
307 if let Some(t) = self.q_proj.quantized_act_type() {
308 attn_output = attn_output.to_dtype(t)?;
309 }
310 let mut res = self.o_proj.lora_forward(
311 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
312 scalings.clone(),
313 global_scaling_weight,
314 is_scaling_pass,
315 )?;
316 if self.q_proj.quantized_act_type().is_some() {
317 res = res.to_dtype(original_dtype)?;
318 }
319 Ok(res)
320 }
321}
322
323struct DecoderLayer {
324 self_attn: Attention,
325 mlp: MLP,
326 input_layernorm: RmsNorm,
327 post_attention_layernorm: RmsNorm,
328}
329
330impl DecoderLayer {
331 #[allow(clippy::too_many_arguments)]
332 fn new(
333 rotary_emb: Arc<RotaryEmbedding>,
334 cfg: &Config,
335 vb: ShardedVarBuilder,
336 lora_config: &[((String, String), LoraConfig)],
337 count: &mut usize,
338 ord: &Ordering,
339 mapper: &dyn DeviceMapper,
340 layer_idx: usize,
341 loading_isq: bool,
342 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
343 ) -> Result<Self> {
344 let self_attn = Attention::new(
345 rotary_emb,
346 cfg,
347 vb.pp("self_attn"),
348 lora_config,
349 count,
350 ord,
351 mapper,
352 layer_idx,
353 loading_isq,
354 preload_adapters,
355 )?;
356 let mlp = MLP::new(
357 cfg,
358 vb.pp("mlp"),
359 lora_config,
360 count,
361 ord,
362 mapper,
363 layer_idx,
364 loading_isq,
365 preload_adapters,
366 )?;
367 let input_layernorm = RmsNorm::new_gemma(
368 cfg.hidden_size,
369 cfg.rms_norm_eps,
370 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
371 )?;
372 let post_attention_layernorm = RmsNorm::new_gemma(
373 cfg.hidden_size,
374 cfg.rms_norm_eps,
375 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
376 )?;
377 Ok(Self {
378 self_attn,
379 mlp,
380 input_layernorm,
381 post_attention_layernorm,
382 })
383 }
384
385 #[allow(clippy::too_many_arguments)]
386 fn forward(
387 &self,
388 xs: &Tensor,
389 attention_mask: Option<&Tensor>,
390 seqlen_offsets: &[usize],
391 kv_cache: &mut Option<(Tensor, Tensor)>,
392 scalings: Option<Tensor>,
393 global_scaling_weight: f64,
394 is_scaling_pass: Option<f64>,
395 flash_params: &FlashParams,
396 ) -> Result<Tensor> {
397 let residual = xs;
398 let xs = self.input_layernorm.forward(xs)?;
399 let xs = self.self_attn.forward(
400 &xs,
401 attention_mask,
402 seqlen_offsets,
403 kv_cache,
404 scalings.clone(),
405 global_scaling_weight,
406 is_scaling_pass,
407 flash_params,
408 )?;
409 let xs = (xs + residual)?;
410 let residual = &xs;
411 let xs = self.mlp.forward(
412 &xs.apply(&self.post_attention_layernorm)?,
413 scalings.clone(),
414 global_scaling_weight,
415 is_scaling_pass,
416 )?;
417 residual + xs
418 }
419}
420
421pub struct XLoraModel {
422 embed_tokens: candle_nn::Embedding,
423 layers: Vec<DecoderLayer>,
424 norm: RmsNorm,
425 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
426 dtype: DType,
427 hidden_size: usize,
428 device: Device,
429 cache: EitherCache,
430 max_seq_len: usize,
431 xlora_classifier: Option<XLoraClassifier>,
432 mapper: Box<dyn DeviceMapper + Send + Sync>,
433 cfg: ModelConfigMetadata,
434}
435
436impl XLoraModel {
437 #[allow(clippy::too_many_arguments)]
438 pub fn new(
439 cfg: &Config,
440 vb: ShardedVarBuilder,
441 lora_config: &[((String, String), LoraConfig)],
442 xlora_config: Option<XLoraConfig>,
443 xlora_ordering: Ordering,
444 is_gptx: bool,
445 normal_loading_metadata: NormalLoadingMetadata,
446 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
447 ) -> Result<Self> {
448 if let Some(ref quant_cfg) = &cfg.quantization_config {
449 tracing::info!(
450 "Using {} quantization: {}.",
451 quant_cfg.name(),
452 quant_cfg.get_bits_name(&vb)
453 );
454 }
455 let mapper = normal_loading_metadata.mapper;
456 let vb_m = vb.pp("model");
457
458 let embed_tokens = layers::embedding(
459 cfg.vocab_size,
460 cfg.hidden_size,
461 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
462 &cfg.quantization_config,
463 )?;
464 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
465 let vb_l = vb_m.pp("layers");
466 let mut ropes = HashMap::new();
467 for layer_idx in 0..cfg.num_hidden_layers {
468 let device = mapper
469 .device_for(layer_idx, false)
470 .unwrap_or(&normal_loading_metadata.real_device);
471 ropes.insert(
472 device.location(),
473 Arc::new(RotaryEmbedding::new(
474 cfg.rope_theta as f32,
475 cfg.head_dim,
476 cfg.max_position_embeddings,
477 device,
478 is_gptx,
479 vb.dtype(),
480 )?),
481 );
482 }
483
484 let mut count = 0;
485 for layer_idx in NiceProgressBar::<_, 'b'>(
486 0..cfg.num_hidden_layers,
487 "Loading repeating layers",
488 &normal_loading_metadata.multi_progress,
489 ) {
490 let device = mapper
491 .device_for(layer_idx, false)
492 .unwrap_or(&normal_loading_metadata.real_device);
493 let rotary_emb = ropes
494 .get(&device.location())
495 .expect("No RoPE for device location!")
496 .clone();
497 let layer = DecoderLayer::new(
498 rotary_emb.clone(),
499 cfg,
500 vb_l.pp(layer_idx),
501 lora_config,
502 &mut count,
503 &xlora_ordering,
504 &*mapper,
505 layer_idx,
506 normal_loading_metadata.loading_isq,
507 preload_adapters,
508 )?;
509 layers.push(layer)
510 }
511 if xlora_config.is_none() && preload_adapters.is_none() {
512 info!("Merging LoRA adapters.");
514 for layer in layers.iter_mut().tqdm() {
515 Arc::get_mut(&mut layer.self_attn.k_proj)
516 .unwrap()
517 .merge_weights()?;
518 Arc::get_mut(&mut layer.self_attn.o_proj)
519 .unwrap()
520 .merge_weights()?;
521 Arc::get_mut(&mut layer.self_attn.q_proj)
522 .unwrap()
523 .merge_weights()?;
524 Arc::get_mut(&mut layer.self_attn.v_proj)
525 .unwrap()
526 .merge_weights()?;
527
528 Arc::get_mut(&mut layer.mlp.down_proj)
529 .unwrap()
530 .merge_weights()?;
531 Arc::get_mut(&mut layer.mlp.gate_proj)
532 .unwrap()
533 .merge_weights()?;
534 Arc::get_mut(&mut layer.mlp.up_proj)
535 .unwrap()
536 .merge_weights()?;
537 }
538 }
539 let norm = RmsNorm::new_gemma(
540 cfg.hidden_size,
541 cfg.rms_norm_eps,
542 mapper.set_nm_device(vb_m.pp("norm"), false),
543 )?;
544 let lm_head = linear(
545 embed_tokens.embeddings().dim(1)?,
546 embed_tokens.embeddings().dim(0)?,
547 false,
548 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
549 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
550 lora_config,
551 &mut count,
552 &xlora_ordering,
553 preload_adapters,
554 )?;
555 if xlora_config.is_some() && lm_head.is_lora() {
556 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
558 }
559
560 Ok(Self {
561 embed_tokens,
562 layers,
563 norm,
564 lm_head,
565 device: normal_loading_metadata.real_device,
566 dtype: vb.dtype(),
567 hidden_size: cfg.hidden_size,
568 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
569 max_seq_len: default_max_position_embeddings(),
570 xlora_classifier: xlora_config.map(|xlora_config| {
571 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
572 }),
573 mapper,
574 cfg: ModelConfigMetadata {
575 max_seq_len: cfg.max_position_embeddings,
576 num_layers: cfg.num_hidden_layers,
577 hidden_size: cfg.hidden_size,
578 num_kv_heads: cfg.num_key_value_heads,
579 num_attn_heads: cfg.num_attention_heads,
580 sliding_window: None,
581 k_head_dim: cfg.head_dim,
582 v_head_dim: cfg.head_dim,
583 },
584 })
585 }
586
587 #[allow(clippy::too_many_arguments)]
588 fn inner_forward(
589 &self,
590 input_ids: &Tensor,
591 seqlen_offsets: &[usize],
592 scalings: Option<Tensor>,
593 is_full_pass: bool,
594 no_kv_cache: bool,
595 is_scaling_pass: Option<f64>,
596 flash_params: &FlashParams,
597 ) -> Result<Tensor> {
598 let mut cache = if is_full_pass {
599 if no_kv_cache {
600 let mut new_cache = Vec::new();
601 for _ in 0..self.cache.full().xlora_lock().len() {
602 new_cache.push(None);
603 }
604
605 self.cache.full().xlora_lock().clone_from(&new_cache);
606 }
607 self.cache.full().xlora_lock()
608 } else {
609 self.cache.full().lock()
610 };
611 let xs = self.embed_tokens.forward(input_ids)?;
612 let attention_mask = CausalMasker.make_causal_mask_matrix(
613 input_ids,
614 &*cache,
615 xs.dtype(),
616 self.cfg.num_attn_heads,
617 )?;
618 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
619 for (i, layer) in self.layers.iter().enumerate() {
620 xs = self.mapper.map(xs, i)?;
621 xs = layer.forward(
622 &xs,
623 attention_mask
624 .as_ref()
625 .map(|m| m.to_device(xs.device()).unwrap())
626 .as_ref(),
627 seqlen_offsets,
628 &mut cache[i],
629 scalings.clone(),
630 self.xlora_classifier
631 .as_ref()
632 .map(|classifier| classifier.get_global_scaling_weight())
633 .unwrap_or(1.0),
634 is_scaling_pass,
635 flash_params,
636 )?
637 }
638 let xs = xs.to_device(&self.device)?;
639 xs.apply(&self.norm)
640 }
641
642 #[allow(clippy::too_many_arguments)]
643 pub fn forward(
644 &self,
645 input_ids: &Tensor,
646 input_ids_full: &Tensor,
647 seqlen_offsets: &[usize],
648 seqlen_offsets_full: &[usize],
649 no_kv_cache: bool,
650 non_granular_state: &Option<NonGranularState>,
651 context_lens: Vec<(usize, usize)>,
652 flash_params: &FlashParams,
653 flash_params_full: &FlashParams,
654 ) -> Result<Tensor> {
655 if self.xlora_classifier.is_some() {
656 let scalings = self.get_scalings(
657 input_ids,
658 input_ids_full,
659 seqlen_offsets,
660 seqlen_offsets_full,
661 no_kv_cache,
662 non_granular_state,
663 &vec![usize::MAX; context_lens.len()],
664 flash_params,
665 flash_params_full,
666 )?;
667
668 if no_kv_cache {
669 let mut res = self
670 .inner_forward(
671 input_ids_full,
672 seqlen_offsets_full,
673 Some(scalings),
674 true,
675 no_kv_cache,
676 None,
677 flash_params_full,
678 )?
679 .contiguous()?;
680 if let Some(t) = self.lm_head.quantized_act_type() {
681 res = res.to_dtype(t)?;
682 }
683 extract_logits(
684 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
685 context_lens,
686 )
687 } else {
688 let mut res = self
690 .inner_forward(
691 input_ids,
692 seqlen_offsets,
693 Some(scalings),
694 true,
695 no_kv_cache,
696 None,
697 flash_params,
698 )?
699 .contiguous()?;
700 if let Some(t) = self.lm_head.quantized_act_type() {
701 res = res.to_dtype(t)?;
702 }
703 extract_logits(
704 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
705 context_lens,
706 )
707 }
708 } else {
709 let mut res = self
710 .inner_forward(
711 input_ids,
712 seqlen_offsets,
713 None,
714 false,
715 no_kv_cache,
716 None,
717 flash_params,
718 )?
719 .contiguous()?;
720 if let Some(t) = self.lm_head.quantized_act_type() {
721 res = res.to_dtype(t)?;
722 }
723 extract_logits(
724 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
725 context_lens,
726 )
727 }
728 }
729}
730
731impl IsqModel for XLoraModel {
732 fn get_layers(
733 &mut self,
734 ) -> (
735 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
736 &dyn DeviceMapper,
737 ) {
738 let mut tensors = Vec::new();
739 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
740 for (i, layer) in self.layers.iter_mut().enumerate() {
741 tensors.push((
742 Arc::get_mut(&mut layer.self_attn.q_proj)
743 .unwrap()
744 .quant_inner(),
745 Some(i),
746 ));
747 tensors.push((
748 Arc::get_mut(&mut layer.self_attn.k_proj)
749 .unwrap()
750 .quant_inner(),
751 Some(i),
752 ));
753 tensors.push((
754 Arc::get_mut(&mut layer.self_attn.v_proj)
755 .unwrap()
756 .quant_inner(),
757 Some(i),
758 ));
759 tensors.push((
760 Arc::get_mut(&mut layer.self_attn.o_proj)
761 .unwrap()
762 .quant_inner(),
763 Some(i),
764 ));
765 tensors.push((
766 Arc::get_mut(&mut layer.mlp.gate_proj)
767 .unwrap()
768 .quant_inner(),
769 Some(i),
770 ));
771 tensors.push((
772 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
773 Some(i),
774 ));
775 tensors.push((
776 Arc::get_mut(&mut layer.mlp.down_proj)
777 .unwrap()
778 .quant_inner(),
779 Some(i),
780 ));
781 }
782 (tensors, &*self.mapper)
783 }
784
785 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
786 panic!("Cannot generate UQFF for an adapter model.")
787 }
788}
789
790impl NormalModel for XLoraModel {
791 fn forward(
792 &self,
793 _input_ids: &Tensor,
794 _seqlen_offsets: &[usize],
795 _context_lens: Vec<(usize, usize)>,
796 _position_ids: Vec<usize>,
797 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
798 _flash_params: &FlashParams,
799 ) -> Result<Tensor> {
800 unreachable!()
801 }
802 fn xlora_forward(
803 &self,
804 input_ids: &Tensor,
805 input_ids_full: &Tensor,
806 seqlen_offsets: &[usize],
807 seqlen_offsets_full: &[usize],
808 no_kv_cache: bool,
809 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
810 context_lens: Vec<(usize, usize)>,
811 _position_ids: Vec<usize>,
812 flash_params: &FlashParams,
813 flash_params_full: &FlashParams,
814 ) -> Result<Tensor> {
815 self.forward(
816 input_ids,
817 input_ids_full,
818 seqlen_offsets,
819 seqlen_offsets_full,
820 no_kv_cache,
821 non_granular_state,
822 context_lens,
823 flash_params,
824 flash_params_full,
825 )
826 }
827 fn cache(&self) -> &EitherCache {
828 &self.cache
829 }
830 fn cache_mut(&mut self) -> &mut EitherCache {
831 &mut self.cache
832 }
833 fn device(&self) -> &Device {
834 &self.device
835 }
836 fn is_xlora(&self) -> bool {
837 true
838 }
839 fn max_seq_len(&self) -> usize {
840 self.max_seq_len
841 }
842 fn config(&self) -> &ModelConfigMetadata {
843 &self.cfg
844 }
845}
846
847impl ScalingsMaker for XLoraModel {
848 fn dtype(&self) -> DType {
849 self.dtype
850 }
851 fn get_cache(&self) -> &EitherCache {
852 &self.cache
853 }
854 fn get_classifier(&self) -> &XLoraClassifier {
855 self.xlora_classifier.as_ref().unwrap()
856 }
857 fn forward(
858 &self,
859 input_ids: &Tensor,
860 seqlen_offsets: &[usize],
861 scalings: Tensor,
862 is_full_pass: bool,
863 no_kv_cache: bool,
864 is_scaling_pass: Option<f64>,
865 _context_lens: &[usize],
866 flash_params: &FlashParams,
867 ) -> Result<Tensor> {
868 self.inner_forward(
869 input_ids,
870 seqlen_offsets,
871 Some(scalings),
872 is_full_pass,
873 no_kv_cache,
874 is_scaling_pass,
875 flash_params,
876 )
877 }
878}
879
880impl AnyMoeBaseModelMixin for XLoraModel {}