1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{Activation, RotaryEmbedding, Sdpa},
9 lora::{linear, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Result, Tensor};
23use candle_nn::{Embedding, LayerNorm};
24use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
25use tqdm::Iter;
26use tracing::info;
27
28use crate::{
29 device_map::DeviceMapper,
30 layers::{embedding, layer_norm, CausalMasker},
31 models::phi2::Config,
32 pipeline::{extract_logits, NormalModel},
33};
34
35use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig};
36
37#[derive(Clone)]
38#[allow(clippy::upper_case_acronyms)]
39struct MLP {
40 fc1: Arc<dyn LinearLayerLike + Send + Sync>,
41 fc2: Arc<dyn LinearLayerLike + Send + Sync>,
42 act: Activation,
43}
44
45impl MLP {
46 #[allow(clippy::too_many_arguments)]
47 fn new(
48 cfg: &Config,
49 vb: ShardedVarBuilder,
50 lora_config: &[((String, String), LoraConfig)],
51 count: &mut usize,
52 ord: &Ordering,
53 mapper: &dyn DeviceMapper,
54 layer_idx: usize,
55 loading_isq: bool,
56 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
57 ) -> Result<Self> {
58 let fc1 = linear(
59 cfg.hidden_size,
60 cfg.intermediate_size,
61 mapper.set_device(layer_idx, vb.pp("fc1"), loading_isq),
62 mapper.set_device(layer_idx, vb.pp("fc1"), false),
63 lora_config,
64 count,
65 ord,
66 preload_adapters,
67 )?;
68 let fc2 = linear(
69 cfg.intermediate_size,
70 cfg.hidden_size,
71 mapper.set_device(layer_idx, vb.pp("fc2"), loading_isq),
72 mapper.set_device(layer_idx, vb.pp("fc2"), false),
73 lora_config,
74 count,
75 ord,
76 preload_adapters,
77 )?;
78 Ok(Self {
79 fc1,
80 fc2,
81 act: cfg.hidden_act,
84 })
85 }
86
87 fn forward(
88 &self,
89 xs: &Tensor,
90 scalings: Option<Tensor>,
91 global_scaling_weight: f64,
92 is_scaling_pass: Option<f64>,
93 ) -> Result<Tensor> {
94 let original_dtype = xs.dtype();
95 let mut xs = xs.clone();
96 if let Some(t) = self.fc1.quantized_act_type() {
97 xs = xs.to_dtype(t)?;
98 }
99 let mut res = self.fc2.lora_forward(
100 &self
101 .fc1
102 .lora_forward(
103 &xs,
104 scalings.clone(),
105 global_scaling_weight,
106 is_scaling_pass,
107 )?
108 .apply(&self.act)?,
109 scalings,
110 global_scaling_weight,
111 is_scaling_pass,
112 )?;
113 if self.fc1.quantized_act_type().is_some() {
114 res = res.to_dtype(original_dtype)?;
115 }
116 Ok(res)
117 }
118}
119
120struct Attention {
121 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
122 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
123 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
124 dense: Arc<dyn LinearLayerLike + Send + Sync>,
125 q_layernorm: Option<LayerNorm>,
126 k_layernorm: Option<LayerNorm>,
127 rotary_emb: Arc<RotaryEmbedding>,
128 num_heads: usize,
129 num_kv_heads: usize,
130 head_dim: usize,
131 sdpa_params: SdpaParams,
132}
133
134impl Attention {
135 #[allow(clippy::too_many_arguments)]
136 fn new(
137 cfg: &Config,
138 vb: ShardedVarBuilder,
139 lora_config: &[((String, String), LoraConfig)],
140 count: &mut usize,
141 ord: &Ordering,
142 mapper: &dyn DeviceMapper,
143 layer_idx: usize,
144 loading_isq: bool,
145 rope: Arc<RotaryEmbedding>,
146 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
147 ) -> Result<Self> {
148 let num_heads = cfg.num_attention_heads;
149 let num_kv_heads = cfg.num_key_value_heads();
150 let head_dim = cfg.head_dim();
151 let q_proj = linear(
152 cfg.hidden_size,
153 num_heads * head_dim,
154 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
155 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
156 lora_config,
157 count,
158 ord,
159 preload_adapters,
160 )?;
161 let k_proj = linear(
162 cfg.hidden_size,
163 num_kv_heads * head_dim,
164 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
165 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
166 lora_config,
167 count,
168 ord,
169 preload_adapters,
170 )?;
171 let v_proj = linear(
172 cfg.hidden_size,
173 num_kv_heads * head_dim,
174 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
175 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
176 lora_config,
177 count,
178 ord,
179 preload_adapters,
180 )?;
181 let dense = linear(
182 num_heads * head_dim,
183 cfg.hidden_size,
184 mapper.set_device(layer_idx, vb.pp("dense"), loading_isq),
185 mapper.set_device(layer_idx, vb.pp("dense"), false),
186 lora_config,
187 count,
188 ord,
189 preload_adapters,
190 )?;
191 let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
192 let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
193 let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
194 (Some(q_layernorm), Some(k_layernorm))
195 } else {
196 (None, None)
197 };
198 Ok(Self {
199 q_proj,
200 k_proj,
201 v_proj,
202 dense,
203 q_layernorm,
204 k_layernorm,
205 rotary_emb: rope,
206 num_heads,
207 num_kv_heads,
208 head_dim,
209 sdpa_params: SdpaParams {
210 n_kv_groups: num_heads / num_kv_heads,
211 use_flash_attn: cfg.use_flash_attn,
212 softcap: None,
213 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
214 sliding_window: None,
215 },
216 })
217 }
218
219 #[allow(clippy::too_many_arguments)]
220 fn forward(
221 &self,
222 xs: &Tensor,
223 mask: Option<&Tensor>,
224 seqlen_offsets: &[usize],
225 kv_cache: &mut Option<(Tensor, Tensor)>,
226 scalings: Option<Tensor>,
227 global_scaling_weight: f64,
228 is_scaling_pass: Option<f64>,
229 flash_params: &FlashParams,
230 ) -> Result<Tensor> {
231 let (b_size, seq_len, _n_embd) = xs.dims3()?;
232 let original_dtype = xs.dtype();
233 let mut xs = xs.clone();
234 if let Some(t) = self.q_proj.quantized_act_type() {
235 xs = xs.to_dtype(t)?;
236 }
237 let mut q = self.q_proj.lora_forward(
238 &xs,
239 scalings.clone(),
240 global_scaling_weight,
241 is_scaling_pass,
242 )?;
243 let mut k = self.k_proj.lora_forward(
244 &xs,
245 scalings.clone(),
246 global_scaling_weight,
247 is_scaling_pass,
248 )?;
249 let mut v = self.v_proj.lora_forward(
250 &xs,
251 scalings.clone(),
252 global_scaling_weight,
253 is_scaling_pass,
254 )?;
255 if self.q_proj.quantized_act_type().is_some() {
256 q = q.to_dtype(original_dtype)?;
257 k = k.to_dtype(original_dtype)?;
258 v = v.to_dtype(original_dtype)?;
259 }
260
261 let q = match &self.q_layernorm {
262 None => q,
263 Some(ln) => q.apply(ln)?,
264 };
265 let k = match &self.k_layernorm {
266 None => k,
267 Some(ln) => k.apply(ln)?,
268 };
269
270 let (q, k, v) = if seq_len != 1 {
271 let q = q
272 .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
273 .transpose(1, 2)?;
274 let k = k
275 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
276 .transpose(1, 2)?;
277 let v = v
278 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
279 .transpose(1, 2)?;
280 (q, k, v)
281 } else {
282 let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
283 let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
284 let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
285 (q, k, v)
286 };
287
288 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
289
290 let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
291
292 let attn_output =
293 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?;
294
295 let mut attn_output = attn_output
296 .transpose(1, 2)?
297 .reshape((b_size, seq_len, ()))?;
298 if let Some(t) = self.q_proj.quantized_act_type() {
299 attn_output = attn_output.to_dtype(t)?;
300 }
301 let mut res = self.dense.lora_forward(
302 &attn_output,
303 scalings,
304 global_scaling_weight,
305 is_scaling_pass,
306 )?;
307 if self.q_proj.quantized_act_type().is_some() {
308 res = res.to_dtype(original_dtype)?;
309 }
310 Ok(res)
311 }
312}
313
314struct DecoderLayer {
315 self_attn: Attention,
316 mlp: MLP,
317 input_layernorm: LayerNorm,
318}
319
320impl DecoderLayer {
321 #[allow(clippy::too_many_arguments)]
322 fn new(
323 cfg: &Config,
324 vb: ShardedVarBuilder,
325 lora_config: &[((String, String), LoraConfig)],
326 count: &mut usize,
327 ord: &Ordering,
328 mapper: &dyn DeviceMapper,
329 layer_idx: usize,
330 loading_isq: bool,
331 rope: Arc<RotaryEmbedding>,
332 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
333 ) -> Result<Self> {
334 let self_attn = Attention::new(
335 cfg,
336 vb.pp("self_attn"),
337 lora_config,
338 count,
339 ord,
340 mapper,
341 layer_idx,
342 loading_isq,
343 rope,
344 preload_adapters,
345 )?;
346 let mlp = MLP::new(
347 cfg,
348 vb.pp("mlp"),
349 lora_config,
350 count,
351 ord,
352 mapper,
353 layer_idx,
354 loading_isq,
355 preload_adapters,
356 )?;
357 let input_layernorm = layer_norm(
358 cfg.hidden_size,
359 cfg.layer_norm_eps,
360 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
361 )?;
362 Ok(Self {
363 self_attn,
364 mlp,
365 input_layernorm,
366 })
367 }
368
369 #[allow(clippy::too_many_arguments)]
370 fn forward(
371 &self,
372 xs: &Tensor,
373 mask: Option<&Tensor>,
374 seqlen_offsets: &[usize],
375 kv_cache: &mut Option<(Tensor, Tensor)>,
376 scalings: Option<Tensor>,
377 global_scaling_weight: f64,
378 is_scaling_pass: Option<f64>,
379 flash_params: &FlashParams,
380 ) -> Result<Tensor> {
381 let residual = xs;
382 let xs = xs.apply(&self.input_layernorm)?;
383 let attn_outputs = self.self_attn.forward(
384 &xs,
385 mask,
386 seqlen_offsets,
387 kv_cache,
388 scalings.clone(),
389 global_scaling_weight,
390 is_scaling_pass,
391 flash_params,
392 )?;
393 let feed_forward_hidden_states =
394 self.mlp
395 .forward(&xs, scalings, global_scaling_weight, is_scaling_pass)?;
396 attn_outputs + feed_forward_hidden_states + residual
397 }
398}
399
400pub struct Model {
401 embed_tokens: Embedding,
402 layers: Vec<DecoderLayer>,
403 final_layernorm: LayerNorm,
404 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
405 cache: EitherCache,
406 device: Device,
407 max_seq_len: usize,
408 xlora_classifier: Option<XLoraClassifier>,
409 dtype: DType,
410 mapper: Box<dyn DeviceMapper + Send + Sync>,
411 cfg: ModelConfigMetadata,
412}
413
414impl Model {
415 #[allow(clippy::too_many_arguments)]
416 pub fn new(
417 cfg: &Config,
418 vb: ShardedVarBuilder,
419 lora_config: &[((String, String), LoraConfig)],
420 xlora_config: Option<XLoraConfig>,
421 xlora_ordering: Ordering,
422 is_gptx: bool,
423 normal_loading_metadata: NormalLoadingMetadata,
424 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
425 ) -> Result<Self> {
426 if let Some(ref quant_cfg) = &cfg.quantization_config {
427 tracing::info!(
428 "Using {} quantization: {}.",
429 quant_cfg.quant_method.to_string(),
430 quant_cfg.get_bits_name(&vb)
431 );
432 }
433 let mapper = normal_loading_metadata.mapper;
434 let vb_m = vb.pp("model");
435
436 let embed_tokens = embedding(
437 cfg.vocab_size,
438 cfg.hidden_size,
439 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
440 )?;
441 let final_layernorm = layer_norm(
442 cfg.hidden_size,
443 cfg.layer_norm_eps,
444 mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
445 )?;
446 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
447 let vb_m = vb_m.pp("layers");
448 let mut ropes = HashMap::new();
449 for layer_idx in 0..cfg.num_hidden_layers {
450 let device = mapper
451 .device_for(layer_idx, false)
452 .unwrap_or(&normal_loading_metadata.real_device);
453 ropes.insert(
455 device.location(),
456 Arc::new(RotaryEmbedding::new_partial(
457 cfg.rope_theta,
458 (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
459 cfg.max_position_embeddings,
460 device,
461 is_gptx,
462 vb.dtype(),
463 )?),
464 );
465 }
466 let mut count = 0;
467 for layer_idx in NiceProgressBar::<_, 'b'>(
468 0..cfg.num_hidden_layers,
469 "Loading repeating layers",
470 &normal_loading_metadata.multi_progress,
471 ) {
472 let device = mapper
473 .device_for(layer_idx, false)
474 .unwrap_or(&normal_loading_metadata.real_device);
475 let rotary_emb = ropes
476 .get(&device.location())
477 .expect("No RoPE for device location!")
478 .clone();
479 let layer = DecoderLayer::new(
480 cfg,
481 vb_m.pp(layer_idx),
482 lora_config,
483 &mut count,
484 &xlora_ordering,
485 &*mapper,
486 layer_idx,
487 normal_loading_metadata.loading_isq,
488 rotary_emb,
489 preload_adapters,
490 )?;
491 layers.push(layer)
492 }
493 if xlora_config.is_none() && preload_adapters.is_none() {
494 info!("Merging LoRA adapters.");
496 for layer in layers.iter_mut().tqdm() {
497 Arc::get_mut(&mut layer.self_attn.k_proj)
498 .unwrap()
499 .merge_weights()?;
500 Arc::get_mut(&mut layer.self_attn.dense)
501 .unwrap()
502 .merge_weights()?;
503 Arc::get_mut(&mut layer.self_attn.q_proj)
504 .unwrap()
505 .merge_weights()?;
506 Arc::get_mut(&mut layer.self_attn.v_proj)
507 .unwrap()
508 .merge_weights()?;
509
510 Arc::get_mut(&mut layer.mlp.fc1).unwrap().merge_weights()?;
511 Arc::get_mut(&mut layer.mlp.fc2).unwrap().merge_weights()?;
512 }
513 }
514 let lm_head = linear(
515 cfg.hidden_size,
516 cfg.vocab_size,
517 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
518 mapper.set_nm_device(vb.pp("lm_head"), false),
519 lora_config,
520 &mut count,
521 &xlora_ordering,
522 preload_adapters,
523 )?;
524 if xlora_config.is_some() && lm_head.is_lora() {
525 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
527 }
528 Ok(Self {
529 embed_tokens,
530 layers,
531 final_layernorm,
532 lm_head,
533 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
534 device: normal_loading_metadata.real_device,
535 max_seq_len: cfg.max_position_embeddings,
536 dtype: vb.dtype(),
537 xlora_classifier: xlora_config.map(|xlora_config| {
538 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
539 }),
540 mapper,
541 cfg: ModelConfigMetadata {
542 max_seq_len: cfg.max_position_embeddings,
543 num_layers: cfg.num_hidden_layers,
544 hidden_size: cfg.hidden_size,
545 num_kv_heads: cfg.num_key_value_heads(),
546 num_attn_heads: cfg.num_attention_heads,
547 sliding_window: None,
548 k_head_dim: cfg.head_dim(),
549 v_head_dim: cfg.head_dim(),
550 },
551 })
552 }
553
554 #[allow(clippy::too_many_arguments)]
555 fn inner_forward(
556 &self,
557 input_ids: &Tensor,
558 seqlen_offsets: &[usize],
559 scalings: Option<Tensor>,
560 is_full_pass: bool,
561 no_kv_cache: bool,
562 is_scaling_pass: Option<f64>,
563 flash_params: &FlashParams,
564 ) -> Result<Tensor> {
565 let mut xs = input_ids.apply(&self.embed_tokens)?;
566 let mut cache = if is_full_pass {
567 if no_kv_cache {
568 let mut new_cache = Vec::new();
569 for _ in 0..self.cache.full().xlora_lock().len() {
570 new_cache.push(None);
571 }
572
573 self.cache.full().xlora_lock().clone_from(&new_cache);
574 }
575 self.cache.full().xlora_lock()
576 } else {
577 self.cache.full().lock()
578 };
579 let mask = CausalMasker.make_causal_mask_matrix(
580 input_ids,
581 &*cache,
582 xs.dtype(),
583 self.cfg.num_attn_heads,
584 )?;
585 for (i, layer) in self.layers.iter().enumerate() {
586 xs = self.mapper.map(xs, i)?;
587 xs = layer.forward(
588 &xs,
589 mask.as_ref()
590 .map(|m| m.to_device(xs.device()).unwrap())
591 .as_ref(),
592 seqlen_offsets,
593 &mut cache[i],
594 scalings.clone(),
595 self.xlora_classifier
596 .as_ref()
597 .map(|classifier| classifier.get_global_scaling_weight())
598 .unwrap_or(1.0),
599 is_scaling_pass,
600 flash_params,
601 )?;
602 }
603 let xs = xs.to_device(&self.device)?;
604 xs.apply(&self.final_layernorm)
605 }
606
607 #[allow(clippy::too_many_arguments)]
608 pub fn forward(
609 &self,
610 input_ids: &Tensor,
611 input_ids_full: &Tensor,
612 seqlen_offsets: &[usize],
613 seqlen_offsets_full: &[usize],
614 no_kv_cache: bool,
615 non_granular_state: &Option<NonGranularState>,
616 context_lens: Vec<(usize, usize)>,
617 flash_params: &FlashParams,
618 flash_params_full: &FlashParams,
619 ) -> Result<Tensor> {
620 if self.xlora_classifier.is_some() {
621 let scalings = self.get_scalings(
622 input_ids,
623 input_ids_full,
624 seqlen_offsets,
625 seqlen_offsets_full,
626 no_kv_cache,
627 non_granular_state,
628 &vec![usize::MAX; context_lens.len()],
629 flash_params,
630 flash_params_full,
631 )?;
632
633 if no_kv_cache {
634 let mut res = self
635 .inner_forward(
636 input_ids_full,
637 seqlen_offsets_full,
638 Some(scalings),
639 true,
640 no_kv_cache,
641 None,
642 flash_params_full,
643 )?
644 .contiguous()?;
645 if let Some(t) = self.lm_head.quantized_act_type() {
646 res = res.to_dtype(t)?;
647 }
648 extract_logits(
649 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
650 context_lens,
651 )
652 } else {
653 let mut res = self
655 .inner_forward(
656 input_ids,
657 seqlen_offsets,
658 Some(scalings),
659 true,
660 no_kv_cache,
661 None,
662 flash_params,
663 )?
664 .contiguous()?;
665 if let Some(t) = self.lm_head.quantized_act_type() {
666 res = res.to_dtype(t)?;
667 }
668 extract_logits(
669 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
670 context_lens,
671 )
672 }
673 } else {
674 let mut res = self
675 .inner_forward(
676 input_ids,
677 seqlen_offsets,
678 None,
679 false,
680 no_kv_cache,
681 None,
682 flash_params,
683 )?
684 .contiguous()?;
685 if let Some(t) = self.lm_head.quantized_act_type() {
686 res = res.to_dtype(t)?;
687 }
688 extract_logits(
689 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
690 context_lens,
691 )
692 }
693 }
694}
695
696impl IsqModel for Model {
697 fn get_layers(
698 &mut self,
699 ) -> (
700 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
701 &dyn DeviceMapper,
702 ) {
703 let mut tensors = Vec::new();
704 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
705 for (i, layer) in self.layers.iter_mut().enumerate() {
706 tensors.push((
707 Arc::get_mut(&mut layer.self_attn.q_proj)
708 .unwrap()
709 .quant_inner(),
710 Some(i),
711 ));
712 tensors.push((
713 Arc::get_mut(&mut layer.self_attn.k_proj)
714 .unwrap()
715 .quant_inner(),
716 Some(i),
717 ));
718 tensors.push((
719 Arc::get_mut(&mut layer.self_attn.v_proj)
720 .unwrap()
721 .quant_inner(),
722 Some(i),
723 ));
724 tensors.push((
725 Arc::get_mut(&mut layer.self_attn.dense)
726 .unwrap()
727 .quant_inner(),
728 Some(i),
729 ));
730 tensors.push((
731 Arc::get_mut(&mut layer.mlp.fc1).unwrap().quant_inner(),
732 Some(i),
733 ));
734 tensors.push((
735 Arc::get_mut(&mut layer.mlp.fc2).unwrap().quant_inner(),
736 Some(i),
737 ));
738 }
739 (tensors, &*self.mapper)
740 }
741
742 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
743 panic!("Cannot generate UQFF for an adapter model.")
744 }
745}
746
747impl NormalModel for Model {
748 fn forward(
749 &self,
750 _input_ids: &Tensor,
751 _seqlen_offsets: &[usize],
752 _context_lens: Vec<(usize, usize)>,
753 _position_ids: Vec<usize>,
754 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
755 _flash_params: &FlashParams,
756 ) -> Result<Tensor> {
757 unreachable!()
758 }
759 fn xlora_forward(
760 &self,
761 input_ids: &Tensor,
762 input_ids_full: &Tensor,
763 seqlen_offsets: &[usize],
764 seqlen_offsets_full: &[usize],
765 no_kv_cache: bool,
766 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
767 context_lens: Vec<(usize, usize)>,
768 _position_ids: Vec<usize>,
769 flash_params: &FlashParams,
770 flash_params_full: &FlashParams,
771 ) -> Result<Tensor> {
772 self.forward(
773 input_ids,
774 input_ids_full,
775 seqlen_offsets,
776 seqlen_offsets_full,
777 no_kv_cache,
778 non_granular_state,
779 context_lens,
780 flash_params,
781 flash_params_full,
782 )
783 }
784 fn cache(&self) -> &EitherCache {
785 &self.cache
786 }
787 fn cache_mut(&mut self) -> &mut EitherCache {
788 &mut self.cache
789 }
790 fn device(&self) -> &Device {
791 &self.device
792 }
793 fn is_xlora(&self) -> bool {
794 true
795 }
796 fn max_seq_len(&self) -> usize {
797 self.max_seq_len
798 }
799 fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
800 if self.xlora_classifier.is_some() {
801 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
802 }
803 let mut sum = 0;
804 for layer in self.layers.iter_mut() {
805 sum += Arc::get_mut(&mut layer.self_attn.k_proj)
806 .unwrap()
807 .activate(&adapter_names)?;
808 sum += Arc::get_mut(&mut layer.self_attn.dense)
809 .unwrap()
810 .activate(&adapter_names)?;
811 sum += Arc::get_mut(&mut layer.self_attn.q_proj)
812 .unwrap()
813 .activate(&adapter_names)?;
814 sum += Arc::get_mut(&mut layer.self_attn.v_proj)
815 .unwrap()
816 .activate(&adapter_names)?;
817
818 sum += Arc::get_mut(&mut layer.mlp.fc1)
819 .unwrap()
820 .activate(&adapter_names)?;
821 sum += Arc::get_mut(&mut layer.mlp.fc2)
822 .unwrap()
823 .activate(&adapter_names)?;
824 }
825 Ok(sum)
826 }
827 fn config(&self) -> &ModelConfigMetadata {
828 &self.cfg
829 }
830}
831
832impl ScalingsMaker for Model {
833 fn dtype(&self) -> DType {
834 self.dtype
835 }
836 fn get_cache(&self) -> &EitherCache {
837 &self.cache
838 }
839 fn get_classifier(&self) -> &XLoraClassifier {
840 self.xlora_classifier.as_ref().unwrap()
841 }
842 fn forward(
843 &self,
844 input_ids: &Tensor,
845 seqlen_offsets: &[usize],
846 scalings: Tensor,
847 is_full_pass: bool,
848 no_kv_cache: bool,
849 is_scaling_pass: Option<f64>,
850 _context_lens: &[usize],
851 flash_params: &FlashParams,
852 ) -> Result<Tensor> {
853 self.inner_forward(
854 input_ids,
855 seqlen_offsets,
856 Some(scalings),
857 is_full_pass,
858 no_kv_cache,
859 is_scaling_pass,
860 flash_params,
861 )
862 }
863}
864
865impl AnyMoeBaseModelMixin for Model {}