1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{self, Activation, Sdpa},
9 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24 device_map::DeviceMapper,
25 layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26 models::phi3::Config,
27 pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35 qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 num_heads: usize,
38 num_kv_heads: usize,
39 head_dim: usize,
40 rotary_emb: Arc<PhiRotaryEmbedding>,
41 sliding_window: Option<usize>,
42 sdpa_params: SdpaParams,
43}
44
45impl Attention {
46 #[allow(clippy::too_many_arguments)]
47 fn new(
48 rotary_emb: Arc<PhiRotaryEmbedding>,
49 cfg: &Config,
50 vb: ShardedVarBuilder,
51 lora_config: &[((String, String), LoraConfig)],
52 count: &mut usize,
53 ord: &Ordering,
54 mapper: &dyn DeviceMapper,
55 layer_idx: usize,
56 loading_isq: bool,
57 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58 ) -> Result<Self> {
59 let num_heads = cfg.num_attention_heads;
60 let num_kv_heads = cfg.num_key_value_heads;
61 let head_dim = cfg.head_dim();
62 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63 let qkv_proj = linear_no_bias(
64 cfg.hidden_size,
65 op_size,
66 mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 let o_proj = linear_no_bias(
74 num_heads * head_dim,
75 cfg.hidden_size,
76 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78 lora_config,
79 count,
80 ord,
81 preload_adapters,
82 )?;
83 Ok(Self {
84 qkv_proj,
85 o_proj,
86 rotary_emb,
87 num_heads,
88 num_kv_heads,
89 head_dim,
90 sliding_window: cfg.sliding_window,
91 sdpa_params: SdpaParams {
92 n_kv_groups: num_heads / num_kv_heads,
93 use_flash_attn: cfg.use_flash_attn,
94 softcap: None,
95 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
96 sliding_window: cfg.sliding_window,
97 },
98 })
99 }
100
101 #[allow(clippy::too_many_arguments)]
102 fn forward(
103 &self,
104 xs: &Tensor,
105 attention_mask: Option<&Tensor>,
106 seqlen_offsets: &[usize],
107 position_ids: &[usize],
108 kv_cache: &mut Option<(Tensor, Tensor)>,
109 scalings: Option<Tensor>,
110 global_scaling_weight: f64,
111 is_scaling_pass: Option<f64>,
112 flash_params: &FlashParams,
113 ) -> Result<Tensor> {
114 let (b_sz, q_len, _) = xs.dims3()?;
115
116 let original_dtype = xs.dtype();
117 let mut xs = xs.clone();
118 if let Some(t) = self.qkv_proj.quantized_act_type() {
119 xs = xs.to_dtype(t)?;
120 }
121 let mut qkv = self.qkv_proj.lora_forward(
122 &xs,
123 scalings.clone(),
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 if self.qkv_proj.quantized_act_type().is_some() {
128 qkv = qkv.to_dtype(original_dtype)?;
129 }
130 let query_pos = self.num_heads * self.head_dim;
131 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
132 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
133 let v = qkv.narrow(
134 D::Minus1,
135 query_pos + self.num_kv_heads * self.head_dim,
136 self.num_kv_heads * self.head_dim,
137 )?;
138
139 let (q, k, v) = if q_len != 1 {
140 let q = q
141 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
142 .transpose(1, 2)?;
143 let k = k
144 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
145 .transpose(1, 2)?;
146 let v = v
147 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
148 .transpose(1, 2)?;
149 (q, k, v)
150 } else {
151 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
152 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
154 (q, k, v)
155 };
156
157 let (q, k) = self
158 .rotary_emb
159 .forward(&q, &k, seqlen_offsets, position_ids)?;
160
161 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
162 kv_cache,
163 k,
164 v,
165 attention_mask,
166 self.sliding_window,
167 true,
168 )?;
169
170 let mut attn_output = Sdpa.run_attention(
171 &q,
172 &k,
173 &v,
174 attn_mask.as_ref(),
175 Some(flash_params),
176 &self.sdpa_params,
177 )?;
178
179 if let Some(t) = self.qkv_proj.quantized_act_type() {
180 attn_output = attn_output.to_dtype(t)?;
181 }
182 let mut res = self.o_proj.lora_forward(
183 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
184 scalings.clone(),
185 global_scaling_weight,
186 is_scaling_pass,
187 )?;
188 if self.qkv_proj.quantized_act_type().is_some() {
189 res = res.to_dtype(original_dtype)?;
190 }
191 Ok(res)
192 }
193}
194
195#[derive(Clone)]
196struct Mlp {
197 gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
198 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
199 act_fn: Activation,
200 i_size: usize,
201}
202
203impl Mlp {
204 #[allow(clippy::too_many_arguments)]
205 fn new(
206 cfg: &Config,
207 vb: ShardedVarBuilder,
208 lora_config: &[((String, String), LoraConfig)],
209 count: &mut usize,
210 ord: &Ordering,
211 mapper: &dyn DeviceMapper,
212 layer_idx: usize,
213 loading_isq: bool,
214 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
215 ) -> Result<Self> {
216 let hidden_size = cfg.hidden_size;
217 let i_size = cfg.intermediate_size;
218 let gate_up_proj = linear_no_bias(
219 hidden_size,
220 2 * i_size,
221 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
222 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
223 lora_config,
224 count,
225 ord,
226 preload_adapters,
227 )?;
228 let down_proj = linear_no_bias(
229 i_size,
230 hidden_size,
231 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
232 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
233 lora_config,
234 count,
235 ord,
236 preload_adapters,
237 )?;
238 Ok(Self {
239 gate_up_proj,
240 down_proj,
241 act_fn: cfg.hidden_act,
242 i_size,
243 })
244 }
245
246 fn forward(
247 &self,
248 xs: &Tensor,
249 scalings: Option<Tensor>,
250 global_scaling_weight: f64,
251 is_scaling_pass: Option<f64>,
252 ) -> Result<Tensor> {
253 let original_dtype = xs.dtype();
254 let mut xs = xs.clone();
255 if let Some(t) = self.gate_up_proj.quantized_act_type() {
256 xs = xs.to_dtype(t)?;
257 }
258 let up_states = self.gate_up_proj.lora_forward(
259 &xs,
260 scalings.clone(),
261 global_scaling_weight,
262 is_scaling_pass,
263 )?;
264 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
265 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
266 let up_states = (up_states * gate.apply(&self.act_fn))?;
267 let mut res = self.down_proj.lora_forward(
268 &up_states,
269 scalings,
270 global_scaling_weight,
271 is_scaling_pass,
272 )?;
273 if self.gate_up_proj.quantized_act_type().is_some() {
274 res = res.to_dtype(original_dtype)?;
275 }
276 Ok(res)
277 }
278}
279
280struct DecoderLayer {
281 self_attn: Attention,
282 mlp: Mlp,
283 input_layernorm: RmsNorm,
284 post_attention_layernorm: RmsNorm,
285}
286
287impl DecoderLayer {
288 #[allow(clippy::too_many_arguments)]
289 fn new(
290 rotary_emb: Arc<PhiRotaryEmbedding>,
291 cfg: &Config,
292 vb: ShardedVarBuilder,
293 lora_config: &[((String, String), LoraConfig)],
294 count: &mut usize,
295 ord: &Ordering,
296 mapper: &dyn DeviceMapper,
297 layer_idx: usize,
298 loading_isq: bool,
299 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
300 ) -> Result<Self> {
301 let self_attn = Attention::new(
302 rotary_emb,
303 cfg,
304 vb.pp("self_attn"),
305 lora_config,
306 count,
307 ord,
308 mapper,
309 layer_idx,
310 loading_isq,
311 preload_adapters,
312 )?;
313 let mlp = Mlp::new(
314 cfg,
315 vb.pp("mlp"),
316 lora_config,
317 count,
318 ord,
319 mapper,
320 layer_idx,
321 loading_isq,
322 preload_adapters,
323 )?;
324 let input_layernorm = RmsNorm::new(
325 cfg.hidden_size,
326 cfg.rms_norm_eps,
327 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
328 )?;
329 let post_attention_layernorm = RmsNorm::new(
330 cfg.hidden_size,
331 cfg.rms_norm_eps,
332 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
333 )?;
334 Ok(Self {
335 self_attn,
336 mlp,
337 input_layernorm,
338 post_attention_layernorm,
339 })
340 }
341
342 #[allow(clippy::too_many_arguments)]
343 fn forward(
344 &self,
345 xs: &Tensor,
346 attention_mask: Option<&Tensor>,
347 seqlen_offsets: &[usize],
348 position_ids: &[usize],
349 kv_cache: &mut Option<(Tensor, Tensor)>,
350 scalings: Option<Tensor>,
351 global_scaling_weight: f64,
352 is_scaling_pass: Option<f64>,
353 flash_params: &FlashParams,
354 ) -> Result<Tensor> {
355 let residual = xs;
356 let xs = self.input_layernorm.forward(xs)?;
357 let xs = self.self_attn.forward(
358 &xs,
359 attention_mask,
360 seqlen_offsets,
361 position_ids,
362 kv_cache,
363 scalings.clone(),
364 global_scaling_weight,
365 is_scaling_pass,
366 flash_params,
367 )?;
368 let xs = (xs + residual)?;
369 let residual = &xs;
370 let xs = self.mlp.forward(
371 &xs.apply(&self.post_attention_layernorm)?,
372 scalings,
373 global_scaling_weight,
374 is_scaling_pass,
375 );
376 residual + xs
377 }
378}
379
380pub struct Model {
381 embed_tokens: candle_nn::Embedding,
382 layers: Vec<DecoderLayer>,
383 norm: RmsNorm,
384 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
385 dtype: DType,
386 device: Device,
387 cache: EitherCache,
388 max_seq_len: usize,
389 mapper: Box<dyn DeviceMapper + Send + Sync>,
390 xlora_classifier: Option<XLoraClassifier>,
391 sliding_window: Option<usize>,
392 cfg: ModelConfigMetadata,
393}
394
395impl Model {
396 #[allow(clippy::too_many_arguments)]
397 pub fn new(
398 cfg: &Config,
399 vb: ShardedVarBuilder,
400 lora_config: &[((String, String), LoraConfig)],
401 xlora_config: Option<XLoraConfig>,
402 xlora_ordering: Ordering,
403 _is_gptx: bool,
404 normal_loading_metadata: NormalLoadingMetadata,
405 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
406 ) -> Result<Self> {
407 if let Some(ref quant_cfg) = &cfg.quantization_config {
408 tracing::info!(
409 "Using {} quantization: {}.",
410 quant_cfg.quant_method.to_string(),
411 quant_cfg.get_bits_name(&vb)
412 );
413 }
414 let mapper = normal_loading_metadata.mapper;
415 let vb_m = vb.pp("model");
416
417 let embed_tokens = layers::embedding(
418 cfg.vocab_size,
419 cfg.hidden_size,
420 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
421 )?;
422 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
423 let vb_l = vb_m.pp("layers");
424 let mut ropes = HashMap::new();
425 for layer_idx in 0..cfg.num_hidden_layers {
426 let device = mapper
427 .device_for(layer_idx, false)
428 .unwrap_or(&normal_loading_metadata.real_device);
429 ropes.insert(
430 device.location(),
431 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
432 );
433 }
434 let mut count = 0;
435 for layer_idx in NiceProgressBar::<_, 'b'>(
436 0..cfg.num_hidden_layers,
437 "Loading repeating layers",
438 &normal_loading_metadata.multi_progress,
439 ) {
440 let device = mapper
441 .device_for(layer_idx, false)
442 .unwrap_or(&normal_loading_metadata.real_device);
443 let rotary_emb = ropes
444 .get(&device.location())
445 .expect("No RoPE for device location!")
446 .clone();
447 let layer = DecoderLayer::new(
448 rotary_emb.clone(),
449 cfg,
450 vb_l.pp(layer_idx),
451 lora_config,
452 &mut count,
453 &xlora_ordering,
454 &*mapper,
455 layer_idx,
456 normal_loading_metadata.loading_isq,
457 preload_adapters,
458 )?;
459 layers.push(layer)
460 }
461 if xlora_config.is_none() && preload_adapters.is_none() {
462 info!("Merging LoRA adapters.");
464 for layer in layers.iter_mut().tqdm() {
465 Arc::get_mut(&mut layer.self_attn.qkv_proj)
466 .unwrap()
467 .merge_weights()?;
468 Arc::get_mut(&mut layer.self_attn.o_proj)
469 .unwrap()
470 .merge_weights()?;
471
472 Arc::get_mut(&mut layer.mlp.down_proj)
473 .unwrap()
474 .merge_weights()?;
475 Arc::get_mut(&mut layer.mlp.gate_up_proj)
476 .unwrap()
477 .merge_weights()?;
478 }
479 }
480 let norm = RmsNorm::new(
481 cfg.hidden_size,
482 cfg.rms_norm_eps,
483 mapper.set_nm_device(vb_m.pp("norm"), false),
484 )?;
485 let lm_head = linear_no_bias(
486 cfg.hidden_size,
487 cfg.vocab_size,
488 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
489 mapper.set_nm_device(vb.pp("lm_head"), false),
490 lora_config,
491 &mut count,
492 &xlora_ordering,
493 preload_adapters,
494 )?;
495 if xlora_config.is_some() && lm_head.is_lora() {
496 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
498 }
499 Ok(Self {
500 embed_tokens,
501 layers,
502 norm,
503 lm_head,
504 device: normal_loading_metadata.real_device,
505 dtype: vb.dtype(),
506 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
507 max_seq_len: cfg.max_position_embeddings,
508 mapper,
509 xlora_classifier: xlora_config.map(|xlora_config| {
510 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
511 }),
512 sliding_window: cfg.sliding_window,
513 cfg: ModelConfigMetadata {
514 max_seq_len: cfg.max_position_embeddings,
515 num_layers: cfg.num_hidden_layers,
516 hidden_size: cfg.hidden_size,
517 num_kv_heads: cfg.num_key_value_heads,
518 num_attn_heads: cfg.num_attention_heads,
519 sliding_window: cfg.sliding_window,
520 k_head_dim: cfg.head_dim(),
521 v_head_dim: cfg.head_dim(),
522 },
523 })
524 }
525
526 #[allow(clippy::too_many_arguments)]
527 fn inner_forward(
528 &self,
529 input_ids: &Tensor,
530 seqlen_offsets: &[usize],
531 position_ids: &[usize],
532 scalings: Option<Tensor>,
533 is_full_pass: bool,
534 no_kv_cache: bool,
535 is_scaling_pass: Option<f64>,
536 flash_params: &FlashParams,
537 ) -> Result<Tensor> {
538 let mut xs = self.embed_tokens.forward(input_ids)?;
539 let mut cache = if is_full_pass {
540 if no_kv_cache {
541 let mut new_cache = Vec::new();
542 for _ in 0..self.cache.full().xlora_lock().len() {
543 new_cache.push(None);
544 }
545
546 self.cache.full().xlora_lock().clone_from(&new_cache);
547 }
548 self.cache.full().xlora_lock()
549 } else {
550 self.cache.full().lock()
551 };
552 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
553 input_ids,
554 &*cache,
555 self.sliding_window,
556 xs.dtype(),
557 self.cfg.num_attn_heads,
558 )?;
559
560 for (i, layer) in self.layers.iter().enumerate() {
561 xs = self.mapper.map(xs, i)?;
562 xs = layer.forward(
563 &xs,
564 attention_mask
565 .as_ref()
566 .map(|m| m.to_device(xs.device()).unwrap())
567 .as_ref(),
568 seqlen_offsets,
569 position_ids,
570 &mut cache[i],
571 scalings.clone(),
572 self.xlora_classifier
573 .as_ref()
574 .map(|classifier| classifier.get_global_scaling_weight())
575 .unwrap_or(1.0),
576 is_scaling_pass,
577 flash_params,
578 )?
579 }
580 let xs = xs.to_device(&self.device)?;
581 xs.apply(&self.norm)
582 }
583
584 #[allow(clippy::too_many_arguments)]
585 pub fn forward(
586 &self,
587 input_ids: &Tensor,
588 input_ids_full: &Tensor,
589 seqlen_offsets: &[usize],
590 seqlen_offsets_full: &[usize],
591 no_kv_cache: bool,
592 non_granular_state: &Option<NonGranularState>,
593 context_lens: Vec<(usize, usize)>,
594 position_ids: Vec<usize>,
595 flash_params: &FlashParams,
596 flash_params_full: &FlashParams,
597 ) -> Result<Tensor> {
598 if self.xlora_classifier.is_some() {
599 let scalings = self.get_scalings(
600 input_ids,
601 input_ids_full,
602 seqlen_offsets,
603 seqlen_offsets_full,
604 no_kv_cache,
605 non_granular_state,
606 &position_ids,
607 flash_params,
608 flash_params_full,
609 )?;
610
611 if no_kv_cache {
612 let mut res = self
613 .inner_forward(
614 input_ids_full,
615 seqlen_offsets_full,
616 &position_ids,
617 Some(scalings),
618 true,
619 no_kv_cache,
620 None,
621 flash_params_full,
622 )?
623 .contiguous()?;
624 if let Some(t) = self.lm_head.quantized_act_type() {
625 res = res.to_dtype(t)?;
626 }
627 extract_logits(
628 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
629 context_lens,
630 )
631 } else {
632 let mut res = self
634 .inner_forward(
635 input_ids,
636 seqlen_offsets,
637 &position_ids,
638 Some(scalings),
639 true,
640 no_kv_cache,
641 None,
642 flash_params,
643 )?
644 .contiguous()?;
645 if let Some(t) = self.lm_head.quantized_act_type() {
646 res = res.to_dtype(t)?;
647 }
648 extract_logits(
649 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
650 context_lens,
651 )
652 }
653 } else {
654 let mut res = self
655 .inner_forward(
656 input_ids,
657 seqlen_offsets,
658 &position_ids,
659 None,
660 false,
661 no_kv_cache,
662 None,
663 flash_params,
664 )?
665 .contiguous()?;
666 if let Some(t) = self.lm_head.quantized_act_type() {
667 res = res.to_dtype(t)?;
668 }
669 extract_logits(
670 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
671 context_lens,
672 )
673 }
674 }
675}
676
677impl IsqModel for Model {
678 fn get_layers(
679 &mut self,
680 ) -> (
681 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
682 &dyn DeviceMapper,
683 ) {
684 let mut tensors = Vec::new();
685 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
686 for (i, layer) in self.layers.iter_mut().enumerate() {
687 tensors.push((
688 Arc::get_mut(&mut layer.self_attn.qkv_proj)
689 .unwrap()
690 .quant_inner(),
691 Some(i),
692 ));
693 tensors.push((
694 Arc::get_mut(&mut layer.self_attn.o_proj)
695 .unwrap()
696 .quant_inner(),
697 Some(i),
698 ));
699 tensors.push((
700 Arc::get_mut(&mut layer.mlp.gate_up_proj)
701 .unwrap()
702 .quant_inner(),
703 Some(i),
704 ));
705 tensors.push((
706 Arc::get_mut(&mut layer.mlp.down_proj)
707 .unwrap()
708 .quant_inner(),
709 Some(i),
710 ));
711 }
712 (tensors, &*self.mapper)
713 }
714
715 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
716 panic!("Cannot generate UQFF for an adapter model.")
717 }
718}
719
720impl NormalModel for Model {
721 fn forward(
722 &self,
723 _input_ids: &Tensor,
724 _seqlen_offsets: &[usize],
725 _context_lens: Vec<(usize, usize)>,
726 _position_ids: Vec<usize>,
727 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
728 _flash_params: &FlashParams,
729 ) -> Result<Tensor> {
730 unreachable!()
731 }
732 fn xlora_forward(
733 &self,
734 input_ids: &Tensor,
735 input_ids_full: &Tensor,
736 seqlen_offsets: &[usize],
737 seqlen_offsets_full: &[usize],
738 no_kv_cache: bool,
739 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
740 context_lens: Vec<(usize, usize)>,
741 position_ids: Vec<usize>,
742 flash_params: &FlashParams,
743 flash_params_full: &FlashParams,
744 ) -> Result<Tensor> {
745 self.forward(
746 input_ids,
747 input_ids_full,
748 seqlen_offsets,
749 seqlen_offsets_full,
750 no_kv_cache,
751 non_granular_state,
752 context_lens,
753 position_ids,
754 flash_params,
755 flash_params_full,
756 )
757 }
758 fn cache(&self) -> &EitherCache {
759 &self.cache
760 }
761 fn cache_mut(&mut self) -> &mut EitherCache {
762 &mut self.cache
763 }
764 fn device(&self) -> &Device {
765 &self.device
766 }
767 fn is_xlora(&self) -> bool {
768 true
769 }
770 fn max_seq_len(&self) -> usize {
771 self.max_seq_len
772 }
773 fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
774 if self.xlora_classifier.is_some() {
775 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
776 }
777 let mut sum = 0;
778 for layer in self.layers.iter_mut() {
779 sum += Arc::get_mut(&mut layer.self_attn.qkv_proj)
780 .unwrap()
781 .activate(&adapter_names)?;
782 sum += Arc::get_mut(&mut layer.self_attn.o_proj)
783 .unwrap()
784 .activate(&adapter_names)?;
785
786 sum += Arc::get_mut(&mut layer.mlp.down_proj)
787 .unwrap()
788 .activate(&adapter_names)?;
789 sum += Arc::get_mut(&mut layer.mlp.gate_up_proj)
790 .unwrap()
791 .activate(&adapter_names)?;
792 }
793 Ok(sum)
794 }
795 fn config(&self) -> &ModelConfigMetadata {
796 &self.cfg
797 }
798}
799
800impl ScalingsMaker for Model {
801 fn dtype(&self) -> DType {
802 self.dtype
803 }
804 fn get_cache(&self) -> &EitherCache {
805 &self.cache
806 }
807 fn get_classifier(&self) -> &XLoraClassifier {
808 self.xlora_classifier.as_ref().unwrap()
809 }
810 fn forward(
811 &self,
812 input_ids: &Tensor,
813 seqlen_offsets: &[usize],
814 scalings: Tensor,
815 is_full_pass: bool,
816 no_kv_cache: bool,
817 is_scaling_pass: Option<f64>,
818 context_lens: &[usize],
819 flash_params: &FlashParams,
820 ) -> Result<Tensor> {
821 self.inner_forward(
823 input_ids,
824 seqlen_offsets,
825 context_lens,
826 Some(scalings),
827 is_full_pass,
828 no_kv_cache,
829 is_scaling_pass,
830 flash_params,
831 )
832 }
833}
834
835impl AnyMoeBaseModelMixin for Model {}