1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4 amoe::AnyMoeBaseModelMixin,
5 attention::SdpaParams,
6 layers::{Llama3RotaryEmbedding, Sdpa},
7 lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
8 paged_attention::ModelConfigMetadata,
9 pipeline::{
10 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11 EitherCache, IsqModel,
12 },
13 utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23 device_map::DeviceMapper,
24 layers::{embedding, CausalMasker, RmsNorm},
25 models::llama::Config,
26 pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31struct CausalSelfAttention {
32 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 num_attention_heads: usize,
37 num_key_value_heads: usize,
38 head_dim: usize,
39 rotary_emb: Arc<Llama3RotaryEmbedding>,
40 max_seq_len: usize,
41 sdpa_params: SdpaParams,
42}
43
44impl CausalSelfAttention {
45 #[allow(clippy::too_many_arguments)]
46 fn forward(
47 &self,
48 x: &Tensor,
49 mask: &Option<Tensor>,
50 seqlen_offsets: &[usize],
51 block_idx: usize,
52 kv_cache: &mut LayerCaches,
53 scalings: Option<Tensor>,
54 global_scaling_weight: f64,
55 is_scaling_pass: Option<f64>,
56 flash_params: &FlashParams,
57 ) -> Result<Tensor> {
58 let (b_sz, seq_len, hidden_size) = x.dims3()?;
59
60 let original_dtype = x.dtype();
61 let mut x = x.clone();
62 if let Some(t) = self.q_proj.quantized_act_type() {
63 x = x.to_dtype(t)?;
64 }
65 let mut q = self.q_proj.lora_forward(
66 &x,
67 scalings.clone(),
68 global_scaling_weight,
69 is_scaling_pass,
70 )?;
71 let mut k = self.k_proj.lora_forward(
72 &x,
73 scalings.clone(),
74 global_scaling_weight,
75 is_scaling_pass,
76 )?;
77 let mut v = self.v_proj.lora_forward(
78 &x,
79 scalings.clone(),
80 global_scaling_weight,
81 is_scaling_pass,
82 )?;
83 if self.q_proj.quantized_act_type().is_some() {
84 q = q.to_dtype(original_dtype)?;
85 k = k.to_dtype(original_dtype)?;
86 v = v.to_dtype(original_dtype)?;
87 }
88
89 let (q, k, v) = if seq_len != 1 {
90 let q = q
91 .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
92 .transpose(1, 2)?;
93 let k = k
94 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
95 .transpose(1, 2)?;
96 let v = v
97 .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
98 .transpose(1, 2)?;
99 (q, k, v)
100 } else {
101 let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
102 let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
103 let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
104 (q, k, v)
105 };
106
107 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
108
109 let (k, v) =
110 crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?;
111
112 let y = Sdpa.run_attention(
113 &q,
114 &k,
115 &v,
116 mask.clone().as_ref(),
117 Some(flash_params),
118 &self.sdpa_params,
119 )?;
120
121 let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
122 if let Some(t) = self.q_proj.quantized_act_type() {
123 y = y.to_dtype(t)?;
124 }
125 let mut res = self.o_proj.lora_forward(
126 &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?,
127 scalings.clone(),
128 global_scaling_weight,
129 is_scaling_pass,
130 )?;
131 if self.q_proj.quantized_act_type().is_some() {
132 res = res.to_dtype(original_dtype)?;
133 }
134 Ok(res)
135 }
136
137 #[allow(clippy::too_many_arguments)]
138 fn load(
139 vb: ShardedVarBuilder,
140 cfg: &Config,
141 lora_config: &[((String, String), LoraConfig)],
142 count: &mut usize,
143 ord: &Ordering,
144 mapper: &dyn DeviceMapper,
145 layer_idx: usize,
146 loading_isq: bool,
147 rope: Arc<Llama3RotaryEmbedding>,
148 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
149 ) -> Result<Self> {
150 let size_in = cfg.hidden_size;
151 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
152 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
153 let q_proj = linear(
154 size_in,
155 size_q,
156 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
157 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
158 lora_config,
159 count,
160 ord,
161 preload_adapters,
162 )?;
163 let k_proj = linear(
164 size_in,
165 size_kv,
166 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
167 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
168 lora_config,
169 count,
170 ord,
171 preload_adapters,
172 )?;
173 let v_proj = linear(
174 size_in,
175 size_kv,
176 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
177 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
178 lora_config,
179 count,
180 ord,
181 preload_adapters,
182 )?;
183 let o_proj = linear(
184 size_q,
185 size_in,
186 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
187 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
188 lora_config,
189 count,
190 ord,
191 preload_adapters,
192 )?;
193 Ok(Self {
194 q_proj,
195 k_proj,
196 v_proj,
197 o_proj,
198 num_attention_heads: cfg.num_attention_heads,
199 num_key_value_heads: cfg.num_key_value_heads,
200 head_dim: cfg.hidden_size / cfg.num_attention_heads,
201 rotary_emb: rope,
202 max_seq_len: cfg.max_position_embeddings,
203 sdpa_params: SdpaParams {
204 n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
205 softcap: None,
206 softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
207 sliding_window: None,
208 },
209 })
210 }
211}
212
213#[derive(Clone)]
214struct Mlp {
215 c_fc1: Arc<dyn LinearLayerLike + Send + Sync>,
216 c_fc2: Arc<dyn LinearLayerLike + Send + Sync>,
217 c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
218}
219
220impl Mlp {
221 fn forward(
222 &self,
223 x: &Tensor,
224 scalings: Option<Tensor>,
225 global_scaling_weight: f64,
226 is_scaling_pass: Option<f64>,
227 ) -> Result<Tensor> {
228 let original_dtype = x.dtype();
229 let mut x = x.clone();
230 if let Some(t) = self.c_fc1.quantized_act_type() {
231 x = x.to_dtype(t)?;
232 }
233 let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward(
234 &x,
235 scalings.clone(),
236 global_scaling_weight,
237 is_scaling_pass,
238 )?)? * self.c_fc2.lora_forward(
239 &x,
240 scalings.clone(),
241 global_scaling_weight,
242 is_scaling_pass,
243 )?)?;
244 let mut res = self.c_proj.lora_forward(
245 &x,
246 scalings.clone(),
247 global_scaling_weight,
248 is_scaling_pass,
249 )?;
250 if self.c_fc1.quantized_act_type().is_some() {
251 res = res.to_dtype(original_dtype)?;
252 }
253 Ok(res)
254 }
255
256 #[allow(clippy::too_many_arguments)]
257 fn load(
258 vb: ShardedVarBuilder,
259 cfg: &Config,
260 lora_config: &[((String, String), LoraConfig)],
261 count: &mut usize,
262 ord: &Ordering,
263 mapper: &dyn DeviceMapper,
264 layer_idx: usize,
265 loading_isq: bool,
266 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
267 ) -> Result<Self> {
268 let h_size = cfg.hidden_size;
269 let i_size = cfg.intermediate_size;
270 let c_fc1 = linear(
271 h_size,
272 i_size,
273 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
274 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
275 lora_config,
276 count,
277 ord,
278 preload_adapters,
279 )?;
280 let c_fc2 = linear(
281 h_size,
282 i_size,
283 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
284 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
285 lora_config,
286 count,
287 ord,
288 preload_adapters,
289 )?;
290 let c_proj = linear(
291 i_size,
292 h_size,
293 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
294 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
295 lora_config,
296 count,
297 ord,
298 preload_adapters,
299 )?;
300 Ok(Self {
301 c_fc1,
302 c_fc2,
303 c_proj,
304 })
305 }
306}
307
308struct Block {
309 rms_1: RmsNorm,
310 attn: CausalSelfAttention,
311 rms_2: RmsNorm,
312 mlp: Mlp,
313}
314
315impl Block {
316 #[allow(clippy::too_many_arguments)]
317 fn forward(
318 &self,
319 x: &Tensor,
320 mask: &Option<Tensor>,
321 seqlen_offsets: &[usize],
322 block_idx: usize,
323 kv_cache: &mut LayerCaches,
324 scalings: Option<Tensor>,
325 global_scaling_weight: f64,
326 is_scaling_pass: Option<f64>,
327 flash_params: &FlashParams,
328 ) -> Result<Tensor> {
329 let residual = x;
330 let x = self.rms_1.forward(x)?;
331 let x = (self.attn.forward(
332 &x,
333 mask,
334 seqlen_offsets,
335 block_idx,
336 kv_cache,
337 scalings.clone(),
338 global_scaling_weight,
339 is_scaling_pass,
340 flash_params,
341 )? + residual)?;
342 let residual = &x;
343 let x = (self.mlp.forward(
344 &self.rms_2.forward(&x)?,
345 scalings,
346 global_scaling_weight,
347 is_scaling_pass,
348 )? + residual)?;
349 Ok(x)
350 }
351
352 #[allow(clippy::too_many_arguments)]
353 fn load(
354 vb: ShardedVarBuilder,
355 cfg: &Config,
356 lora_config: &[((String, String), LoraConfig)],
357 count: &mut usize,
358 ord: &Ordering,
359 mapper: &dyn DeviceMapper,
360 layer_idx: usize,
361 loading_isq: bool,
362 rope: Arc<Llama3RotaryEmbedding>,
363 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
364 ) -> Result<Self> {
365 let attn = CausalSelfAttention::load(
366 vb.pp("self_attn"),
367 cfg,
368 lora_config,
369 count,
370 ord,
371 mapper,
372 layer_idx,
373 loading_isq,
374 rope,
375 preload_adapters,
376 )?;
377 let mlp = Mlp::load(
378 vb.pp("mlp"),
379 cfg,
380 lora_config,
381 count,
382 ord,
383 mapper,
384 layer_idx,
385 loading_isq,
386 preload_adapters,
387 )?;
388 let rms_1 = RmsNorm::new(
389 cfg.hidden_size,
390 cfg.rms_norm_eps,
391 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
392 )?;
393 let rms_2 = RmsNorm::new(
394 cfg.hidden_size,
395 cfg.rms_norm_eps,
396 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
397 )?;
398 Ok(Self {
399 rms_1,
400 attn,
401 rms_2,
402 mlp,
403 })
404 }
405}
406
407pub struct XLoraLlama {
408 wte: Embedding,
409 blocks: Vec<Block>,
410 ln_f: RmsNorm,
411 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
412 kv_cache: pipeline::EitherCache,
413 device: Device,
414 xlora_classifier: Option<XLoraClassifier>,
415 dtype: DType,
416 mapper: Box<dyn DeviceMapper + Send + Sync>,
417 cfg: ModelConfigMetadata,
418}
419
420impl XLoraLlama {
421 #[allow(clippy::too_many_arguments)]
422 fn inner_forward(
423 &self,
424 input_ids: &Tensor,
425 seqlen_offsets: &[usize],
426 scalings: Option<Tensor>,
427 is_full_pass: bool,
428 no_kv_cache: bool,
429 is_scaling_pass: Option<f64>,
430 flash_params: &FlashParams,
431 ) -> Result<Tensor> {
432 let mut x = self.wte.forward(input_ids)?;
433 let mut cache = if is_full_pass {
434 if no_kv_cache {
435 let mut new_cache = Vec::new();
436 for _ in 0..self.kv_cache.full().xlora_lock().len() {
437 new_cache.push(None);
438 }
439
440 self.kv_cache.full().xlora_lock().clone_from(&new_cache);
441 }
442 self.kv_cache.full().xlora_lock()
443 } else {
444 self.kv_cache.full().lock()
445 };
446 let mask = CausalMasker.make_causal_mask_matrix(
447 input_ids,
448 &*cache,
449 x.dtype(),
450 self.cfg.num_attn_heads,
451 )?;
452 for (block_idx, block) in self.blocks.iter().enumerate() {
453 x = self.mapper.map(x, block_idx)?;
454 x = block.forward(
455 &x,
456 &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
457 seqlen_offsets,
458 block_idx,
459 &mut cache,
460 scalings.clone(),
461 self.xlora_classifier
462 .as_ref()
463 .map(|classifier| classifier.get_global_scaling_weight())
464 .unwrap_or(1.0),
465 is_scaling_pass,
466 flash_params,
467 )?;
468 }
469 let x = x.to_device(&self.device)?;
470 self.ln_f.forward(&x)
471 }
472
473 #[allow(clippy::too_many_arguments)]
474 pub fn forward(
475 &self,
476 input_ids: &Tensor,
477 input_ids_full: &Tensor,
478 seqlen_offsets: &[usize],
479 seqlen_offsets_full: &[usize],
480 no_kv_cache: bool,
481 non_granular_state: &Option<NonGranularState>,
482 context_lens: Vec<(usize, usize)>,
483 flash_params: &FlashParams,
484 flash_params_full: &FlashParams,
485 ) -> Result<Tensor> {
486 if self.xlora_classifier.is_some() {
487 let scalings = self.get_scalings(
488 input_ids,
489 input_ids_full,
490 seqlen_offsets,
491 seqlen_offsets_full,
492 no_kv_cache,
493 non_granular_state,
494 &vec![usize::MAX; context_lens.len()],
495 flash_params,
496 flash_params_full,
497 )?;
498
499 if no_kv_cache {
500 let mut res = self
501 .inner_forward(
502 input_ids_full,
503 seqlen_offsets_full,
504 Some(scalings),
505 true,
506 no_kv_cache,
507 None,
508 flash_params_full,
509 )?
510 .contiguous()?;
511 if let Some(t) = self.lm_head.quantized_act_type() {
512 res = res.to_dtype(t)?;
513 }
514 extract_logits(
515 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
516 context_lens,
517 )
518 } else {
519 let mut res = self
521 .inner_forward(
522 input_ids,
523 seqlen_offsets,
524 Some(scalings),
525 true,
526 no_kv_cache,
527 None,
528 flash_params,
529 )?
530 .contiguous()?;
531 if let Some(t) = self.lm_head.quantized_act_type() {
532 res = res.to_dtype(t)?;
533 }
534 extract_logits(
535 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
536 context_lens,
537 )
538 }
539 } else {
540 let mut res = self
541 .inner_forward(
542 input_ids,
543 seqlen_offsets,
544 None,
545 false,
546 no_kv_cache,
547 None,
548 flash_params,
549 )?
550 .contiguous()?;
551 if let Some(t) = self.lm_head.quantized_act_type() {
552 res = res.to_dtype(t)?;
553 }
554 extract_logits(
555 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
556 context_lens,
557 )
558 }
559 }
560
561 #[allow(clippy::too_many_arguments)]
562 pub fn new(
563 cfg: &Config,
564 vb: ShardedVarBuilder,
565 lora_config: &[((String, String), LoraConfig)],
566 xlora_config: Option<XLoraConfig>,
567 xlora_ordering: Ordering,
568 is_gptx: bool,
569 normal_loading_metadata: NormalLoadingMetadata,
570 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
571 ) -> Result<Self> {
572 if let Some(ref quant_cfg) = &cfg.quantization_config {
573 tracing::info!(
574 "Using {} quantization: {}.",
575 quant_cfg.name(),
576 quant_cfg.get_bits_name(&vb)
577 );
578 }
579 let mapper = normal_loading_metadata.mapper;
580 let dtype = vb.dtype();
581 let mut count = 0;
582
583 let wte = embedding(
584 cfg.vocab_size,
585 cfg.hidden_size,
586 mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
587 &cfg.quantization_config,
588 )?;
589 let lm_head = linear(
590 cfg.hidden_size,
591 cfg.vocab_size,
592 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
593 mapper.set_nm_device(vb.pp("lm_head"), false),
594 lora_config,
595 &mut count,
596 &xlora_ordering,
597 preload_adapters,
598 )?;
599 if xlora_config.is_some() && lm_head.is_lora() {
600 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
602 }
603 let ln_f = RmsNorm::new(
604 cfg.hidden_size,
605 cfg.rms_norm_eps,
606 mapper.set_nm_device(vb.pp("model.norm"), false),
607 )?;
608 let mut ropes = HashMap::new();
609 for i in 0..cfg.num_hidden_layers {
610 let device = mapper
611 .device_for(i, false)
612 .unwrap_or(&normal_loading_metadata.real_device);
613 ropes.insert(
614 device.location(),
615 Arc::new(Llama3RotaryEmbedding::new_llama3(
616 vb.dtype(),
617 cfg,
618 device,
619 is_gptx,
620 )?),
621 );
622 }
623 let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
624 0..cfg.num_hidden_layers,
625 "Loading repeating layers",
626 &normal_loading_metadata.multi_progress,
627 )
628 .into_iter()
629 .map(|i| {
630 let device = mapper
631 .device_for(i, false)
632 .unwrap_or(&normal_loading_metadata.real_device);
633 let rotary_emb = ropes
634 .get(&device.location())
635 .expect("No RoPE for device location!")
636 .clone();
637 Block::load(
638 vb.pp(format!("model.layers.{i}")),
639 cfg,
640 lora_config,
641 &mut count,
642 &xlora_ordering,
643 &*mapper,
644 i,
645 normal_loading_metadata.loading_isq,
646 rotary_emb,
647 preload_adapters,
648 )
649 .expect("Failed to load block.")
650 })
651 .collect();
652 if xlora_config.is_none() && preload_adapters.is_none() {
653 info!("Merging LoRA adapters.");
655 for layer in blocks.iter_mut().tqdm() {
656 Arc::get_mut(&mut layer.attn.k_proj)
657 .unwrap()
658 .merge_weights()?;
659 Arc::get_mut(&mut layer.attn.o_proj)
660 .unwrap()
661 .merge_weights()?;
662 Arc::get_mut(&mut layer.attn.q_proj)
663 .unwrap()
664 .merge_weights()?;
665 Arc::get_mut(&mut layer.attn.v_proj)
666 .unwrap()
667 .merge_weights()?;
668
669 Arc::get_mut(&mut layer.mlp.c_fc1)
670 .unwrap()
671 .merge_weights()?;
672 Arc::get_mut(&mut layer.mlp.c_fc2)
673 .unwrap()
674 .merge_weights()?;
675 Arc::get_mut(&mut layer.mlp.c_proj)
676 .unwrap()
677 .merge_weights()?;
678 }
679 }
680
681 Ok(Self {
682 wte,
683 blocks,
684 ln_f,
685 lm_head,
686 kv_cache: EitherCache::Full(pipeline::Cache::new(cfg.num_hidden_layers, true)),
687 device: normal_loading_metadata.real_device,
688 xlora_classifier: xlora_config.map(|xlora_config| {
689 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
690 }),
691 dtype,
692 mapper,
693 cfg: ModelConfigMetadata {
694 max_seq_len: cfg.max_position_embeddings,
695 num_layers: cfg.num_hidden_layers,
696 hidden_size: cfg.hidden_size,
697 num_kv_heads: cfg.num_key_value_heads,
698 num_attn_heads: cfg.num_attention_heads,
699 sliding_window: None,
700 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
702 },
703 })
704 }
705}
706
707impl IsqModel for XLoraLlama {
708 fn get_layers(
709 &mut self,
710 ) -> (
711 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
712 &dyn DeviceMapper,
713 ) {
714 let mut tensors = Vec::new();
715 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
716 for (i, layer) in self.blocks.iter_mut().enumerate() {
717 tensors.push((
718 Arc::get_mut(&mut layer.attn.q_proj).unwrap().quant_inner(),
719 Some(i),
720 ));
721 tensors.push((
722 Arc::get_mut(&mut layer.attn.k_proj).unwrap().quant_inner(),
723 Some(i),
724 ));
725 tensors.push((
726 Arc::get_mut(&mut layer.attn.v_proj).unwrap().quant_inner(),
727 Some(i),
728 ));
729 tensors.push((
730 Arc::get_mut(&mut layer.attn.o_proj).unwrap().quant_inner(),
731 Some(i),
732 ));
733 tensors.push((
734 Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().quant_inner(),
735 Some(i),
736 ));
737 tensors.push((
738 Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().quant_inner(),
739 Some(i),
740 ));
741 tensors.push((
742 Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
743 Some(i),
744 ));
745 }
746 (tensors, &*self.mapper)
747 }
748
749 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
750 panic!("Cannot generate UQFF for an adapter model.")
751 }
752}
753
754impl NormalModel for XLoraLlama {
755 fn forward(
756 &self,
757 _input_ids: &Tensor,
758 _seqlen_offsets: &[usize],
759 _context_lens: Vec<(usize, usize)>,
760 _position_ids: Vec<usize>,
761 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
762 _flash_params: &FlashParams,
763 ) -> Result<Tensor> {
764 unreachable!()
765 }
766 fn xlora_forward(
767 &self,
768 input_ids: &Tensor,
769 input_ids_full: &Tensor,
770 seqlen_offsets: &[usize],
771 seqlen_offsets_full: &[usize],
772 no_kv_cache: bool,
773 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
774 context_lens: Vec<(usize, usize)>,
775 _position_ids: Vec<usize>,
776 flash_params: &FlashParams,
777 flash_params_full: &FlashParams,
778 ) -> Result<Tensor> {
779 self.forward(
780 input_ids,
781 input_ids_full,
782 seqlen_offsets,
783 seqlen_offsets_full,
784 no_kv_cache,
785 non_granular_state,
786 context_lens,
787 flash_params,
788 flash_params_full,
789 )
790 }
791 fn cache(&self) -> &super::EitherCache {
792 &self.kv_cache
793 }
794 fn cache_mut(&mut self) -> &mut super::EitherCache {
795 &mut self.kv_cache
796 }
797 fn device(&self) -> &Device {
798 &self.device
799 }
800 fn is_xlora(&self) -> bool {
801 true
802 }
803 fn max_seq_len(&self) -> usize {
804 self.blocks[0].attn.max_seq_len
805 }
806 fn config(&self) -> &ModelConfigMetadata {
807 &self.cfg
808 }
809}
810
811impl ScalingsMaker for XLoraLlama {
812 fn dtype(&self) -> DType {
813 self.dtype
814 }
815 fn get_cache(&self) -> &pipeline::EitherCache {
816 &self.kv_cache
817 }
818 fn get_classifier(&self) -> &XLoraClassifier {
819 self.xlora_classifier.as_ref().unwrap()
820 }
821 fn forward(
822 &self,
823 input_ids: &Tensor,
824 seqlen_offsets: &[usize],
825 scalings: Tensor,
826 is_full_pass: bool,
827 no_kv_cache: bool,
828 is_scaling_pass: Option<f64>,
829 _context_lens: &[usize],
830 flash_params: &FlashParams,
831 ) -> Result<Tensor> {
832 self.inner_forward(
833 input_ids,
834 seqlen_offsets,
835 Some(scalings),
836 is_full_pass,
837 no_kv_cache,
838 is_scaling_pass,
839 flash_params,
840 )
841 }
842}
843
844impl AnyMoeBaseModelMixin for XLoraLlama {}