1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{Llama3RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::{embedding, CausalMasker, RmsNorm},
25 models::llama::Config,
26 pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31struct CausalSelfAttention {
32 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 num_attention_heads: usize,
37 num_key_value_heads: usize,
38 head_dim: usize,
39 rotary_emb: Arc<Llama3RotaryEmbedding>,
40 max_seq_len: usize,
41 sdpa_params: SdpaParams,
42}
43
44impl CausalSelfAttention {
45 #[allow(clippy::too_many_arguments)]
46 fn forward(
47 &self,
48 x: &Tensor,
49 mask: &Option<Tensor>,
50 seqlen_offsets: &[usize],
51 block_idx: usize,
52 kv_cache: &mut LayerCaches,
53 scalings: Option<Tensor>,
54 global_scaling_weight: f64,
55 is_scaling_pass: Option<f64>,
56 flash_params: &FlashParams,
57 ) -> Result<Tensor> {
58 let (b_sz, seq_len, hidden_size) = x.dims3()?;
59
60 let original_dtype = x.dtype();
61 let mut x = x.clone();
62 if let Some(t) = self.q_proj.quantized_act_type() {
63 x = x.to_dtype(t)?;
64 }
65 let mut q = self.q_proj.lora_forward(
66 &x,
67 scalings.clone(),
68 global_scaling_weight,
69 is_scaling_pass,
70 )?;
71 let mut k = self.k_proj.lora_forward(
72 &x,
73 scalings.clone(),
74 global_scaling_weight,
75 is_scaling_pass,
76 )?;
77 let mut v = self.v_proj.lora_forward(
78 &x,
79 scalings.clone(),
80 global_scaling_weight,
81 is_scaling_pass,
82 )?;
83 if self.q_proj.quantized_act_type().is_some() {
84 q = q.to_dtype(original_dtype)?;
85 k = k.to_dtype(original_dtype)?;
86 v = v.to_dtype(original_dtype)?;
87 }
88
89 let (q, k, v) = if seq_len != 1 {
90 let q = q
91 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
92 .transpose(1, 2)?;
93 let k = k
94 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
95 .transpose(1, 2)?;
96 let v = v
97 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
98 .transpose(1, 2)?;
99 (q, k, v)
100 } else {
101 let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
102 let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
103 let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
104 (q, k, v)
105 };
106
107 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
108
109 let (k, v) = crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v)?;
110
111 let y = Sdpa.run_attention(
112 &q,
113 &k,
114 &v,
115 mask.clone().as_ref(),
116 Some(flash_params),
117 &self.sdpa_params,
118 )?;
119
120 let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
121 if let Some(t) = self.q_proj.quantized_act_type() {
122 y = y.to_dtype(t)?;
123 }
124 let mut res = self.o_proj.lora_forward(
125 &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?,
126 scalings.clone(),
127 global_scaling_weight,
128 is_scaling_pass,
129 )?;
130 if self.q_proj.quantized_act_type().is_some() {
131 res = res.to_dtype(original_dtype)?;
132 }
133 Ok(res)
134 }
135
136 #[allow(clippy::too_many_arguments)]
137 fn load(
138 vb: ShardedVarBuilder,
139 cfg: &Config,
140 lora_config: &[((String, String), LoraConfig)],
141 count: &mut usize,
142 ord: &Ordering,
143 mapper: &dyn DeviceMapper,
144 layer_idx: usize,
145 loading_isq: bool,
146 rope: Arc<Llama3RotaryEmbedding>,
147 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
148 ) -> Result<Self> {
149 let size_in = cfg.hidden_size;
150 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
151 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
152 let q_proj = linear(
153 size_in,
154 size_q,
155 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
156 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
157 lora_config,
158 count,
159 ord,
160 preload_adapters,
161 )?;
162 let k_proj = linear(
163 size_in,
164 size_kv,
165 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
166 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
167 lora_config,
168 count,
169 ord,
170 preload_adapters,
171 )?;
172 let v_proj = linear(
173 size_in,
174 size_kv,
175 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
176 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
177 lora_config,
178 count,
179 ord,
180 preload_adapters,
181 )?;
182 let o_proj = linear(
183 size_q,
184 size_in,
185 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
186 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
187 lora_config,
188 count,
189 ord,
190 preload_adapters,
191 )?;
192 Ok(Self {
193 q_proj,
194 k_proj,
195 v_proj,
196 o_proj,
197 num_attention_heads: cfg.num_attention_heads,
198 num_key_value_heads: cfg.num_key_value_heads,
199 head_dim: cfg.hidden_size / cfg.num_attention_heads,
200 rotary_emb: rope,
201 max_seq_len: cfg.max_position_embeddings,
202 sdpa_params: SdpaParams {
203 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
204 softcap: None,
205 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
206 sliding_window: None,
207 },
208 })
209 }
210}
211
212#[derive(Clone)]
213struct Mlp {
214 c_fc1: Arc<dyn LinearLayerLike + Send + Sync>,
215 c_fc2: Arc<dyn LinearLayerLike + Send + Sync>,
216 c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
217}
218
219impl Mlp {
220 fn forward(
221 &self,
222 x: &Tensor,
223 scalings: Option<Tensor>,
224 global_scaling_weight: f64,
225 is_scaling_pass: Option<f64>,
226 ) -> Result<Tensor> {
227 let original_dtype = x.dtype();
228 let mut x = x.clone();
229 if let Some(t) = self.c_fc1.quantized_act_type() {
230 x = x.to_dtype(t)?;
231 }
232 let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward(
233 &x,
234 scalings.clone(),
235 global_scaling_weight,
236 is_scaling_pass,
237 )?)? * self.c_fc2.lora_forward(
238 &x,
239 scalings.clone(),
240 global_scaling_weight,
241 is_scaling_pass,
242 )?)?;
243 let mut res = self.c_proj.lora_forward(
244 &x,
245 scalings.clone(),
246 global_scaling_weight,
247 is_scaling_pass,
248 )?;
249 if self.c_fc1.quantized_act_type().is_some() {
250 res = res.to_dtype(original_dtype)?;
251 }
252 Ok(res)
253 }
254
255 #[allow(clippy::too_many_arguments)]
256 fn load(
257 vb: ShardedVarBuilder,
258 cfg: &Config,
259 lora_config: &[((String, String), LoraConfig)],
260 count: &mut usize,
261 ord: &Ordering,
262 mapper: &dyn DeviceMapper,
263 layer_idx: usize,
264 loading_isq: bool,
265 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
266 ) -> Result<Self> {
267 let h_size = cfg.hidden_size;
268 let i_size = cfg.intermediate_size;
269 let c_fc1 = linear(
270 h_size,
271 i_size,
272 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
273 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
274 lora_config,
275 count,
276 ord,
277 preload_adapters,
278 )?;
279 let c_fc2 = linear(
280 h_size,
281 i_size,
282 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
283 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
284 lora_config,
285 count,
286 ord,
287 preload_adapters,
288 )?;
289 let c_proj = linear(
290 i_size,
291 h_size,
292 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
293 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
294 lora_config,
295 count,
296 ord,
297 preload_adapters,
298 )?;
299 Ok(Self {
300 c_fc1,
301 c_fc2,
302 c_proj,
303 })
304 }
305}
306
307struct Block {
308 rms_1: RmsNorm,
309 attn: CausalSelfAttention,
310 rms_2: RmsNorm,
311 mlp: Mlp,
312}
313
314impl Block {
315 #[allow(clippy::too_many_arguments)]
316 fn forward(
317 &self,
318 x: &Tensor,
319 mask: &Option<Tensor>,
320 seqlen_offsets: &[usize],
321 block_idx: usize,
322 kv_cache: &mut LayerCaches,
323 scalings: Option<Tensor>,
324 global_scaling_weight: f64,
325 is_scaling_pass: Option<f64>,
326 flash_params: &FlashParams,
327 ) -> Result<Tensor> {
328 let residual = x;
329 let x = self.rms_1.forward(x)?;
330 let x = (self.attn.forward(
331 &x,
332 mask,
333 seqlen_offsets,
334 block_idx,
335 kv_cache,
336 scalings.clone(),
337 global_scaling_weight,
338 is_scaling_pass,
339 flash_params,
340 )? + residual)?;
341 let residual = &x;
342 let x = (self.mlp.forward(
343 &self.rms_2.forward(&x)?,
344 scalings,
345 global_scaling_weight,
346 is_scaling_pass,
347 )? + residual)?;
348 Ok(x)
349 }
350
351 #[allow(clippy::too_many_arguments)]
352 fn load(
353 vb: ShardedVarBuilder,
354 cfg: &Config,
355 lora_config: &[((String, String), LoraConfig)],
356 count: &mut usize,
357 ord: &Ordering,
358 mapper: &dyn DeviceMapper,
359 layer_idx: usize,
360 loading_isq: bool,
361 rope: Arc<Llama3RotaryEmbedding>,
362 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
363 ) -> Result<Self> {
364 let attn = CausalSelfAttention::load(
365 vb.pp("self_attn"),
366 cfg,
367 lora_config,
368 count,
369 ord,
370 mapper,
371 layer_idx,
372 loading_isq,
373 rope,
374 preload_adapters,
375 )?;
376 let mlp = Mlp::load(
377 vb.pp("mlp"),
378 cfg,
379 lora_config,
380 count,
381 ord,
382 mapper,
383 layer_idx,
384 loading_isq,
385 preload_adapters,
386 )?;
387 let rms_1 = RmsNorm::new(
388 cfg.hidden_size,
389 cfg.rms_norm_eps,
390 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
391 )?;
392 let rms_2 = RmsNorm::new(
393 cfg.hidden_size,
394 cfg.rms_norm_eps,
395 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
396 )?;
397 Ok(Self {
398 rms_1,
399 attn,
400 rms_2,
401 mlp,
402 })
403 }
404}
405
406pub struct XLoraLlama {
407 wte: Embedding,
408 blocks: Vec<Block>,
409 ln_f: RmsNorm,
410 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
411 kv_cache: pipeline::EitherCache,
412 device: Device,
413 xlora_classifier: Option<XLoraClassifier>,
414 dtype: DType,
415 mapper: Box<dyn DeviceMapper + Send + Sync>,
416 cfg: ModelConfigMetadata,
417}
418
419impl XLoraLlama {
420 #[allow(clippy::too_many_arguments)]
421 fn inner_forward(
422 &self,
423 input_ids: &Tensor,
424 seqlen_offsets: &[usize],
425 scalings: Option<Tensor>,
426 is_full_pass: bool,
427 no_kv_cache: bool,
428 is_scaling_pass: Option<f64>,
429 flash_params: &FlashParams,
430 ) -> Result<Tensor> {
431 let mut x = self.wte.forward(input_ids)?;
432 let mut cache = if is_full_pass {
433 if no_kv_cache {
434 let mut new_cache = Vec::new();
435 for _ in 0..self.kv_cache.full().xlora_lock().len() {
436 new_cache.push(None);
437 }
438
439 self.kv_cache.full().xlora_lock().clone_from(&new_cache);
440 }
441 self.kv_cache.full().xlora_lock()
442 } else {
443 self.kv_cache.full().lock()
444 };
445 let mask = CausalMasker.make_causal_mask_matrix(
446 input_ids,
447 &*cache,
448 x.dtype(),
449 self.cfg.num_attn_heads,
450 )?;
451 for (block_idx, block) in self.blocks.iter().enumerate() {
452 x = self.mapper.map(x, block_idx)?;
453 x = block.forward(
454 &x,
455 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
456 seqlen_offsets,
457 block_idx,
458 &mut cache,
459 scalings.clone(),
460 self.xlora_classifier
461 .as_ref()
462 .map(|classifier| classifier.get_global_scaling_weight())
463 .unwrap_or(1.0),
464 is_scaling_pass,
465 flash_params,
466 )?;
467 }
468 let x = x.to_device(&self.device)?;
469 self.ln_f.forward(&x)
470 }
471
472 #[allow(clippy::too_many_arguments)]
473 pub fn forward(
474 &self,
475 input_ids: &Tensor,
476 input_ids_full: &Tensor,
477 seqlen_offsets: &[usize],
478 seqlen_offsets_full: &[usize],
479 no_kv_cache: bool,
480 non_granular_state: &Option<NonGranularState>,
481 context_lens: Vec<(usize, usize)>,
482 flash_params: &FlashParams,
483 flash_params_full: &FlashParams,
484 ) -> Result<Tensor> {
485 if self.xlora_classifier.is_some() {
486 let scalings = self.get_scalings(
487 input_ids,
488 input_ids_full,
489 seqlen_offsets,
490 seqlen_offsets_full,
491 no_kv_cache,
492 non_granular_state,
493 &vec![usize::MAX; context_lens.len()],
494 flash_params,
495 flash_params_full,
496 )?;
497
498 if no_kv_cache {
499 let mut res = self
500 .inner_forward(
501 input_ids_full,
502 seqlen_offsets_full,
503 Some(scalings),
504 true,
505 no_kv_cache,
506 None,
507 flash_params_full,
508 )?
509 .contiguous()?;
510 if let Some(t) = self.lm_head.quantized_act_type() {
511 res = res.to_dtype(t)?;
512 }
513 extract_logits(
514 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
515 context_lens,
516 )
517 } else {
518 let mut res = self
520 .inner_forward(
521 input_ids,
522 seqlen_offsets,
523 Some(scalings),
524 true,
525 no_kv_cache,
526 None,
527 flash_params,
528 )?
529 .contiguous()?;
530 if let Some(t) = self.lm_head.quantized_act_type() {
531 res = res.to_dtype(t)?;
532 }
533 extract_logits(
534 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
535 context_lens,
536 )
537 }
538 } else {
539 let mut res = self
540 .inner_forward(
541 input_ids,
542 seqlen_offsets,
543 None,
544 false,
545 no_kv_cache,
546 None,
547 flash_params,
548 )?
549 .contiguous()?;
550 if let Some(t) = self.lm_head.quantized_act_type() {
551 res = res.to_dtype(t)?;
552 }
553 extract_logits(
554 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
555 context_lens,
556 )
557 }
558 }
559
560 #[allow(clippy::too_many_arguments)]
561 pub fn new(
562 cfg: &Config,
563 vb: ShardedVarBuilder,
564 lora_config: &[((String, String), LoraConfig)],
565 xlora_config: Option<XLoraConfig>,
566 xlora_ordering: Ordering,
567 is_gptx: bool,
568 normal_loading_metadata: NormalLoadingMetadata,
569 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
570 ) -> Result<Self> {
571 if let Some(ref quant_cfg) = &cfg.quantization_config {
572 tracing::info!(
573 "Using {} quantization: {}.",
574 quant_cfg.name(),
575 quant_cfg.get_bits_name(&vb)
576 );
577 }
578 let mapper = normal_loading_metadata.mapper;
579 let dtype = vb.dtype();
580 let mut count = 0;
581
582 let wte = embedding(
583 cfg.vocab_size,
584 cfg.hidden_size,
585 mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
586 &cfg.quantization_config,
587 )?;
588 let lm_head = linear(
589 cfg.hidden_size,
590 cfg.vocab_size,
591 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
592 mapper.set_nm_device(vb.pp("lm_head"), false),
593 lora_config,
594 &mut count,
595 &xlora_ordering,
596 preload_adapters,
597 )?;
598 if xlora_config.is_some() && lm_head.is_lora() {
599 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
601 }
602 let ln_f = RmsNorm::new(
603 cfg.hidden_size,
604 cfg.rms_norm_eps,
605 mapper.set_nm_device(vb.pp("model.norm"), false),
606 )?;
607 let mut ropes = HashMap::new();
608 for i in 0..cfg.num_hidden_layers {
609 let device = mapper
610 .device_for(i, false)
611 .unwrap_or(&normal_loading_metadata.real_device);
612 ropes.insert(
613 device.location(),
614 Arc::new(Llama3RotaryEmbedding::new_llama3(
615 vb.dtype(),
616 cfg,
617 device,
618 is_gptx,
619 )?),
620 );
621 }
622 let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
623 0..cfg.num_hidden_layers,
624 "Loading repeating layers",
625 &normal_loading_metadata.multi_progress,
626 )
627 .into_iter()
628 .map(|i| {
629 let device = mapper
630 .device_for(i, false)
631 .unwrap_or(&normal_loading_metadata.real_device);
632 let rotary_emb = ropes
633 .get(&device.location())
634 .expect("No RoPE for device location!")
635 .clone();
636 Block::load(
637 vb.pp(format!("model.layers.{i}")),
638 cfg,
639 lora_config,
640 &mut count,
641 &xlora_ordering,
642 &*mapper,
643 i,
644 normal_loading_metadata.loading_isq,
645 rotary_emb,
646 preload_adapters,
647 )
648 .expect("Failed to load block.")
649 })
650 .collect();
651 if xlora_config.is_none() && preload_adapters.is_none() {
652 info!("Merging LoRA adapters.");
654 for layer in blocks.iter_mut().tqdm() {
655 Arc::get_mut(&mut layer.attn.k_proj)
656 .unwrap()
657 .merge_weights()?;
658 Arc::get_mut(&mut layer.attn.o_proj)
659 .unwrap()
660 .merge_weights()?;
661 Arc::get_mut(&mut layer.attn.q_proj)
662 .unwrap()
663 .merge_weights()?;
664 Arc::get_mut(&mut layer.attn.v_proj)
665 .unwrap()
666 .merge_weights()?;
667
668 Arc::get_mut(&mut layer.mlp.c_fc1)
669 .unwrap()
670 .merge_weights()?;
671 Arc::get_mut(&mut layer.mlp.c_fc2)
672 .unwrap()
673 .merge_weights()?;
674 Arc::get_mut(&mut layer.mlp.c_proj)
675 .unwrap()
676 .merge_weights()?;
677 }
678 }
679
680 Ok(Self {
681 wte,
682 blocks,
683 ln_f,
684 lm_head,
685 kv_cache: EitherCache::Full(pipeline::Cache::new(cfg.num_hidden_layers, true)),
686 device: normal_loading_metadata.real_device,
687 xlora_classifier: xlora_config.map(|xlora_config| {
688 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
689 }),
690 dtype,
691 mapper,
692 cfg: ModelConfigMetadata {
693 max_seq_len: cfg.max_position_embeddings,
694 num_layers: cfg.num_hidden_layers,
695 hidden_size: cfg.hidden_size,
696 num_kv_heads: cfg.num_key_value_heads,
697 num_attn_heads: cfg.num_attention_heads,
698 sliding_window: None,
699 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
700 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701 },
702 })
703 }
704}
705
706impl IsqModel for XLoraLlama {
707 fn get_layers(
708 &mut self,
709 ) -> (
710 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
711 &dyn DeviceMapper,
712 ) {
713 let mut tensors = Vec::new();
714 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
715 for (i, layer) in self.blocks.iter_mut().enumerate() {
716 tensors.push((
717 Arc::get_mut(&mut layer.attn.q_proj).unwrap().quant_inner(),
718 Some(i),
719 ));
720 tensors.push((
721 Arc::get_mut(&mut layer.attn.k_proj).unwrap().quant_inner(),
722 Some(i),
723 ));
724 tensors.push((
725 Arc::get_mut(&mut layer.attn.v_proj).unwrap().quant_inner(),
726 Some(i),
727 ));
728 tensors.push((
729 Arc::get_mut(&mut layer.attn.o_proj).unwrap().quant_inner(),
730 Some(i),
731 ));
732 tensors.push((
733 Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().quant_inner(),
734 Some(i),
735 ));
736 tensors.push((
737 Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().quant_inner(),
738 Some(i),
739 ));
740 tensors.push((
741 Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
742 Some(i),
743 ));
744 }
745 (tensors, &*self.mapper)
746 }
747
748 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
749 panic!("Cannot generate UQFF for an adapter model.")
750 }
751}
752
753impl NormalModel for XLoraLlama {
754 fn forward(
755 &self,
756 _input_ids: &Tensor,
757 _seqlen_offsets: &[usize],
758 _context_lens: Vec<(usize, usize)>,
759 _position_ids: Vec<usize>,
760 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
761 _flash_params: &FlashParams,
762 ) -> Result<Tensor> {
763 unreachable!()
764 }
765 fn xlora_forward(
766 &self,
767 input_ids: &Tensor,
768 input_ids_full: &Tensor,
769 seqlen_offsets: &[usize],
770 seqlen_offsets_full: &[usize],
771 no_kv_cache: bool,
772 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
773 context_lens: Vec<(usize, usize)>,
774 _position_ids: Vec<usize>,
775 flash_params: &FlashParams,
776 flash_params_full: &FlashParams,
777 ) -> Result<Tensor> {
778 self.forward(
779 input_ids,
780 input_ids_full,
781 seqlen_offsets,
782 seqlen_offsets_full,
783 no_kv_cache,
784 non_granular_state,
785 context_lens,
786 flash_params,
787 flash_params_full,
788 )
789 }
790 fn cache(&self) -> &super::EitherCache {
791 &self.kv_cache
792 }
793 fn cache_mut(&mut self) -> &mut super::EitherCache {
794 &mut self.kv_cache
795 }
796 fn device(&self) -> &Device {
797 &self.device
798 }
799 fn is_xlora(&self) -> bool {
800 true
801 }
802 fn max_seq_len(&self) -> usize {
803 self.blocks[0].attn.max_seq_len
804 }
805 fn config(&self) -> &ModelConfigMetadata {
806 &self.cfg
807 }
808}
809
810impl ScalingsMaker for XLoraLlama {
811 fn dtype(&self) -> DType {
812 self.dtype
813 }
814 fn get_cache(&self) -> &pipeline::EitherCache {
815 &self.kv_cache
816 }
817 fn get_classifier(&self) -> &XLoraClassifier {
818 self.xlora_classifier.as_ref().unwrap()
819 }
820 fn forward(
821 &self,
822 input_ids: &Tensor,
823 seqlen_offsets: &[usize],
824 scalings: Tensor,
825 is_full_pass: bool,
826 no_kv_cache: bool,
827 is_scaling_pass: Option<f64>,
828 _context_lens: &[usize],
829 flash_params: &FlashParams,
830 ) -> Result<Tensor> {
831 self.inner_forward(
832 input_ids,
833 seqlen_offsets,
834 Some(scalings),
835 is_full_pass,
836 no_kv_cache,
837 is_scaling_pass,
838 flash_params,
839 )
840 }
841}
842
843impl AnyMoeBaseModelMixin for XLoraLlama {}