1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Module, Result, Tensor};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, Activation, CausalMasker, RmsNorm, RotaryEmbedding, Sdpa},
15 lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16 models::gemma2::Config,
17 paged_attention::ModelConfigMetadata,
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22 },
23 utils::progress::NiceProgressBar,
24 Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 act_fn: Activation,
36}
37
38impl MLP {
39 #[allow(clippy::too_many_arguments)]
40 fn new(
41 cfg: &Config,
42 vb: ShardedVarBuilder,
43 lora_config: &[((String, String), LoraConfig)],
44 count: &mut usize,
45 ord: &Ordering,
46 mapper: &dyn DeviceMapper,
47 layer_idx: usize,
48 loading_isq: bool,
49 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
50 ) -> Result<Self> {
51 let hidden_sz = cfg.hidden_size;
52 let intermediate_sz = cfg.intermediate_size;
53 let gate_proj = linear_b(
54 hidden_sz,
55 intermediate_sz,
56 false,
57 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
58 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
59 lora_config,
60 count,
61 ord,
62 preload_adapters,
63 )?;
64 let up_proj = linear_b(
65 hidden_sz,
66 intermediate_sz,
67 false,
68 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70 lora_config,
71 count,
72 ord,
73 preload_adapters,
74 )?;
75 let down_proj = linear_b(
76 intermediate_sz,
77 hidden_sz,
78 false,
79 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
80 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
81 lora_config,
82 count,
83 ord,
84 preload_adapters,
85 )?;
86 Ok(Self {
87 gate_proj,
88 up_proj,
89 down_proj,
90 act_fn: cfg.hidden_act()?,
91 })
92 }
93
94 fn forward(
95 &self,
96 xs: &Tensor,
97 scalings: Option<Tensor>,
98 global_scaling_weight: f64,
99 is_scaling_pass: Option<f64>,
100 ) -> Result<Tensor> {
101 let original_dtype = xs.dtype();
102 let mut xs = xs.clone();
103 if let Some(t) = self.gate_proj.quantized_act_type() {
104 xs = xs.to_dtype(t)?;
105 }
106 let lhs = self
107 .gate_proj
108 .lora_forward(
109 &xs,
110 scalings.clone(),
111 global_scaling_weight,
112 is_scaling_pass,
113 )?
114 .apply(&self.act_fn)?;
115 let rhs = self.up_proj.lora_forward(
116 &xs,
117 scalings.clone(),
118 global_scaling_weight,
119 is_scaling_pass,
120 )?;
121 let mut res = self.down_proj.lora_forward(
122 &(lhs * rhs)?,
123 scalings,
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 if self.gate_proj.quantized_act_type().is_some() {
128 res = res.to_dtype(original_dtype)?;
129 }
130 Ok(res)
131 }
132}
133
134struct Attention {
135 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
139 num_heads: usize,
140 num_kv_heads: usize,
141 head_dim: usize,
142 rotary_emb: Arc<RotaryEmbedding>,
143 use_sliding_window: bool,
144 sliding_window: Option<usize>,
145 sdpa_params: SdpaParams,
146}
147
148impl Attention {
149 #[allow(clippy::too_many_arguments)]
150 fn new(
151 rotary_emb: Arc<RotaryEmbedding>,
152 cfg: &Config,
153 vb: ShardedVarBuilder,
154 lora_config: &[((String, String), LoraConfig)],
155 count: &mut usize,
156 ord: &Ordering,
157 mapper: &dyn DeviceMapper,
158 layer_idx: usize,
159 loading_isq: bool,
160 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
161 ) -> Result<Self> {
162 let hidden_sz = cfg.hidden_size;
163 let num_heads = cfg.num_attention_heads;
164 let num_kv_heads = cfg.num_key_value_heads;
165 let head_dim = cfg.head_dim;
166 let bias = cfg.attention_bias;
167 let q_proj = linear_b(
168 hidden_sz,
169 num_heads * head_dim,
170 bias,
171 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
172 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
173 lora_config,
174 count,
175 ord,
176 preload_adapters,
177 )?;
178 let k_proj = linear_b(
179 hidden_sz,
180 num_kv_heads * head_dim,
181 bias,
182 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
183 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
184 lora_config,
185 count,
186 ord,
187 preload_adapters,
188 )?;
189 let v_proj = linear_b(
190 hidden_sz,
191 num_kv_heads * head_dim,
192 bias,
193 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
194 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
195 lora_config,
196 count,
197 ord,
198 preload_adapters,
199 )?;
200 let o_proj = linear_b(
201 num_heads * head_dim,
202 hidden_sz,
203 bias,
204 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
205 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
206 lora_config,
207 count,
208 ord,
209 preload_adapters,
210 )?;
211 let sliding_window = if layer_idx % 2 == 0 {
212 Some(cfg.sliding_window)
214 } else {
215 None
216 };
217 Ok(Self {
218 q_proj,
219 k_proj,
220 v_proj,
221 o_proj,
222 num_heads,
223 num_kv_heads,
224 head_dim,
225 rotary_emb,
226 use_sliding_window: layer_idx % 2 == 0, sliding_window,
228 sdpa_params: SdpaParams {
229 n_kv_groups: num_heads / num_kv_heads,
230 use_flash_attn: cfg.use_flash_attn,
231 softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
232 softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
233 sliding_window,
234 },
235 })
236 }
237
238 #[allow(clippy::too_many_arguments)]
239 fn forward(
240 &self,
241 xs: &Tensor,
242 attention_mask: Option<&Tensor>,
243 sliding_attention_mask: Option<&Tensor>,
244 seqlen_offsets: &[usize],
245 kv_cache: &mut Option<(Tensor, Tensor)>,
246 scalings: Option<Tensor>,
247 global_scaling_weight: f64,
248 is_scaling_pass: Option<f64>,
249 flash_params: &FlashParams,
250 ) -> Result<Tensor> {
251 let (b_sz, q_len, _) = xs.dims3()?;
252
253 let original_dtype = xs.dtype();
254 let mut xs = xs.clone();
255 if let Some(t) = self.q_proj.quantized_act_type() {
256 xs = xs.to_dtype(t)?;
257 }
258 let mut q = self.q_proj.lora_forward(
259 &xs,
260 scalings.clone(),
261 global_scaling_weight,
262 is_scaling_pass,
263 )?;
264 let mut k = self.k_proj.lora_forward(
265 &xs,
266 scalings.clone(),
267 global_scaling_weight,
268 is_scaling_pass,
269 )?;
270 let mut v = self.v_proj.lora_forward(
271 &xs,
272 scalings.clone(),
273 global_scaling_weight,
274 is_scaling_pass,
275 )?;
276 if self.q_proj.quantized_act_type().is_some() {
277 q = q.to_dtype(original_dtype)?;
278 k = k.to_dtype(original_dtype)?;
279 v = v.to_dtype(original_dtype)?;
280 }
281
282 let (q, k, v) = if q_len != 1 {
283 let q = q
284 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
285 .transpose(1, 2)?;
286 let k = k
287 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
288 .transpose(1, 2)?;
289 let v = v
290 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
291 .transpose(1, 2)?;
292 (q, k, v)
293 } else {
294 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
295 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
296 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
297 (q, k, v)
298 };
299
300 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
301
302 let mask = if self.use_sliding_window {
303 sliding_attention_mask
304 } else {
305 attention_mask
306 };
307
308 let (k, v, mask) = Cache::update_kv_cache_sliding_window(
310 kv_cache,
311 k,
312 v,
313 mask,
314 self.sliding_window,
315 false,
316 )?;
317
318 let mut attn_output = Sdpa.run_attention(
319 &q,
320 &k,
321 &v,
322 mask.as_ref(),
323 Some(flash_params),
324 &self.sdpa_params,
325 )?;
326
327 if let Some(t) = self.q_proj.quantized_act_type() {
328 attn_output = attn_output.to_dtype(t)?;
329 }
330 let mut res = self.o_proj.lora_forward(
331 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
332 scalings.clone(),
333 global_scaling_weight,
334 is_scaling_pass,
335 )?;
336 if self.q_proj.quantized_act_type().is_some() {
337 res = res.to_dtype(original_dtype)?;
338 }
339 Ok(res)
340 }
341}
342
343struct DecoderLayer {
344 self_attn: Attention,
345 mlp: MLP,
346 input_layernorm: RmsNorm,
347 post_attention_layernorm: RmsNorm,
348 pre_feedforward_layernorm: RmsNorm,
349 post_feedforward_layernorm: RmsNorm,
350}
351
352impl DecoderLayer {
353 #[allow(clippy::too_many_arguments)]
354 fn new(
355 rotary_emb: Arc<RotaryEmbedding>,
356 cfg: &Config,
357 vb: ShardedVarBuilder,
358 lora_config: &[((String, String), LoraConfig)],
359 count: &mut usize,
360 ord: &Ordering,
361 mapper: &dyn DeviceMapper,
362 layer_idx: usize,
363 loading_isq: bool,
364 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
365 ) -> Result<Self> {
366 let self_attn = Attention::new(
367 rotary_emb,
368 cfg,
369 vb.pp("self_attn"),
370 lora_config,
371 count,
372 ord,
373 mapper,
374 layer_idx,
375 loading_isq,
376 preload_adapters,
377 )?;
378 let mlp = MLP::new(
379 cfg,
380 vb.pp("mlp"),
381 lora_config,
382 count,
383 ord,
384 mapper,
385 layer_idx,
386 loading_isq,
387 preload_adapters,
388 )?;
389 let input_layernorm = RmsNorm::new_gemma(
390 cfg.hidden_size,
391 cfg.rms_norm_eps,
392 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
393 )?;
394 let post_attention_layernorm = RmsNorm::new_gemma(
395 cfg.hidden_size,
396 cfg.rms_norm_eps,
397 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
398 )?;
399 let pre_feedforward_layernorm = RmsNorm::new_gemma(
400 cfg.hidden_size,
401 cfg.rms_norm_eps,
402 mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
403 )?;
404 let post_feedforward_layernorm = RmsNorm::new_gemma(
405 cfg.hidden_size,
406 cfg.rms_norm_eps,
407 mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
408 )?;
409 Ok(Self {
410 self_attn,
411 mlp,
412 input_layernorm,
413 post_attention_layernorm,
414 pre_feedforward_layernorm,
415 post_feedforward_layernorm,
416 })
417 }
418
419 #[allow(clippy::too_many_arguments)]
420 fn forward(
421 &self,
422 xs: &Tensor,
423 attention_mask: Option<&Tensor>,
424 sliding_attention_mask: Option<&Tensor>,
425 seqlen_offsets: &[usize],
426 kv_cache: &mut Option<(Tensor, Tensor)>,
427 scalings: Option<Tensor>,
428 global_scaling_weight: f64,
429 is_scaling_pass: Option<f64>,
430 flash_params: &FlashParams,
431 ) -> Result<Tensor> {
432 let residual = xs;
433 let xs = self.input_layernorm.forward(xs)?;
434 let xs = self
435 .self_attn
436 .forward(
437 &xs,
438 attention_mask,
439 sliding_attention_mask,
440 seqlen_offsets,
441 kv_cache,
442 scalings.clone(),
443 global_scaling_weight,
444 is_scaling_pass,
445 flash_params,
446 )?
447 .apply(&self.post_attention_layernorm)?;
448 let xs = (xs + residual)?;
449 let residual = &xs;
450 let xs = self
451 .mlp
452 .forward(
453 &xs.apply(&self.pre_feedforward_layernorm)?,
454 scalings,
455 global_scaling_weight,
456 is_scaling_pass,
457 )?
458 .apply(&self.post_feedforward_layernorm)?;
459 residual + xs
460 }
461}
462
463pub struct Model {
464 embed_tokens: candle_nn::Embedding,
465 layers: Vec<DecoderLayer>,
466 norm: RmsNorm,
467 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
468 hidden_size: usize,
469 device: Device,
470 cache: EitherCache,
471 max_seq_len: usize,
472 mapper: Box<dyn DeviceMapper + Send + Sync>,
473 sliding_window: usize,
474 final_logit_softcapping: Option<f64>,
475 xlora_classifier: Option<XLoraClassifier>,
476 dtype: DType,
477 cfg: ModelConfigMetadata,
478}
479
480impl Model {
481 #[allow(clippy::too_many_arguments)]
482 pub fn new(
483 cfg: &Config,
484 vb: ShardedVarBuilder,
485 lora_config: &[((String, String), LoraConfig)],
486 xlora_config: Option<XLoraConfig>,
487 xlora_ordering: Ordering,
488 is_gptx: bool,
489 normal_loading_metadata: NormalLoadingMetadata,
490 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
491 ) -> Result<Self> {
492 if let Some(ref quant_cfg) = &cfg.quantization_config {
493 tracing::info!(
494 "Using {} quantization: {}.",
495 quant_cfg.name(),
496 quant_cfg.get_bits_name(&vb)
497 );
498 }
499 let mapper = normal_loading_metadata.mapper;
500
501 let vb_m = vb.pp("model");
502 let embed_tokens = layers::embedding(
503 cfg.vocab_size,
504 cfg.hidden_size,
505 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
506 &cfg.quantization_config,
507 )?;
508 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
509 let vb_l = vb_m.pp("layers");
510 let mut ropes = HashMap::new();
511 for layer_idx in 0..cfg.num_hidden_layers {
512 let device = mapper
513 .device_for(layer_idx, false)
514 .unwrap_or(&normal_loading_metadata.real_device);
515 ropes.insert(
516 device.location(),
517 Arc::new(RotaryEmbedding::new(
518 cfg.rope_theta as f32,
519 cfg.head_dim,
520 cfg.max_position_embeddings,
521 device,
522 is_gptx,
523 vb.dtype(),
524 )?),
525 );
526 }
527 let mut count = 0;
528 for layer_idx in NiceProgressBar::<_, 'b'>(
529 0..cfg.num_hidden_layers,
530 "Loading repeating layers",
531 &normal_loading_metadata.multi_progress,
532 ) {
533 let device = mapper
534 .device_for(layer_idx, false)
535 .unwrap_or(&normal_loading_metadata.real_device);
536 let rotary_emb = ropes
537 .get(&device.location())
538 .expect("No RoPE for device location!")
539 .clone();
540 let layer = DecoderLayer::new(
541 rotary_emb.clone(),
542 cfg,
543 vb_l.pp(layer_idx),
544 lora_config,
545 &mut count,
546 &xlora_ordering,
547 &*mapper,
548 layer_idx,
549 normal_loading_metadata.loading_isq,
550 preload_adapters,
551 )?;
552 layers.push(layer)
553 }
554 if xlora_config.is_none() && preload_adapters.is_none() {
555 info!("Merging LoRA adapters.");
557 for layer in layers.iter_mut().tqdm() {
558 Arc::get_mut(&mut layer.self_attn.k_proj)
559 .unwrap()
560 .merge_weights()?;
561 Arc::get_mut(&mut layer.self_attn.o_proj)
562 .unwrap()
563 .merge_weights()?;
564 Arc::get_mut(&mut layer.self_attn.q_proj)
565 .unwrap()
566 .merge_weights()?;
567 Arc::get_mut(&mut layer.self_attn.v_proj)
568 .unwrap()
569 .merge_weights()?;
570
571 Arc::get_mut(&mut layer.mlp.down_proj)
572 .unwrap()
573 .merge_weights()?;
574 Arc::get_mut(&mut layer.mlp.gate_proj)
575 .unwrap()
576 .merge_weights()?;
577 Arc::get_mut(&mut layer.mlp.up_proj)
578 .unwrap()
579 .merge_weights()?;
580 }
581 }
582 let norm = RmsNorm::new_gemma(
583 cfg.hidden_size,
584 cfg.rms_norm_eps,
585 mapper.set_nm_device(vb_m.pp("norm"), false),
586 )?;
587
588 let lm_head = linear_no_bias(
589 embed_tokens.embeddings().dim(1)?,
590 embed_tokens.embeddings().dim(0)?,
591 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
592 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
593 lora_config,
594 &mut count,
595 &xlora_ordering,
596 preload_adapters,
597 )?;
598 if xlora_config.is_some() && lm_head.is_lora() {
599 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
601 }
602
603 Ok(Self {
604 embed_tokens,
605 layers,
606 norm,
607 lm_head,
608 device: normal_loading_metadata.real_device,
609 hidden_size: cfg.hidden_size,
610 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
611 max_seq_len: cfg.max_position_embeddings,
612 mapper,
613 sliding_window: cfg.sliding_window,
614 final_logit_softcapping: cfg.final_logit_softcapping,
615 dtype: vb.dtype(),
616 xlora_classifier: xlora_config.map(|xlora_config| {
617 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
618 }),
619 cfg: ModelConfigMetadata {
620 max_seq_len: cfg.max_position_embeddings,
621 num_layers: cfg.num_hidden_layers,
622 hidden_size: cfg.hidden_size,
623 num_kv_heads: cfg.num_key_value_heads,
624 num_attn_heads: cfg.num_attention_heads,
625 sliding_window: None,
626 k_head_dim: cfg.head_dim,
627 v_head_dim: cfg.head_dim,
628 },
629 })
630 }
631
632 #[allow(clippy::too_many_arguments)]
633 fn inner_forward(
634 &self,
635 input_ids: &Tensor,
636 seqlen_offsets: &[usize],
637 scalings: Option<Tensor>,
638 is_full_pass: bool,
639 no_kv_cache: bool,
640 is_scaling_pass: Option<f64>,
641 flash_params: &FlashParams,
642 ) -> Result<Tensor> {
643 let xs = self.embed_tokens.forward(input_ids)?;
644 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
645 let mut cache = if is_full_pass {
646 if no_kv_cache {
647 let mut new_cache = Vec::new();
648 for _ in 0..self.cache.full().xlora_lock().len() {
649 new_cache.push(None);
650 }
651
652 self.cache.full().xlora_lock().clone_from(&new_cache);
653 }
654 self.cache.full().xlora_lock()
655 } else {
656 self.cache.full().lock()
657 };
658 let attention_mask = CausalMasker.make_causal_mask_matrix(
659 input_ids,
660 &*cache,
661 xs.dtype(),
662 self.cfg.num_attn_heads,
663 )?;
664 let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
665 input_ids,
666 &*cache,
667 Some(self.sliding_window),
668 xs.dtype(),
669 self.cfg.num_attn_heads,
670 )?;
671 for (i, layer) in self.layers.iter().enumerate() {
672 xs = self.mapper.map(xs, i)?;
673 xs = layer.forward(
674 &xs,
675 attention_mask
676 .as_ref()
677 .map(|m| m.to_device(xs.device()).unwrap())
678 .as_ref(),
679 sliding_attention_mask
680 .as_ref()
681 .map(|m| m.to_device(xs.device()).unwrap())
682 .as_ref(),
683 seqlen_offsets,
684 &mut cache[i],
685 scalings.clone(),
686 self.xlora_classifier
687 .as_ref()
688 .map(|classifier| classifier.get_global_scaling_weight())
689 .unwrap_or(1.0),
690 is_scaling_pass,
691 flash_params,
692 )?;
693 }
694 let xs = xs.to_device(&self.device)?;
695 let mut xs = xs.apply(&self.norm)?;
696 if let Some(t) = self.lm_head.quantized_act_type() {
697 xs = xs.to_dtype(t)?;
698 }
699
700 let mut xs = self.lm_head.lora_forward(&xs, None, 1.0, None)?;
701
702 if let Some(final_logit_softcapping) = self.final_logit_softcapping {
703 xs = (xs / final_logit_softcapping)?;
704 xs = xs.tanh()?;
705 xs = (xs * final_logit_softcapping)?;
706 }
707
708 Ok(xs)
709 }
710
711 #[allow(clippy::too_many_arguments)]
712 pub fn forward(
713 &self,
714 input_ids: &Tensor,
715 input_ids_full: &Tensor,
716 seqlen_offsets: &[usize],
717 seqlen_offsets_full: &[usize],
718 no_kv_cache: bool,
719 non_granular_state: &Option<NonGranularState>,
720 context_lens: Vec<(usize, usize)>,
721 flash_params: &FlashParams,
722 flash_params_full: &FlashParams,
723 ) -> Result<Tensor> {
724 if self.xlora_classifier.is_some() {
725 let scalings = self.get_scalings(
726 input_ids,
727 input_ids_full,
728 seqlen_offsets,
729 seqlen_offsets_full,
730 no_kv_cache,
731 non_granular_state,
732 &vec![usize::MAX; context_lens.len()],
733 flash_params,
734 flash_params_full,
735 )?;
736
737 if no_kv_cache {
738 let mut res = self
739 .inner_forward(
740 input_ids_full,
741 seqlen_offsets_full,
742 Some(scalings),
743 true,
744 no_kv_cache,
745 None,
746 flash_params_full,
747 )?
748 .contiguous()?;
749 if let Some(t) = self.lm_head.quantized_act_type() {
750 res = res.to_dtype(t)?;
751 }
752 extract_logits(
753 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
754 context_lens,
755 )
756 } else {
757 let mut res = self
759 .inner_forward(
760 input_ids,
761 seqlen_offsets,
762 Some(scalings),
763 true,
764 no_kv_cache,
765 None,
766 flash_params,
767 )?
768 .contiguous()?;
769 if let Some(t) = self.lm_head.quantized_act_type() {
770 res = res.to_dtype(t)?;
771 }
772 extract_logits(
773 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
774 context_lens,
775 )
776 }
777 } else {
778 let mut res = self
779 .inner_forward(
780 input_ids,
781 seqlen_offsets,
782 None,
783 false,
784 no_kv_cache,
785 None,
786 flash_params,
787 )?
788 .contiguous()?;
789 if let Some(t) = self.lm_head.quantized_act_type() {
790 res = res.to_dtype(t)?;
791 }
792 extract_logits(
793 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
794 context_lens,
795 )
796 }
797 }
798}
799
800impl IsqModel for Model {
801 fn get_layers(
802 &mut self,
803 ) -> (
804 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
805 &dyn DeviceMapper,
806 ) {
807 let mut tensors = Vec::new();
808 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
809 for (i, layer) in self.layers.iter_mut().enumerate() {
810 tensors.push((
811 Arc::get_mut(&mut layer.self_attn.q_proj)
812 .unwrap()
813 .quant_inner(),
814 Some(i),
815 ));
816 tensors.push((
817 Arc::get_mut(&mut layer.self_attn.k_proj)
818 .unwrap()
819 .quant_inner(),
820 Some(i),
821 ));
822 tensors.push((
823 Arc::get_mut(&mut layer.self_attn.v_proj)
824 .unwrap()
825 .quant_inner(),
826 Some(i),
827 ));
828 tensors.push((
829 Arc::get_mut(&mut layer.self_attn.o_proj)
830 .unwrap()
831 .quant_inner(),
832 Some(i),
833 ));
834 tensors.push((
835 Arc::get_mut(&mut layer.mlp.gate_proj)
836 .unwrap()
837 .quant_inner(),
838 Some(i),
839 ));
840 tensors.push((
841 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
842 Some(i),
843 ));
844 tensors.push((
845 Arc::get_mut(&mut layer.mlp.down_proj)
846 .unwrap()
847 .quant_inner(),
848 Some(i),
849 ));
850 }
851 (tensors, &*self.mapper)
852 }
853
854 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
855 panic!("Cannot generate UQFF for an adapter model.")
856 }
857}
858
859impl NormalModel for Model {
860 fn forward(
861 &self,
862 _input_ids: &Tensor,
863 _seqlen_offsets: &[usize],
864 _context_lens: Vec<(usize, usize)>,
865 _position_ids: Vec<usize>,
866 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
867 _flash_params: &FlashParams,
868 ) -> Result<Tensor> {
869 unreachable!()
870 }
871 fn xlora_forward(
872 &self,
873 input_ids: &Tensor,
874 input_ids_full: &Tensor,
875 seqlen_offsets: &[usize],
876 seqlen_offsets_full: &[usize],
877 no_kv_cache: bool,
878 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
879 context_lens: Vec<(usize, usize)>,
880 _position_ids: Vec<usize>,
881 flash_params: &FlashParams,
882 flash_params_full: &FlashParams,
883 ) -> Result<Tensor> {
884 self.forward(
885 input_ids,
886 input_ids_full,
887 seqlen_offsets,
888 seqlen_offsets_full,
889 no_kv_cache,
890 non_granular_state,
891 context_lens,
892 flash_params,
893 flash_params_full,
894 )
895 }
896 fn cache(&self) -> &EitherCache {
897 &self.cache
898 }
899 fn cache_mut(&mut self) -> &mut EitherCache {
900 &mut self.cache
901 }
902 fn device(&self) -> &Device {
903 &self.device
904 }
905 fn is_xlora(&self) -> bool {
906 false
907 }
908 fn max_seq_len(&self) -> usize {
909 self.max_seq_len
910 }
911 fn config(&self) -> &ModelConfigMetadata {
912 &self.cfg
913 }
914}
915
916impl ScalingsMaker for Model {
917 fn dtype(&self) -> DType {
918 self.dtype
919 }
920 fn get_cache(&self) -> &EitherCache {
921 &self.cache
922 }
923 fn get_classifier(&self) -> &XLoraClassifier {
924 self.xlora_classifier.as_ref().unwrap()
925 }
926 fn forward(
927 &self,
928 input_ids: &Tensor,
929 seqlen_offsets: &[usize],
930 scalings: Tensor,
931 is_full_pass: bool,
932 no_kv_cache: bool,
933 is_scaling_pass: Option<f64>,
934 _context_lens: &[usize],
935 flash_params: &FlashParams,
936 ) -> Result<Tensor> {
937 self.inner_forward(
938 input_ids,
939 seqlen_offsets,
940 Some(scalings),
941 is_full_pass,
942 no_kv_cache,
943 is_scaling_pass,
944 flash_params,
945 )
946 }
947}
948
949impl AnyMoeBaseModelMixin for Model {}