1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Module, Result, Tensor};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, Activation, CausalMasker, RmsNorm, RotaryEmbedding, Sdpa},
15 lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16 models::gemma2::Config,
17 paged_attention::ModelConfigMetadata,
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22 },
23 utils::progress::NiceProgressBar,
24 Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 act_fn: Activation,
36}
37
38impl MLP {
39 #[allow(clippy::too_many_arguments)]
40 fn new(
41 cfg: &Config,
42 vb: ShardedVarBuilder,
43 lora_config: &[((String, String), LoraConfig)],
44 count: &mut usize,
45 ord: &Ordering,
46 mapper: &dyn DeviceMapper,
47 layer_idx: usize,
48 loading_isq: bool,
49 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
50 ) -> Result<Self> {
51 let hidden_sz = cfg.hidden_size;
52 let intermediate_sz = cfg.intermediate_size;
53 let gate_proj = linear_b(
54 hidden_sz,
55 intermediate_sz,
56 false,
57 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
58 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
59 lora_config,
60 count,
61 ord,
62 preload_adapters,
63 )?;
64 let up_proj = linear_b(
65 hidden_sz,
66 intermediate_sz,
67 false,
68 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70 lora_config,
71 count,
72 ord,
73 preload_adapters,
74 )?;
75 let down_proj = linear_b(
76 intermediate_sz,
77 hidden_sz,
78 false,
79 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
80 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
81 lora_config,
82 count,
83 ord,
84 preload_adapters,
85 )?;
86 Ok(Self {
87 gate_proj,
88 up_proj,
89 down_proj,
90 act_fn: cfg.hidden_act()?,
91 })
92 }
93
94 fn forward(
95 &self,
96 xs: &Tensor,
97 scalings: Option<Tensor>,
98 global_scaling_weight: f64,
99 is_scaling_pass: Option<f64>,
100 ) -> Result<Tensor> {
101 let original_dtype = xs.dtype();
102 let mut xs = xs.clone();
103 if let Some(t) = self.gate_proj.quantized_act_type() {
104 xs = xs.to_dtype(t)?;
105 }
106 let lhs = self
107 .gate_proj
108 .lora_forward(
109 &xs,
110 scalings.clone(),
111 global_scaling_weight,
112 is_scaling_pass,
113 )?
114 .apply(&self.act_fn)?;
115 let rhs = self.up_proj.lora_forward(
116 &xs,
117 scalings.clone(),
118 global_scaling_weight,
119 is_scaling_pass,
120 )?;
121 let mut res = self.down_proj.lora_forward(
122 &(lhs * rhs)?,
123 scalings,
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 if self.gate_proj.quantized_act_type().is_some() {
128 res = res.to_dtype(original_dtype)?;
129 }
130 Ok(res)
131 }
132}
133
134struct Attention {
135 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
139 num_heads: usize,
140 num_kv_heads: usize,
141 head_dim: usize,
142 rotary_emb: Arc<RotaryEmbedding>,
143 use_sliding_window: bool,
144 sliding_window: Option<usize>,
145 sdpa_params: SdpaParams,
146}
147
148impl Attention {
149 #[allow(clippy::too_many_arguments)]
150 fn new(
151 rotary_emb: Arc<RotaryEmbedding>,
152 cfg: &Config,
153 vb: ShardedVarBuilder,
154 lora_config: &[((String, String), LoraConfig)],
155 count: &mut usize,
156 ord: &Ordering,
157 mapper: &dyn DeviceMapper,
158 layer_idx: usize,
159 loading_isq: bool,
160 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
161 ) -> Result<Self> {
162 let hidden_sz = cfg.hidden_size;
163 let num_heads = cfg.num_attention_heads;
164 let num_kv_heads = cfg.num_key_value_heads;
165 let head_dim = cfg.head_dim;
166 let bias = cfg.attention_bias;
167 let q_proj = linear_b(
168 hidden_sz,
169 num_heads * head_dim,
170 bias,
171 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
172 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
173 lora_config,
174 count,
175 ord,
176 preload_adapters,
177 )?;
178 let k_proj = linear_b(
179 hidden_sz,
180 num_kv_heads * head_dim,
181 bias,
182 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
183 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
184 lora_config,
185 count,
186 ord,
187 preload_adapters,
188 )?;
189 let v_proj = linear_b(
190 hidden_sz,
191 num_kv_heads * head_dim,
192 bias,
193 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
194 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
195 lora_config,
196 count,
197 ord,
198 preload_adapters,
199 )?;
200 let o_proj = linear_b(
201 num_heads * head_dim,
202 hidden_sz,
203 bias,
204 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
205 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
206 lora_config,
207 count,
208 ord,
209 preload_adapters,
210 )?;
211 let sliding_window = if layer_idx % 2 == 0 {
212 Some(cfg.sliding_window)
214 } else {
215 None
216 };
217 Ok(Self {
218 q_proj,
219 k_proj,
220 v_proj,
221 o_proj,
222 num_heads,
223 num_kv_heads,
224 head_dim,
225 rotary_emb,
226 use_sliding_window: layer_idx % 2 == 0, sliding_window,
228 sdpa_params: SdpaParams {
229 n_kv_groups: num_heads / num_kv_heads,
230 use_flash_attn: cfg.use_flash_attn,
231 softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
232 softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
233 sliding_window,
234 },
235 })
236 }
237
238 #[allow(clippy::too_many_arguments)]
239 fn forward(
240 &self,
241 xs: &Tensor,
242 attention_mask: Option<&Tensor>,
243 sliding_attention_mask: Option<&Tensor>,
244 seqlen_offsets: &[usize],
245 kv_cache: &mut Option<(Tensor, Tensor)>,
246 scalings: Option<Tensor>,
247 global_scaling_weight: f64,
248 is_scaling_pass: Option<f64>,
249 flash_params: &FlashParams,
250 ) -> Result<Tensor> {
251 let (b_sz, q_len, _) = xs.dims3()?;
252
253 let original_dtype = xs.dtype();
254 let mut xs = xs.clone();
255 if let Some(t) = self.q_proj.quantized_act_type() {
256 xs = xs.to_dtype(t)?;
257 }
258 let mut q = self.q_proj.lora_forward(
259 &xs,
260 scalings.clone(),
261 global_scaling_weight,
262 is_scaling_pass,
263 )?;
264 let mut k = self.k_proj.lora_forward(
265 &xs,
266 scalings.clone(),
267 global_scaling_weight,
268 is_scaling_pass,
269 )?;
270 let mut v = self.v_proj.lora_forward(
271 &xs,
272 scalings.clone(),
273 global_scaling_weight,
274 is_scaling_pass,
275 )?;
276 if self.q_proj.quantized_act_type().is_some() {
277 q = q.to_dtype(original_dtype)?;
278 k = k.to_dtype(original_dtype)?;
279 v = v.to_dtype(original_dtype)?;
280 }
281
282 let (q, k, v) = if q_len != 1 {
283 let q = q
284 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
285 .transpose(1, 2)?;
286 let k = k
287 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
288 .transpose(1, 2)?;
289 let v = v
290 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
291 .transpose(1, 2)?;
292 (q, k, v)
293 } else {
294 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
295 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
296 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
297 (q, k, v)
298 };
299
300 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
301
302 let mask = if self.use_sliding_window {
303 sliding_attention_mask
304 } else {
305 attention_mask
306 };
307
308 let (k, v, mask) = Cache::update_kv_cache_sliding_window(
310 kv_cache,
311 k,
312 v,
313 mask,
314 self.sliding_window,
315 false,
316 )?;
317
318 let mut attn_output = Sdpa.run_attention(
319 &q,
320 &k,
321 &v,
322 mask.as_ref(),
323 Some(flash_params),
324 &self.sdpa_params,
325 )?;
326
327 if let Some(t) = self.q_proj.quantized_act_type() {
328 attn_output = attn_output.to_dtype(t)?;
329 }
330 let mut res = self.o_proj.lora_forward(
331 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
332 scalings.clone(),
333 global_scaling_weight,
334 is_scaling_pass,
335 )?;
336 if self.q_proj.quantized_act_type().is_some() {
337 res = res.to_dtype(original_dtype)?;
338 }
339 Ok(res)
340 }
341}
342
343struct DecoderLayer {
344 self_attn: Attention,
345 mlp: MLP,
346 input_layernorm: RmsNorm,
347 post_attention_layernorm: RmsNorm,
348 pre_feedforward_layernorm: RmsNorm,
349 post_feedforward_layernorm: RmsNorm,
350}
351
352impl DecoderLayer {
353 #[allow(clippy::too_many_arguments)]
354 fn new(
355 rotary_emb: Arc<RotaryEmbedding>,
356 cfg: &Config,
357 vb: ShardedVarBuilder,
358 lora_config: &[((String, String), LoraConfig)],
359 count: &mut usize,
360 ord: &Ordering,
361 mapper: &dyn DeviceMapper,
362 layer_idx: usize,
363 loading_isq: bool,
364 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
365 ) -> Result<Self> {
366 let self_attn = Attention::new(
367 rotary_emb,
368 cfg,
369 vb.pp("self_attn"),
370 lora_config,
371 count,
372 ord,
373 mapper,
374 layer_idx,
375 loading_isq,
376 preload_adapters,
377 )?;
378 let mlp = MLP::new(
379 cfg,
380 vb.pp("mlp"),
381 lora_config,
382 count,
383 ord,
384 mapper,
385 layer_idx,
386 loading_isq,
387 preload_adapters,
388 )?;
389 let input_layernorm = RmsNorm::new_gemma(
390 cfg.hidden_size,
391 cfg.rms_norm_eps,
392 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
393 )?;
394 let post_attention_layernorm = RmsNorm::new_gemma(
395 cfg.hidden_size,
396 cfg.rms_norm_eps,
397 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
398 )?;
399 let pre_feedforward_layernorm = RmsNorm::new_gemma(
400 cfg.hidden_size,
401 cfg.rms_norm_eps,
402 mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
403 )?;
404 let post_feedforward_layernorm = RmsNorm::new_gemma(
405 cfg.hidden_size,
406 cfg.rms_norm_eps,
407 mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
408 )?;
409 Ok(Self {
410 self_attn,
411 mlp,
412 input_layernorm,
413 post_attention_layernorm,
414 pre_feedforward_layernorm,
415 post_feedforward_layernorm,
416 })
417 }
418
419 #[allow(clippy::too_many_arguments)]
420 fn forward(
421 &self,
422 xs: &Tensor,
423 attention_mask: Option<&Tensor>,
424 sliding_attention_mask: Option<&Tensor>,
425 seqlen_offsets: &[usize],
426 kv_cache: &mut Option<(Tensor, Tensor)>,
427 scalings: Option<Tensor>,
428 global_scaling_weight: f64,
429 is_scaling_pass: Option<f64>,
430 flash_params: &FlashParams,
431 ) -> Result<Tensor> {
432 let residual = xs;
433 let xs = self.input_layernorm.forward(xs)?;
434 let xs = self
435 .self_attn
436 .forward(
437 &xs,
438 attention_mask,
439 sliding_attention_mask,
440 seqlen_offsets,
441 kv_cache,
442 scalings.clone(),
443 global_scaling_weight,
444 is_scaling_pass,
445 flash_params,
446 )?
447 .apply(&self.post_attention_layernorm)?;
448 let xs = (xs + residual)?;
449 let residual = &xs;
450 let xs = self
451 .mlp
452 .forward(
453 &xs.apply(&self.pre_feedforward_layernorm)?,
454 scalings,
455 global_scaling_weight,
456 is_scaling_pass,
457 )?
458 .apply(&self.post_feedforward_layernorm)?;
459 residual + xs
460 }
461}
462
463pub struct Model {
464 embed_tokens: candle_nn::Embedding,
465 layers: Vec<DecoderLayer>,
466 norm: RmsNorm,
467 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
468 hidden_size: usize,
469 device: Device,
470 cache: EitherCache,
471 max_seq_len: usize,
472 mapper: Box<dyn DeviceMapper + Send + Sync>,
473 sliding_window: usize,
474 final_logit_softcapping: Option<f64>,
475 xlora_classifier: Option<XLoraClassifier>,
476 dtype: DType,
477 cfg: ModelConfigMetadata,
478}
479
480impl Model {
481 #[allow(clippy::too_many_arguments)]
482 pub fn new(
483 cfg: &Config,
484 vb: ShardedVarBuilder,
485 lora_config: &[((String, String), LoraConfig)],
486 xlora_config: Option<XLoraConfig>,
487 xlora_ordering: Ordering,
488 is_gptx: bool,
489 normal_loading_metadata: NormalLoadingMetadata,
490 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
491 ) -> Result<Self> {
492 if let Some(ref quant_cfg) = &cfg.quantization_config {
493 tracing::info!(
494 "Using {} quantization: {}.",
495 quant_cfg.quant_method.to_string(),
496 quant_cfg.get_bits_name(&vb)
497 );
498 }
499 let mapper = normal_loading_metadata.mapper;
500
501 let vb_m = vb.pp("model");
502 let embed_tokens = layers::embedding(
503 cfg.vocab_size,
504 cfg.hidden_size,
505 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
506 )?;
507 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
508 let vb_l = vb_m.pp("layers");
509 let mut ropes = HashMap::new();
510 for layer_idx in 0..cfg.num_hidden_layers {
511 let device = mapper
512 .device_for(layer_idx, false)
513 .unwrap_or(&normal_loading_metadata.real_device);
514 ropes.insert(
515 device.location(),
516 Arc::new(RotaryEmbedding::new(
517 cfg.rope_theta as f32,
518 cfg.head_dim,
519 cfg.max_position_embeddings,
520 device,
521 is_gptx,
522 vb.dtype(),
523 )?),
524 );
525 }
526 let mut count = 0;
527 for layer_idx in NiceProgressBar::<_, 'b'>(
528 0..cfg.num_hidden_layers,
529 "Loading repeating layers",
530 &normal_loading_metadata.multi_progress,
531 ) {
532 let device = mapper
533 .device_for(layer_idx, false)
534 .unwrap_or(&normal_loading_metadata.real_device);
535 let rotary_emb = ropes
536 .get(&device.location())
537 .expect("No RoPE for device location!")
538 .clone();
539 let layer = DecoderLayer::new(
540 rotary_emb.clone(),
541 cfg,
542 vb_l.pp(layer_idx),
543 lora_config,
544 &mut count,
545 &xlora_ordering,
546 &*mapper,
547 layer_idx,
548 normal_loading_metadata.loading_isq,
549 preload_adapters,
550 )?;
551 layers.push(layer)
552 }
553 if xlora_config.is_none() && preload_adapters.is_none() {
554 info!("Merging LoRA adapters.");
556 for layer in layers.iter_mut().tqdm() {
557 Arc::get_mut(&mut layer.self_attn.k_proj)
558 .unwrap()
559 .merge_weights()?;
560 Arc::get_mut(&mut layer.self_attn.o_proj)
561 .unwrap()
562 .merge_weights()?;
563 Arc::get_mut(&mut layer.self_attn.q_proj)
564 .unwrap()
565 .merge_weights()?;
566 Arc::get_mut(&mut layer.self_attn.v_proj)
567 .unwrap()
568 .merge_weights()?;
569
570 Arc::get_mut(&mut layer.mlp.down_proj)
571 .unwrap()
572 .merge_weights()?;
573 Arc::get_mut(&mut layer.mlp.gate_proj)
574 .unwrap()
575 .merge_weights()?;
576 Arc::get_mut(&mut layer.mlp.up_proj)
577 .unwrap()
578 .merge_weights()?;
579 }
580 }
581 let norm = RmsNorm::new_gemma(
582 cfg.hidden_size,
583 cfg.rms_norm_eps,
584 mapper.set_nm_device(vb_m.pp("norm"), false),
585 )?;
586
587 let lm_head = linear_no_bias(
588 embed_tokens.embeddings().dim(1)?,
589 embed_tokens.embeddings().dim(0)?,
590 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
591 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
592 lora_config,
593 &mut count,
594 &xlora_ordering,
595 preload_adapters,
596 )?;
597 if xlora_config.is_some() && lm_head.is_lora() {
598 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
600 }
601
602 Ok(Self {
603 embed_tokens,
604 layers,
605 norm,
606 lm_head,
607 device: normal_loading_metadata.real_device,
608 hidden_size: cfg.hidden_size,
609 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
610 max_seq_len: cfg.max_position_embeddings,
611 mapper,
612 sliding_window: cfg.sliding_window,
613 final_logit_softcapping: cfg.final_logit_softcapping,
614 dtype: vb.dtype(),
615 xlora_classifier: xlora_config.map(|xlora_config| {
616 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
617 }),
618 cfg: ModelConfigMetadata {
619 max_seq_len: cfg.max_position_embeddings,
620 num_layers: cfg.num_hidden_layers,
621 hidden_size: cfg.hidden_size,
622 num_kv_heads: cfg.num_key_value_heads,
623 num_attn_heads: cfg.num_attention_heads,
624 sliding_window: None,
625 k_head_dim: cfg.head_dim,
626 v_head_dim: cfg.head_dim,
627 },
628 })
629 }
630
631 #[allow(clippy::too_many_arguments)]
632 fn inner_forward(
633 &self,
634 input_ids: &Tensor,
635 seqlen_offsets: &[usize],
636 scalings: Option<Tensor>,
637 is_full_pass: bool,
638 no_kv_cache: bool,
639 is_scaling_pass: Option<f64>,
640 flash_params: &FlashParams,
641 ) -> Result<Tensor> {
642 let xs = self.embed_tokens.forward(input_ids)?;
643 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
644 let mut cache = if is_full_pass {
645 if no_kv_cache {
646 let mut new_cache = Vec::new();
647 for _ in 0..self.cache.full().xlora_lock().len() {
648 new_cache.push(None);
649 }
650
651 self.cache.full().xlora_lock().clone_from(&new_cache);
652 }
653 self.cache.full().xlora_lock()
654 } else {
655 self.cache.full().lock()
656 };
657 let attention_mask = CausalMasker.make_causal_mask_matrix(
658 input_ids,
659 &*cache,
660 xs.dtype(),
661 self.cfg.num_attn_heads,
662 )?;
663 let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
664 input_ids,
665 &*cache,
666 Some(self.sliding_window),
667 xs.dtype(),
668 self.cfg.num_attn_heads,
669 )?;
670 for (i, layer) in self.layers.iter().enumerate() {
671 xs = self.mapper.map(xs, i)?;
672 xs = layer.forward(
673 &xs,
674 attention_mask
675 .as_ref()
676 .map(|m| m.to_device(xs.device()).unwrap())
677 .as_ref(),
678 sliding_attention_mask
679 .as_ref()
680 .map(|m| m.to_device(xs.device()).unwrap())
681 .as_ref(),
682 seqlen_offsets,
683 &mut cache[i],
684 scalings.clone(),
685 self.xlora_classifier
686 .as_ref()
687 .map(|classifier| classifier.get_global_scaling_weight())
688 .unwrap_or(1.0),
689 is_scaling_pass,
690 flash_params,
691 )?;
692 }
693 let xs = xs.to_device(&self.device)?;
694 let mut xs = xs.apply(&self.norm)?;
695 if let Some(t) = self.lm_head.quantized_act_type() {
696 xs = xs.to_dtype(t)?;
697 }
698
699 let mut xs = self.lm_head.lora_forward(&xs, None, 1.0, None)?;
700
701 if let Some(final_logit_softcapping) = self.final_logit_softcapping {
702 xs = (xs / final_logit_softcapping)?;
703 xs = xs.tanh()?;
704 xs = (xs * final_logit_softcapping)?;
705 }
706
707 Ok(xs)
708 }
709
710 #[allow(clippy::too_many_arguments)]
711 pub fn forward(
712 &self,
713 input_ids: &Tensor,
714 input_ids_full: &Tensor,
715 seqlen_offsets: &[usize],
716 seqlen_offsets_full: &[usize],
717 no_kv_cache: bool,
718 non_granular_state: &Option<NonGranularState>,
719 context_lens: Vec<(usize, usize)>,
720 flash_params: &FlashParams,
721 flash_params_full: &FlashParams,
722 ) -> Result<Tensor> {
723 if self.xlora_classifier.is_some() {
724 let scalings = self.get_scalings(
725 input_ids,
726 input_ids_full,
727 seqlen_offsets,
728 seqlen_offsets_full,
729 no_kv_cache,
730 non_granular_state,
731 &vec![usize::MAX; context_lens.len()],
732 flash_params,
733 flash_params_full,
734 )?;
735
736 if no_kv_cache {
737 let mut res = self
738 .inner_forward(
739 input_ids_full,
740 seqlen_offsets_full,
741 Some(scalings),
742 true,
743 no_kv_cache,
744 None,
745 flash_params_full,
746 )?
747 .contiguous()?;
748 if let Some(t) = self.lm_head.quantized_act_type() {
749 res = res.to_dtype(t)?;
750 }
751 extract_logits(
752 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
753 context_lens,
754 )
755 } else {
756 let mut res = self
758 .inner_forward(
759 input_ids,
760 seqlen_offsets,
761 Some(scalings),
762 true,
763 no_kv_cache,
764 None,
765 flash_params,
766 )?
767 .contiguous()?;
768 if let Some(t) = self.lm_head.quantized_act_type() {
769 res = res.to_dtype(t)?;
770 }
771 extract_logits(
772 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
773 context_lens,
774 )
775 }
776 } else {
777 let mut res = self
778 .inner_forward(
779 input_ids,
780 seqlen_offsets,
781 None,
782 false,
783 no_kv_cache,
784 None,
785 flash_params,
786 )?
787 .contiguous()?;
788 if let Some(t) = self.lm_head.quantized_act_type() {
789 res = res.to_dtype(t)?;
790 }
791 extract_logits(
792 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
793 context_lens,
794 )
795 }
796 }
797}
798
799impl IsqModel for Model {
800 fn get_layers(
801 &mut self,
802 ) -> (
803 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
804 &dyn DeviceMapper,
805 ) {
806 let mut tensors = Vec::new();
807 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
808 for (i, layer) in self.layers.iter_mut().enumerate() {
809 tensors.push((
810 Arc::get_mut(&mut layer.self_attn.q_proj)
811 .unwrap()
812 .quant_inner(),
813 Some(i),
814 ));
815 tensors.push((
816 Arc::get_mut(&mut layer.self_attn.k_proj)
817 .unwrap()
818 .quant_inner(),
819 Some(i),
820 ));
821 tensors.push((
822 Arc::get_mut(&mut layer.self_attn.v_proj)
823 .unwrap()
824 .quant_inner(),
825 Some(i),
826 ));
827 tensors.push((
828 Arc::get_mut(&mut layer.self_attn.o_proj)
829 .unwrap()
830 .quant_inner(),
831 Some(i),
832 ));
833 tensors.push((
834 Arc::get_mut(&mut layer.mlp.gate_proj)
835 .unwrap()
836 .quant_inner(),
837 Some(i),
838 ));
839 tensors.push((
840 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
841 Some(i),
842 ));
843 tensors.push((
844 Arc::get_mut(&mut layer.mlp.down_proj)
845 .unwrap()
846 .quant_inner(),
847 Some(i),
848 ));
849 }
850 (tensors, &*self.mapper)
851 }
852
853 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
854 panic!("Cannot generate UQFF for an adapter model.")
855 }
856}
857
858impl NormalModel for Model {
859 fn forward(
860 &self,
861 _input_ids: &Tensor,
862 _seqlen_offsets: &[usize],
863 _context_lens: Vec<(usize, usize)>,
864 _position_ids: Vec<usize>,
865 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
866 _flash_params: &FlashParams,
867 ) -> Result<Tensor> {
868 unreachable!()
869 }
870 fn xlora_forward(
871 &self,
872 input_ids: &Tensor,
873 input_ids_full: &Tensor,
874 seqlen_offsets: &[usize],
875 seqlen_offsets_full: &[usize],
876 no_kv_cache: bool,
877 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
878 context_lens: Vec<(usize, usize)>,
879 _position_ids: Vec<usize>,
880 flash_params: &FlashParams,
881 flash_params_full: &FlashParams,
882 ) -> Result<Tensor> {
883 self.forward(
884 input_ids,
885 input_ids_full,
886 seqlen_offsets,
887 seqlen_offsets_full,
888 no_kv_cache,
889 non_granular_state,
890 context_lens,
891 flash_params,
892 flash_params_full,
893 )
894 }
895 fn cache(&self) -> &EitherCache {
896 &self.cache
897 }
898 fn cache_mut(&mut self) -> &mut EitherCache {
899 &mut self.cache
900 }
901 fn device(&self) -> &Device {
902 &self.device
903 }
904 fn is_xlora(&self) -> bool {
905 false
906 }
907 fn max_seq_len(&self) -> usize {
908 self.max_seq_len
909 }
910 fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
911 if self.xlora_classifier.is_some() {
912 candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
913 }
914 let mut sum = 0;
915 for layer in self.layers.iter_mut() {
916 sum += Arc::get_mut(&mut layer.self_attn.k_proj)
917 .unwrap()
918 .activate(&adapter_names)?;
919 sum += Arc::get_mut(&mut layer.self_attn.o_proj)
920 .unwrap()
921 .activate(&adapter_names)?;
922 sum += Arc::get_mut(&mut layer.self_attn.q_proj)
923 .unwrap()
924 .activate(&adapter_names)?;
925 sum += Arc::get_mut(&mut layer.self_attn.v_proj)
926 .unwrap()
927 .activate(&adapter_names)?;
928
929 sum += Arc::get_mut(&mut layer.mlp.down_proj)
930 .unwrap()
931 .activate(&adapter_names)?;
932 sum += Arc::get_mut(&mut layer.mlp.gate_proj)
933 .unwrap()
934 .activate(&adapter_names)?;
935 sum += Arc::get_mut(&mut layer.mlp.up_proj)
936 .unwrap()
937 .activate(&adapter_names)?;
938 }
939 Ok(sum)
940 }
941 fn config(&self) -> &ModelConfigMetadata {
942 &self.cfg
943 }
944}
945
946impl ScalingsMaker for Model {
947 fn dtype(&self) -> DType {
948 self.dtype
949 }
950 fn get_cache(&self) -> &EitherCache {
951 &self.cache
952 }
953 fn get_classifier(&self) -> &XLoraClassifier {
954 self.xlora_classifier.as_ref().unwrap()
955 }
956 fn forward(
957 &self,
958 input_ids: &Tensor,
959 seqlen_offsets: &[usize],
960 scalings: Tensor,
961 is_full_pass: bool,
962 no_kv_cache: bool,
963 is_scaling_pass: Option<f64>,
964 _context_lens: &[usize],
965 flash_params: &FlashParams,
966 ) -> Result<Tensor> {
967 self.inner_forward(
968 input_ids,
969 seqlen_offsets,
970 Some(scalings),
971 is_full_pass,
972 no_kv_cache,
973 is_scaling_pass,
974 flash_params,
975 )
976 }
977}
978
979impl AnyMoeBaseModelMixin for Model {}