1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{self, RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel, NormalLoadingMetadata,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Module, Result, Tensor};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::{Activation, CausalMasker, RmsNorm},
25 models::mistral::Config,
26 pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker};
30
31#[derive(Clone)]
32#[allow(clippy::upper_case_acronyms)]
33struct MLP {
34 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 act_fn: Activation,
38}
39
40impl MLP {
41 #[allow(clippy::too_many_arguments)]
42 fn new(
43 cfg: &Config,
44 vb: ShardedVarBuilder,
45 lora_config: &[((String, String), LoraConfig)],
46 count: &mut usize,
47 ord: &Ordering,
48 mapper: &dyn DeviceMapper,
49 layer_idx: usize,
50 loading_isq: bool,
51 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
52 ) -> Result<Self> {
53 let hidden_sz = cfg.hidden_size;
54 let intermediate_sz = cfg.intermediate_size;
55 let gate_proj = linear_no_bias(
56 hidden_sz,
57 intermediate_sz,
58 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
59 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
60 lora_config,
61 count,
62 ord,
63 preload_adapters,
64 )?;
65 let up_proj = linear_no_bias(
66 hidden_sz,
67 intermediate_sz,
68 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70 lora_config,
71 count,
72 ord,
73 preload_adapters,
74 )?;
75 let down_proj = linear_no_bias(
76 intermediate_sz,
77 hidden_sz,
78 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
79 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
80 lora_config,
81 count,
82 ord,
83 preload_adapters,
84 )?;
85 Ok(Self {
86 gate_proj,
87 up_proj,
88 down_proj,
89 act_fn: cfg.hidden_act,
90 })
91 }
92
93 fn forward(
94 &self,
95 xs: &Tensor,
96 scalings: Option<Tensor>,
97 global_scaling_weight: f64,
98 is_scaling_pass: Option<f64>,
99 ) -> Result<Tensor> {
100 let original_dtype = xs.dtype();
101 let mut xs = xs.clone();
102 if let Some(t) = self.gate_proj.quantized_act_type() {
103 xs = xs.to_dtype(t)?;
104 }
105 let lhs = self
106 .gate_proj
107 .lora_forward(
108 &xs,
109 scalings.clone(),
110 global_scaling_weight,
111 is_scaling_pass,
112 )?
113 .apply(&self.act_fn)?;
114 let rhs = self.up_proj.lora_forward(
115 &xs,
116 scalings.clone(),
117 global_scaling_weight,
118 is_scaling_pass,
119 )?;
120 let mut res = self.down_proj.lora_forward(
121 &(lhs * rhs)?,
122 scalings,
123 global_scaling_weight,
124 is_scaling_pass,
125 )?;
126 if self.gate_proj.quantized_act_type().is_some() {
127 res = res.to_dtype(original_dtype)?;
128 }
129 Ok(res)
130 }
131}
132
133struct Attention {
134 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
135 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138 num_heads: usize,
139 num_kv_heads: usize,
140 head_dim: usize,
141 rotary_emb: Arc<RotaryEmbedding>,
142 sliding_window: Option<usize>,
143 sdpa_params: SdpaParams,
144}
145
146impl Attention {
147 #[allow(clippy::too_many_arguments)]
148 fn new(
149 rotary_emb: Arc<RotaryEmbedding>,
150 cfg: &Config,
151 vb: ShardedVarBuilder,
152 lora_config: &[((String, String), LoraConfig)],
153 count: &mut usize,
154 ord: &Ordering,
155 mapper: &dyn DeviceMapper,
156 layer_idx: usize,
157 loading_isq: bool,
158 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
159 ) -> Result<Self> {
160 let hidden_sz = cfg.hidden_size;
161 let num_heads = cfg.num_attention_heads;
162 let num_kv_heads = cfg.num_key_value_heads;
163 let head_dim = cfg.head_dim();
164 let q_proj = linear_no_bias(
165 hidden_sz,
166 num_heads * head_dim,
167 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
168 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
169 lora_config,
170 count,
171 ord,
172 preload_adapters,
173 )?;
174 let k_proj = linear_no_bias(
175 hidden_sz,
176 num_kv_heads * head_dim,
177 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
178 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
179 lora_config,
180 count,
181 ord,
182 preload_adapters,
183 )?;
184 let v_proj = linear_no_bias(
185 hidden_sz,
186 num_kv_heads * head_dim,
187 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
188 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
189 lora_config,
190 count,
191 ord,
192 preload_adapters,
193 )?;
194 let o_proj = linear_no_bias(
195 num_heads * head_dim,
196 hidden_sz,
197 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
198 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
199 lora_config,
200 count,
201 ord,
202 preload_adapters,
203 )?;
204 Ok(Self {
205 q_proj,
206 k_proj,
207 v_proj,
208 o_proj,
209 num_heads,
210 num_kv_heads,
211 head_dim,
212 rotary_emb,
213 sliding_window: cfg.sliding_window,
214 sdpa_params: SdpaParams {
215 n_kv_groups: num_heads / num_kv_heads,
216 use_flash_attn: cfg.use_flash_attn,
217 softcap: None,
218 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
219 sliding_window: cfg.sliding_window,
220 },
221 })
222 }
223
224 #[allow(clippy::too_many_arguments)]
225 fn forward(
226 &self,
227 xs: &Tensor,
228 attention_mask: Option<&Tensor>,
229 seqlen_offsets: &[usize],
230 kv_cache: &mut Option<(Tensor, Tensor)>,
231 scalings: Option<Tensor>,
232 global_scaling_weight: f64,
233 is_scaling_pass: Option<f64>,
234 flash_params: &FlashParams,
235 ) -> Result<Tensor> {
236 let (b_sz, q_len, _) = xs.dims3()?;
237
238 let original_dtype = xs.dtype();
239 let mut xs = xs.clone();
240 if let Some(t) = self.q_proj.quantized_act_type() {
241 xs = xs.to_dtype(t)?;
242 }
243 let mut q = self.q_proj.lora_forward(
244 &xs,
245 scalings.clone(),
246 global_scaling_weight,
247 is_scaling_pass,
248 )?;
249 let mut k = self.k_proj.lora_forward(
250 &xs,
251 scalings.clone(),
252 global_scaling_weight,
253 is_scaling_pass,
254 )?;
255 let mut v = self.v_proj.lora_forward(
256 &xs,
257 scalings.clone(),
258 global_scaling_weight,
259 is_scaling_pass,
260 )?;
261 if self.q_proj.quantized_act_type().is_some() {
262 q = q.to_dtype(original_dtype)?;
263 k = k.to_dtype(original_dtype)?;
264 v = v.to_dtype(original_dtype)?;
265 }
266
267 let (q, k, v) = if q_len != 1 {
268 let q = q
269 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
270 .transpose(1, 2)?;
271 let k = k
272 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
273 .transpose(1, 2)?;
274 let v = v
275 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
276 .transpose(1, 2)?;
277 (q, k, v)
278 } else {
279 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
280 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
281 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
282 (q, k, v)
283 };
284
285 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
286
287 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
288 kv_cache,
289 k,
290 v,
291 attention_mask,
292 self.sliding_window,
293 false,
294 )?;
295
296 let mut attn_output = Sdpa.run_attention(
297 &q,
298 &k,
299 &v,
300 attn_mask.as_ref(),
301 Some(flash_params),
302 &self.sdpa_params,
303 )?;
304
305 if let Some(t) = self.q_proj.quantized_act_type() {
306 attn_output = attn_output.to_dtype(t)?;
307 }
308 let mut res = self.o_proj.lora_forward(
309 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
310 scalings.clone(),
311 global_scaling_weight,
312 is_scaling_pass,
313 )?;
314 if self.q_proj.quantized_act_type().is_some() {
315 res = res.to_dtype(original_dtype)?;
316 }
317 Ok(res)
318 }
319}
320
321struct DecoderLayer {
322 self_attn: Attention,
323 mlp: MLP,
324 input_layernorm: RmsNorm,
325 post_attention_layernorm: RmsNorm,
326}
327
328impl DecoderLayer {
329 #[allow(clippy::too_many_arguments)]
330 fn new(
331 rotary_emb: Arc<RotaryEmbedding>,
332 cfg: &Config,
333 vb: ShardedVarBuilder,
334 lora_config: &[((String, String), LoraConfig)],
335 count: &mut usize,
336 ord: &Ordering,
337 mapper: &dyn DeviceMapper,
338 layer_idx: usize,
339 loading_isq: bool,
340 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341 ) -> Result<Self> {
342 let self_attn = Attention::new(
343 rotary_emb,
344 cfg,
345 vb.pp("self_attn"),
346 lora_config,
347 count,
348 ord,
349 mapper,
350 layer_idx,
351 loading_isq,
352 preload_adapters,
353 )?;
354 let mlp = MLP::new(
355 cfg,
356 vb.pp("mlp"),
357 lora_config,
358 count,
359 ord,
360 mapper,
361 layer_idx,
362 loading_isq,
363 preload_adapters,
364 )?;
365 let input_layernorm = RmsNorm::new(
366 cfg.hidden_size,
367 cfg.rms_norm_eps,
368 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
369 )?;
370 let post_attention_layernorm = RmsNorm::new(
371 cfg.hidden_size,
372 cfg.rms_norm_eps,
373 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
374 )?;
375 Ok(Self {
376 self_attn,
377 mlp,
378 input_layernorm,
379 post_attention_layernorm,
380 })
381 }
382
383 #[allow(clippy::too_many_arguments)]
384 fn forward(
385 &self,
386 xs: &Tensor,
387 attention_mask: Option<&Tensor>,
388 seqlen_offsets: &[usize],
389 kv_cache: &mut Option<(Tensor, Tensor)>,
390 scalings: Option<Tensor>,
391 global_scaling_weight: f64,
392 is_scaling_pass: Option<f64>,
393 flash_params: &FlashParams,
394 ) -> Result<Tensor> {
395 let residual = xs;
396 let xs = self.input_layernorm.forward(xs)?;
397 let xs = self.self_attn.forward(
398 &xs,
399 attention_mask,
400 seqlen_offsets,
401 kv_cache,
402 scalings.clone(),
403 global_scaling_weight,
404 is_scaling_pass,
405 flash_params,
406 )?;
407 let xs = (xs + residual)?;
408 let residual = &xs;
409 let xs = self.mlp.forward(
410 &xs.apply(&self.post_attention_layernorm)?,
411 scalings,
412 global_scaling_weight,
413 is_scaling_pass,
414 )?;
415 residual + xs
416 }
417}
418
419pub struct XLoraModel {
420 embed_tokens: candle_nn::Embedding,
421 layers: Vec<DecoderLayer>,
422 norm: RmsNorm,
423 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
424 sliding_window: Option<usize>,
425 dtype: DType,
426 device: Device,
427 cache: EitherCache,
428 max_seq_len: usize,
429 xlora_classifier: Option<XLoraClassifier>,
430 mapper: Box<dyn DeviceMapper + Send + Sync>,
431 cfg: ModelConfigMetadata,
432}
433
434impl XLoraModel {
435 #[allow(clippy::too_many_arguments)]
436 pub fn new(
437 cfg: &Config,
438 vb: ShardedVarBuilder,
439 lora_config: &[((String, String), LoraConfig)],
440 xlora_config: Option<XLoraConfig>,
441 xlora_ordering: Ordering,
442 is_gptx: bool,
443 normal_loading_metadata: NormalLoadingMetadata,
444 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
445 ) -> Result<Self> {
446 if let Some(ref quant_cfg) = &cfg.quantization_config {
447 tracing::info!(
448 "Using {} quantization: {}.",
449 quant_cfg.name(),
450 quant_cfg.get_bits_name(&vb)
451 );
452 }
453 let mapper = normal_loading_metadata.mapper;
454
455 let vb_m = vb.pp("model");
456 let embed_tokens = layers::embedding(
457 cfg.vocab_size,
458 cfg.hidden_size,
459 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
460 &cfg.quantization_config,
461 )?;
462 let head_dim = cfg.head_dim();
463 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
464 let vb_l = vb_m.pp("layers");
465 let mut ropes = HashMap::new();
466 for layer_idx in 0..cfg.num_hidden_layers {
467 let device = mapper
468 .device_for(layer_idx, false)
469 .unwrap_or(&normal_loading_metadata.real_device);
470 ropes.insert(
471 device.location(),
472 Arc::new(RotaryEmbedding::new(
473 cfg.rope_theta as f32,
474 head_dim,
475 cfg.max_position_embeddings,
476 device,
477 is_gptx,
478 vb_m.dtype(),
479 )?),
480 );
481 }
482 let mut count = 0;
483 for layer_idx in NiceProgressBar::<_, 'b'>(
484 0..cfg.num_hidden_layers,
485 "Loading repeating layers",
486 &normal_loading_metadata.multi_progress,
487 ) {
488 let device = mapper
489 .device_for(layer_idx, false)
490 .unwrap_or(&normal_loading_metadata.real_device);
491 let rotary_emb = ropes
492 .get(&device.location())
493 .expect("No RoPE for device location!")
494 .clone();
495 let layer = DecoderLayer::new(
496 rotary_emb.clone(),
497 cfg,
498 vb_l.pp(layer_idx),
499 lora_config,
500 &mut count,
501 &xlora_ordering,
502 &*mapper,
503 layer_idx,
504 normal_loading_metadata.loading_isq,
505 preload_adapters,
506 )?;
507 layers.push(layer)
508 }
509 if xlora_config.is_none() && preload_adapters.is_none() {
510 info!("Merging LoRA adapters.");
512 for layer in layers.iter_mut().tqdm() {
513 Arc::get_mut(&mut layer.self_attn.k_proj)
514 .unwrap()
515 .merge_weights()?;
516 Arc::get_mut(&mut layer.self_attn.o_proj)
517 .unwrap()
518 .merge_weights()?;
519 Arc::get_mut(&mut layer.self_attn.q_proj)
520 .unwrap()
521 .merge_weights()?;
522 Arc::get_mut(&mut layer.self_attn.v_proj)
523 .unwrap()
524 .merge_weights()?;
525
526 Arc::get_mut(&mut layer.mlp.down_proj)
527 .unwrap()
528 .merge_weights()?;
529 Arc::get_mut(&mut layer.mlp.gate_proj)
530 .unwrap()
531 .merge_weights()?;
532 Arc::get_mut(&mut layer.mlp.up_proj)
533 .unwrap()
534 .merge_weights()?;
535 }
536 }
537 let norm = RmsNorm::new(
538 cfg.hidden_size,
539 cfg.rms_norm_eps,
540 mapper.set_nm_device(vb_m.pp("norm"), false),
541 )?;
542 let lm_head = linear_no_bias(
543 cfg.hidden_size,
544 cfg.vocab_size,
545 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
546 mapper.set_nm_device(vb.pp("lm_head"), false),
547 lora_config,
548 &mut count,
549 &xlora_ordering,
550 preload_adapters,
551 )?;
552 if xlora_config.is_some() && lm_head.is_lora() {
553 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
555 }
556 Ok(Self {
557 embed_tokens,
558 layers,
559 norm,
560 lm_head,
561 sliding_window: cfg.sliding_window,
562 device: normal_loading_metadata.real_device,
563 dtype: vb.dtype(),
564 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
565 max_seq_len: cfg.max_position_embeddings,
566 xlora_classifier: xlora_config.map(|xlora_config| {
567 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
568 }),
569 mapper,
570 cfg: ModelConfigMetadata {
571 max_seq_len: cfg.max_position_embeddings,
572 num_layers: cfg.num_hidden_layers,
573 hidden_size: cfg.hidden_size,
574 num_kv_heads: cfg.num_key_value_heads,
575 num_attn_heads: cfg.num_attention_heads,
576 sliding_window: cfg.sliding_window,
577 k_head_dim: cfg.head_dim(),
578 v_head_dim: cfg.head_dim(),
579 },
580 })
581 }
582
583 #[allow(clippy::too_many_arguments)]
584 fn inner_forward(
585 &self,
586 input_ids: &Tensor,
587 seqlen_offsets: &[usize],
588 scalings: Option<Tensor>,
589 is_full_pass: bool,
590 no_kv_cache: bool,
591 is_scaling_pass: Option<f64>,
592 flash_params: &FlashParams,
593 ) -> Result<Tensor> {
594 let mut cache = if is_full_pass {
595 if no_kv_cache {
596 let mut new_cache = Vec::new();
597 for _ in 0..self.cache.full().xlora_lock().len() {
598 new_cache.push(None);
599 }
600
601 self.cache.full().xlora_lock().clone_from(&new_cache);
602 }
603 self.cache.full().xlora_lock()
604 } else {
605 self.cache.full().lock()
606 };
607 let mut xs = self.embed_tokens.forward(input_ids)?;
608 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
609 input_ids,
610 &*cache,
611 self.sliding_window,
612 xs.dtype(),
613 self.cfg.num_attn_heads,
614 )?;
615 for (i, layer) in self.layers.iter().enumerate() {
616 xs = self.mapper.map(xs, i)?;
617 xs = layer.forward(
618 &xs,
619 attention_mask
620 .as_ref()
621 .map(|m| m.to_device(xs.device()).unwrap())
622 .as_ref(),
623 seqlen_offsets,
624 &mut cache[i],
625 scalings.clone(),
626 self.xlora_classifier
627 .as_ref()
628 .map(|classifier| classifier.get_global_scaling_weight())
629 .unwrap_or(1.0),
630 is_scaling_pass,
631 flash_params,
632 )?
633 }
634 let xs = xs.to_device(&self.device)?;
635 xs.apply(&self.norm)
636 }
637
638 #[allow(clippy::too_many_arguments)]
639 pub fn forward(
640 &self,
641 input_ids: &Tensor,
642 input_ids_full: &Tensor,
643 seqlen_offsets: &[usize],
644 seqlen_offsets_full: &[usize],
645 no_kv_cache: bool,
646 non_granular_state: &Option<NonGranularState>,
647 context_lens: Vec<(usize, usize)>,
648 flash_params: &FlashParams,
649 flash_params_full: &FlashParams,
650 ) -> Result<Tensor> {
651 if self.xlora_classifier.is_some() {
652 let scalings = self.get_scalings(
653 input_ids,
654 input_ids_full,
655 seqlen_offsets,
656 seqlen_offsets_full,
657 no_kv_cache,
658 non_granular_state,
659 &vec![usize::MAX; context_lens.len()],
660 flash_params,
661 flash_params_full,
662 )?;
663
664 if no_kv_cache {
665 let mut res = self
666 .inner_forward(
667 input_ids_full,
668 seqlen_offsets_full,
669 Some(scalings),
670 true,
671 no_kv_cache,
672 None,
673 flash_params_full,
674 )?
675 .contiguous()?;
676 if let Some(t) = self.lm_head.quantized_act_type() {
677 res = res.to_dtype(t)?;
678 }
679 extract_logits(
680 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
681 context_lens,
682 )
683 } else {
684 let mut res = self
686 .inner_forward(
687 input_ids,
688 seqlen_offsets,
689 Some(scalings),
690 true,
691 no_kv_cache,
692 None,
693 flash_params,
694 )?
695 .contiguous()?;
696 if let Some(t) = self.lm_head.quantized_act_type() {
697 res = res.to_dtype(t)?;
698 }
699 extract_logits(
700 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
701 context_lens,
702 )
703 }
704 } else {
705 let mut res = self
706 .inner_forward(
707 input_ids,
708 seqlen_offsets,
709 None,
710 false,
711 no_kv_cache,
712 None,
713 flash_params,
714 )?
715 .contiguous()?;
716 if let Some(t) = self.lm_head.quantized_act_type() {
717 res = res.to_dtype(t)?;
718 }
719 extract_logits(
720 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
721 context_lens,
722 )
723 }
724 }
725}
726
727impl IsqModel for XLoraModel {
728 fn get_layers(
729 &mut self,
730 ) -> (
731 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
732 &dyn DeviceMapper,
733 ) {
734 let mut tensors = Vec::new();
735 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
736 for (i, layer) in self.layers.iter_mut().enumerate() {
737 tensors.push((
738 Arc::get_mut(&mut layer.self_attn.q_proj)
739 .unwrap()
740 .quant_inner(),
741 Some(i),
742 ));
743 tensors.push((
744 Arc::get_mut(&mut layer.self_attn.k_proj)
745 .unwrap()
746 .quant_inner(),
747 Some(i),
748 ));
749 tensors.push((
750 Arc::get_mut(&mut layer.self_attn.v_proj)
751 .unwrap()
752 .quant_inner(),
753 Some(i),
754 ));
755 tensors.push((
756 Arc::get_mut(&mut layer.self_attn.o_proj)
757 .unwrap()
758 .quant_inner(),
759 Some(i),
760 ));
761 tensors.push((
762 Arc::get_mut(&mut layer.mlp.gate_proj)
763 .unwrap()
764 .quant_inner(),
765 Some(i),
766 ));
767 tensors.push((
768 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
769 Some(i),
770 ));
771 tensors.push((
772 Arc::get_mut(&mut layer.mlp.down_proj)
773 .unwrap()
774 .quant_inner(),
775 Some(i),
776 ));
777 }
778 (tensors, &*self.mapper)
779 }
780
781 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
782 panic!("Cannot generate UQFF for an adapter model.")
783 }
784}
785
786impl NormalModel for XLoraModel {
787 fn forward(
788 &self,
789 _input_ids: &Tensor,
790 _seqlen_offsets: &[usize],
791 _context_lens: Vec<(usize, usize)>,
792 _position_ids: Vec<usize>,
793 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
794 _flash_params: &FlashParams,
795 ) -> Result<Tensor> {
796 unreachable!()
797 }
798 fn xlora_forward(
799 &self,
800 input_ids: &Tensor,
801 input_ids_full: &Tensor,
802 seqlen_offsets: &[usize],
803 seqlen_offsets_full: &[usize],
804 no_kv_cache: bool,
805 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
806 context_lens: Vec<(usize, usize)>,
807 _position_ids: Vec<usize>,
808 flash_params: &FlashParams,
809 flash_params_full: &FlashParams,
810 ) -> Result<Tensor> {
811 self.forward(
812 input_ids,
813 input_ids_full,
814 seqlen_offsets,
815 seqlen_offsets_full,
816 no_kv_cache,
817 non_granular_state,
818 context_lens,
819 flash_params,
820 flash_params_full,
821 )
822 }
823 fn cache(&self) -> &EitherCache {
824 &self.cache
825 }
826 fn cache_mut(&mut self) -> &mut EitherCache {
827 &mut self.cache
828 }
829 fn device(&self) -> &Device {
830 &self.device
831 }
832 fn is_xlora(&self) -> bool {
833 true
834 }
835 fn max_seq_len(&self) -> usize {
836 self.max_seq_len
837 }
838 fn config(&self) -> &ModelConfigMetadata {
839 &self.cfg
840 }
841}
842
843impl ScalingsMaker for XLoraModel {
844 fn dtype(&self) -> DType {
845 self.dtype
846 }
847 fn get_cache(&self) -> &EitherCache {
848 &self.cache
849 }
850 fn get_classifier(&self) -> &XLoraClassifier {
851 self.xlora_classifier.as_ref().unwrap()
852 }
853 fn forward(
854 &self,
855 input_ids: &Tensor,
856 seqlen_offsets: &[usize],
857 scalings: Tensor,
858 is_full_pass: bool,
859 no_kv_cache: bool,
860 is_scaling_pass: Option<f64>,
861 _context_lens: &[usize],
862 flash_params: &FlashParams,
863 ) -> Result<Tensor> {
864 self.inner_forward(
865 input_ids,
866 seqlen_offsets,
867 Some(scalings),
868 is_full_pass,
869 no_kv_cache,
870 is_scaling_pass,
871 flash_params,
872 )
873 }
874}
875
876impl AnyMoeBaseModelMixin for XLoraModel {}