1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use candle_core::{DType, Device, Module, Result, Tensor};
6use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, Activation, CausalMasker, RmsNorm, RotaryEmbedding, Sdpa},
15 lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16 models::gemma2::Config,
17 paged_attention::ModelConfigMetadata,
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22 },
23 utils::progress::NiceProgressBar,
24 Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32 gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33 up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35 act_fn: Activation,
36}
37
38impl MLP {
39 #[allow(clippy::too_many_arguments)]
40 fn new(
41 cfg: &Config,
42 vb: ShardedVarBuilder,
43 lora_config: &[((String, String), LoraConfig)],
44 count: &mut usize,
45 ord: &Ordering,
46 mapper: &dyn DeviceMapper,
47 layer_idx: usize,
48 loading_isq: bool,
49 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
50 ) -> Result<Self> {
51 let hidden_sz = cfg.hidden_size;
52 let intermediate_sz = cfg.intermediate_size;
53 let gate_proj = linear_b(
54 hidden_sz,
55 intermediate_sz,
56 false,
57 mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
58 mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
59 lora_config,
60 count,
61 ord,
62 preload_adapters,
63 )?;
64 let up_proj = linear_b(
65 hidden_sz,
66 intermediate_sz,
67 false,
68 mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69 mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70 lora_config,
71 count,
72 ord,
73 preload_adapters,
74 )?;
75 let down_proj = linear_b(
76 intermediate_sz,
77 hidden_sz,
78 false,
79 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
80 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
81 lora_config,
82 count,
83 ord,
84 preload_adapters,
85 )?;
86 Ok(Self {
87 gate_proj,
88 up_proj,
89 down_proj,
90 act_fn: cfg.hidden_act()?,
91 })
92 }
93
94 fn forward(
95 &self,
96 xs: &Tensor,
97 scalings: Option<Tensor>,
98 global_scaling_weight: f64,
99 is_scaling_pass: Option<f64>,
100 ) -> Result<Tensor> {
101 let original_dtype = xs.dtype();
102 let mut xs = xs.clone();
103 if let Some(t) = self.gate_proj.quantized_act_type() {
104 xs = xs.to_dtype(t)?;
105 }
106 let lhs = self
107 .gate_proj
108 .lora_forward(
109 &xs,
110 scalings.clone(),
111 global_scaling_weight,
112 is_scaling_pass,
113 )?
114 .apply(&self.act_fn)?;
115 let rhs = self.up_proj.lora_forward(
116 &xs,
117 scalings.clone(),
118 global_scaling_weight,
119 is_scaling_pass,
120 )?;
121 let mut res = self.down_proj.lora_forward(
122 &(lhs * rhs)?,
123 scalings,
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 if self.gate_proj.quantized_act_type().is_some() {
128 res = res.to_dtype(original_dtype)?;
129 }
130 Ok(res)
131 }
132}
133
134struct Attention {
135 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
139 num_heads: usize,
140 num_kv_heads: usize,
141 head_dim: usize,
142 rotary_emb: Arc<RotaryEmbedding>,
143 use_sliding_window: bool,
144 sliding_window: Option<usize>,
145 sdpa_params: SdpaParams,
146}
147
148impl Attention {
149 #[allow(clippy::too_many_arguments)]
150 fn new(
151 rotary_emb: Arc<RotaryEmbedding>,
152 cfg: &Config,
153 vb: ShardedVarBuilder,
154 lora_config: &[((String, String), LoraConfig)],
155 count: &mut usize,
156 ord: &Ordering,
157 mapper: &dyn DeviceMapper,
158 layer_idx: usize,
159 loading_isq: bool,
160 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
161 ) -> Result<Self> {
162 let hidden_sz = cfg.hidden_size;
163 let num_heads = cfg.num_attention_heads;
164 let num_kv_heads = cfg.num_key_value_heads;
165 let head_dim = cfg.head_dim;
166 let bias = cfg.attention_bias;
167 let q_proj = linear_b(
168 hidden_sz,
169 num_heads * head_dim,
170 bias,
171 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
172 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
173 lora_config,
174 count,
175 ord,
176 preload_adapters,
177 )?;
178 let k_proj = linear_b(
179 hidden_sz,
180 num_kv_heads * head_dim,
181 bias,
182 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
183 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
184 lora_config,
185 count,
186 ord,
187 preload_adapters,
188 )?;
189 let v_proj = linear_b(
190 hidden_sz,
191 num_kv_heads * head_dim,
192 bias,
193 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
194 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
195 lora_config,
196 count,
197 ord,
198 preload_adapters,
199 )?;
200 let o_proj = linear_b(
201 num_heads * head_dim,
202 hidden_sz,
203 bias,
204 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
205 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
206 lora_config,
207 count,
208 ord,
209 preload_adapters,
210 )?;
211 let sliding_window = if layer_idx % 2 == 0 {
212 Some(cfg.sliding_window)
214 } else {
215 None
216 };
217 Ok(Self {
218 q_proj,
219 k_proj,
220 v_proj,
221 o_proj,
222 num_heads,
223 num_kv_heads,
224 head_dim,
225 rotary_emb,
226 use_sliding_window: layer_idx % 2 == 0, sliding_window,
228 sdpa_params: SdpaParams {
229 n_kv_groups: num_heads / num_kv_heads,
230 softcap: cfg.attn_logit_softcapping.map(|x| x as f32),
231 softmax_scale: 1.0 / (cfg.query_pre_attn_scalar as f32).sqrt(),
232 sliding_window,
233 },
234 })
235 }
236
237 #[allow(clippy::too_many_arguments)]
238 fn forward(
239 &self,
240 xs: &Tensor,
241 attention_mask: Option<&Tensor>,
242 sliding_attention_mask: Option<&Tensor>,
243 seqlen_offsets: &[usize],
244 kv_cache: &mut Option<(Tensor, Tensor)>,
245 scalings: Option<Tensor>,
246 global_scaling_weight: f64,
247 is_scaling_pass: Option<f64>,
248 flash_params: &FlashParams,
249 ) -> Result<Tensor> {
250 let (b_sz, q_len, _) = xs.dims3()?;
251
252 let original_dtype = xs.dtype();
253 let mut xs = xs.clone();
254 if let Some(t) = self.q_proj.quantized_act_type() {
255 xs = xs.to_dtype(t)?;
256 }
257 let mut q = self.q_proj.lora_forward(
258 &xs,
259 scalings.clone(),
260 global_scaling_weight,
261 is_scaling_pass,
262 )?;
263 let mut k = self.k_proj.lora_forward(
264 &xs,
265 scalings.clone(),
266 global_scaling_weight,
267 is_scaling_pass,
268 )?;
269 let mut v = self.v_proj.lora_forward(
270 &xs,
271 scalings.clone(),
272 global_scaling_weight,
273 is_scaling_pass,
274 )?;
275 if self.q_proj.quantized_act_type().is_some() {
276 q = q.to_dtype(original_dtype)?;
277 k = k.to_dtype(original_dtype)?;
278 v = v.to_dtype(original_dtype)?;
279 }
280
281 let (q, k, v) = if q_len != 1 {
282 let q = q
283 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
284 .transpose(1, 2)?;
285 let k = k
286 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
287 .transpose(1, 2)?;
288 let v = v
289 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
290 .transpose(1, 2)?;
291 (q, k, v)
292 } else {
293 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
294 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
295 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
296 (q, k, v)
297 };
298
299 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
300
301 let mask = if self.use_sliding_window {
302 sliding_attention_mask
303 } else {
304 attention_mask
305 };
306
307 let (k, v, mask) =
309 Cache::update_kv_cache_sliding_window(kv_cache, k, v, mask, self.sliding_window)?;
310
311 let mut attn_output = Sdpa.run_attention(
312 &q,
313 &k,
314 &v,
315 mask.as_ref(),
316 Some(flash_params),
317 &self.sdpa_params,
318 )?;
319
320 if let Some(t) = self.q_proj.quantized_act_type() {
321 attn_output = attn_output.to_dtype(t)?;
322 }
323 let mut res = self.o_proj.lora_forward(
324 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
325 scalings.clone(),
326 global_scaling_weight,
327 is_scaling_pass,
328 )?;
329 if self.q_proj.quantized_act_type().is_some() {
330 res = res.to_dtype(original_dtype)?;
331 }
332 Ok(res)
333 }
334}
335
336struct DecoderLayer {
337 self_attn: Attention,
338 mlp: MLP,
339 input_layernorm: RmsNorm,
340 post_attention_layernorm: RmsNorm,
341 pre_feedforward_layernorm: RmsNorm,
342 post_feedforward_layernorm: RmsNorm,
343}
344
345impl DecoderLayer {
346 #[allow(clippy::too_many_arguments)]
347 fn new(
348 rotary_emb: Arc<RotaryEmbedding>,
349 cfg: &Config,
350 vb: ShardedVarBuilder,
351 lora_config: &[((String, String), LoraConfig)],
352 count: &mut usize,
353 ord: &Ordering,
354 mapper: &dyn DeviceMapper,
355 layer_idx: usize,
356 loading_isq: bool,
357 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
358 ) -> Result<Self> {
359 let self_attn = Attention::new(
360 rotary_emb,
361 cfg,
362 vb.pp("self_attn"),
363 lora_config,
364 count,
365 ord,
366 mapper,
367 layer_idx,
368 loading_isq,
369 preload_adapters,
370 )?;
371 let mlp = MLP::new(
372 cfg,
373 vb.pp("mlp"),
374 lora_config,
375 count,
376 ord,
377 mapper,
378 layer_idx,
379 loading_isq,
380 preload_adapters,
381 )?;
382 let input_layernorm = RmsNorm::new_gemma(
383 cfg.hidden_size,
384 cfg.rms_norm_eps,
385 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
386 )?;
387 let post_attention_layernorm = RmsNorm::new_gemma(
388 cfg.hidden_size,
389 cfg.rms_norm_eps,
390 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
391 )?;
392 let pre_feedforward_layernorm = RmsNorm::new_gemma(
393 cfg.hidden_size,
394 cfg.rms_norm_eps,
395 mapper.set_device(layer_idx, vb.pp("pre_feedforward_layernorm"), false),
396 )?;
397 let post_feedforward_layernorm = RmsNorm::new_gemma(
398 cfg.hidden_size,
399 cfg.rms_norm_eps,
400 mapper.set_device(layer_idx, vb.pp("post_feedforward_layernorm"), false),
401 )?;
402 Ok(Self {
403 self_attn,
404 mlp,
405 input_layernorm,
406 post_attention_layernorm,
407 pre_feedforward_layernorm,
408 post_feedforward_layernorm,
409 })
410 }
411
412 #[allow(clippy::too_many_arguments)]
413 fn forward(
414 &self,
415 xs: &Tensor,
416 attention_mask: Option<&Tensor>,
417 sliding_attention_mask: Option<&Tensor>,
418 seqlen_offsets: &[usize],
419 kv_cache: &mut Option<(Tensor, Tensor)>,
420 scalings: Option<Tensor>,
421 global_scaling_weight: f64,
422 is_scaling_pass: Option<f64>,
423 flash_params: &FlashParams,
424 ) -> Result<Tensor> {
425 let residual = xs;
426 let xs = self.input_layernorm.forward(xs)?;
427 let xs = self
428 .self_attn
429 .forward(
430 &xs,
431 attention_mask,
432 sliding_attention_mask,
433 seqlen_offsets,
434 kv_cache,
435 scalings.clone(),
436 global_scaling_weight,
437 is_scaling_pass,
438 flash_params,
439 )?
440 .apply(&self.post_attention_layernorm)?;
441 let xs = (xs + residual)?;
442 let residual = &xs;
443 let xs = self
444 .mlp
445 .forward(
446 &xs.apply(&self.pre_feedforward_layernorm)?,
447 scalings,
448 global_scaling_weight,
449 is_scaling_pass,
450 )?
451 .apply(&self.post_feedforward_layernorm)?;
452 residual + xs
453 }
454}
455
456pub struct Model {
457 embed_tokens: candle_nn::Embedding,
458 layers: Vec<DecoderLayer>,
459 norm: RmsNorm,
460 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
461 hidden_size: usize,
462 device: Device,
463 cache: EitherCache,
464 max_seq_len: usize,
465 mapper: Box<dyn DeviceMapper + Send + Sync>,
466 sliding_window: usize,
467 final_logit_softcapping: Option<f64>,
468 xlora_classifier: Option<XLoraClassifier>,
469 dtype: DType,
470 cfg: ModelConfigMetadata,
471}
472
473impl Model {
474 #[allow(clippy::too_many_arguments)]
475 pub fn new(
476 cfg: &Config,
477 vb: ShardedVarBuilder,
478 lora_config: &[((String, String), LoraConfig)],
479 xlora_config: Option<XLoraConfig>,
480 xlora_ordering: Ordering,
481 is_gptx: bool,
482 normal_loading_metadata: NormalLoadingMetadata,
483 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
484 ) -> Result<Self> {
485 if let Some(ref quant_cfg) = &cfg.quantization_config {
486 tracing::info!(
487 "Using {} quantization: {}.",
488 quant_cfg.name(),
489 quant_cfg.get_bits_name(&vb)
490 );
491 }
492 let mapper = normal_loading_metadata.mapper;
493
494 let vb_m = vb.pp("model");
495 let embed_tokens = layers::embedding(
496 cfg.vocab_size,
497 cfg.hidden_size,
498 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
499 &cfg.quantization_config,
500 )?;
501 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
502 let vb_l = vb_m.pp("layers");
503 let mut ropes = HashMap::new();
504 for layer_idx in 0..cfg.num_hidden_layers {
505 let device = mapper
506 .device_for(layer_idx, false)
507 .unwrap_or(&normal_loading_metadata.real_device);
508 ropes.insert(
509 device.location(),
510 Arc::new(RotaryEmbedding::new(
511 cfg.rope_theta as f32,
512 cfg.head_dim,
513 cfg.max_position_embeddings,
514 device,
515 is_gptx,
516 vb.dtype(),
517 )?),
518 );
519 }
520 let mut count = 0;
521 for layer_idx in NiceProgressBar::<_, 'b'>(
522 0..cfg.num_hidden_layers,
523 "Loading repeating layers",
524 &normal_loading_metadata.multi_progress,
525 ) {
526 let device = mapper
527 .device_for(layer_idx, false)
528 .unwrap_or(&normal_loading_metadata.real_device);
529 let rotary_emb = ropes
530 .get(&device.location())
531 .expect("No RoPE for device location!")
532 .clone();
533 let layer = DecoderLayer::new(
534 rotary_emb.clone(),
535 cfg,
536 vb_l.pp(layer_idx),
537 lora_config,
538 &mut count,
539 &xlora_ordering,
540 &*mapper,
541 layer_idx,
542 normal_loading_metadata.loading_isq,
543 preload_adapters,
544 )?;
545 layers.push(layer)
546 }
547 if xlora_config.is_none() && preload_adapters.is_none() {
548 info!("Merging LoRA adapters.");
550 for layer in layers.iter_mut().tqdm() {
551 Arc::get_mut(&mut layer.self_attn.k_proj)
552 .unwrap()
553 .merge_weights()?;
554 Arc::get_mut(&mut layer.self_attn.o_proj)
555 .unwrap()
556 .merge_weights()?;
557 Arc::get_mut(&mut layer.self_attn.q_proj)
558 .unwrap()
559 .merge_weights()?;
560 Arc::get_mut(&mut layer.self_attn.v_proj)
561 .unwrap()
562 .merge_weights()?;
563
564 Arc::get_mut(&mut layer.mlp.down_proj)
565 .unwrap()
566 .merge_weights()?;
567 Arc::get_mut(&mut layer.mlp.gate_proj)
568 .unwrap()
569 .merge_weights()?;
570 Arc::get_mut(&mut layer.mlp.up_proj)
571 .unwrap()
572 .merge_weights()?;
573 }
574 }
575 let norm = RmsNorm::new_gemma(
576 cfg.hidden_size,
577 cfg.rms_norm_eps,
578 mapper.set_nm_device(vb_m.pp("norm"), false),
579 )?;
580
581 let lm_head = linear_no_bias(
582 embed_tokens.embeddings().dim(1)?,
583 embed_tokens.embeddings().dim(0)?,
584 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
585 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
586 lora_config,
587 &mut count,
588 &xlora_ordering,
589 preload_adapters,
590 )?;
591 if xlora_config.is_some() && lm_head.is_lora() {
592 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
594 }
595
596 Ok(Self {
597 embed_tokens,
598 layers,
599 norm,
600 lm_head,
601 device: normal_loading_metadata.real_device,
602 hidden_size: cfg.hidden_size,
603 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
604 max_seq_len: cfg.max_position_embeddings,
605 mapper,
606 sliding_window: cfg.sliding_window,
607 final_logit_softcapping: cfg.final_logit_softcapping,
608 dtype: vb.dtype(),
609 xlora_classifier: xlora_config.map(|xlora_config| {
610 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
611 }),
612 cfg: ModelConfigMetadata {
613 max_seq_len: cfg.max_position_embeddings,
614 num_layers: cfg.num_hidden_layers,
615 hidden_size: cfg.hidden_size,
616 num_kv_heads: cfg.num_key_value_heads,
617 num_attn_heads: cfg.num_attention_heads,
618 sliding_window: None,
619 k_head_dim: cfg.head_dim,
620 v_head_dim: cfg.head_dim,
621 },
622 })
623 }
624
625 #[allow(clippy::too_many_arguments)]
626 fn inner_forward(
627 &self,
628 input_ids: &Tensor,
629 seqlen_offsets: &[usize],
630 scalings: Option<Tensor>,
631 is_full_pass: bool,
632 no_kv_cache: bool,
633 is_scaling_pass: Option<f64>,
634 flash_params: &FlashParams,
635 ) -> Result<Tensor> {
636 let xs = self.embed_tokens.forward(input_ids)?;
637 let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
638 let mut cache = if is_full_pass {
639 if no_kv_cache {
640 let mut new_cache = Vec::new();
641 for _ in 0..self.cache.full().xlora_lock().len() {
642 new_cache.push(None);
643 }
644
645 self.cache.full().xlora_lock().clone_from(&new_cache);
646 }
647 self.cache.full().xlora_lock()
648 } else {
649 self.cache.full().lock()
650 };
651 let attention_mask = CausalMasker.make_causal_mask_matrix(
652 input_ids,
653 &*cache,
654 xs.dtype(),
655 self.cfg.num_attn_heads,
656 )?;
657 let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
658 input_ids,
659 &*cache,
660 Some(self.sliding_window),
661 xs.dtype(),
662 self.cfg.num_attn_heads,
663 )?;
664 for (i, layer) in self.layers.iter().enumerate() {
665 xs = self.mapper.map(xs, i)?;
666 xs = layer.forward(
667 &xs,
668 attention_mask
669 .as_ref()
670 .map(|m| m.to_device(xs.device()).unwrap())
671 .as_ref(),
672 sliding_attention_mask
673 .as_ref()
674 .map(|m| m.to_device(xs.device()).unwrap())
675 .as_ref(),
676 seqlen_offsets,
677 &mut cache[i],
678 scalings.clone(),
679 self.xlora_classifier
680 .as_ref()
681 .map(|classifier| classifier.get_global_scaling_weight())
682 .unwrap_or(1.0),
683 is_scaling_pass,
684 flash_params,
685 )?;
686 }
687 let xs = xs.to_device(&self.device)?;
688 let mut xs = xs.apply(&self.norm)?;
689 if let Some(t) = self.lm_head.quantized_act_type() {
690 xs = xs.to_dtype(t)?;
691 }
692
693 let mut xs = self.lm_head.lora_forward(&xs, None, 1.0, None)?;
694
695 if let Some(final_logit_softcapping) = self.final_logit_softcapping {
696 xs = (xs / final_logit_softcapping)?;
697 xs = xs.tanh()?;
698 xs = (xs * final_logit_softcapping)?;
699 }
700
701 Ok(xs)
702 }
703
704 #[allow(clippy::too_many_arguments)]
705 pub fn forward(
706 &self,
707 input_ids: &Tensor,
708 input_ids_full: &Tensor,
709 seqlen_offsets: &[usize],
710 seqlen_offsets_full: &[usize],
711 no_kv_cache: bool,
712 non_granular_state: &Option<NonGranularState>,
713 context_lens: Vec<(usize, usize)>,
714 flash_params: &FlashParams,
715 flash_params_full: &FlashParams,
716 ) -> Result<Tensor> {
717 if self.xlora_classifier.is_some() {
718 let scalings = self.get_scalings(
719 input_ids,
720 input_ids_full,
721 seqlen_offsets,
722 seqlen_offsets_full,
723 no_kv_cache,
724 non_granular_state,
725 &vec![usize::MAX; context_lens.len()],
726 flash_params,
727 flash_params_full,
728 )?;
729
730 if no_kv_cache {
731 let mut res = self
732 .inner_forward(
733 input_ids_full,
734 seqlen_offsets_full,
735 Some(scalings),
736 true,
737 no_kv_cache,
738 None,
739 flash_params_full,
740 )?
741 .contiguous()?;
742 if let Some(t) = self.lm_head.quantized_act_type() {
743 res = res.to_dtype(t)?;
744 }
745 extract_logits(
746 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
747 context_lens,
748 )
749 } else {
750 let mut res = self
752 .inner_forward(
753 input_ids,
754 seqlen_offsets,
755 Some(scalings),
756 true,
757 no_kv_cache,
758 None,
759 flash_params,
760 )?
761 .contiguous()?;
762 if let Some(t) = self.lm_head.quantized_act_type() {
763 res = res.to_dtype(t)?;
764 }
765 extract_logits(
766 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
767 context_lens,
768 )
769 }
770 } else {
771 let mut res = self
772 .inner_forward(
773 input_ids,
774 seqlen_offsets,
775 None,
776 false,
777 no_kv_cache,
778 None,
779 flash_params,
780 )?
781 .contiguous()?;
782 if let Some(t) = self.lm_head.quantized_act_type() {
783 res = res.to_dtype(t)?;
784 }
785 extract_logits(
786 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
787 context_lens,
788 )
789 }
790 }
791}
792
793impl IsqModel for Model {
794 fn get_layers(
795 &mut self,
796 ) -> (
797 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
798 &dyn DeviceMapper,
799 ) {
800 let mut tensors = Vec::new();
801 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
802 for (i, layer) in self.layers.iter_mut().enumerate() {
803 tensors.push((
804 Arc::get_mut(&mut layer.self_attn.q_proj)
805 .unwrap()
806 .quant_inner(),
807 Some(i),
808 ));
809 tensors.push((
810 Arc::get_mut(&mut layer.self_attn.k_proj)
811 .unwrap()
812 .quant_inner(),
813 Some(i),
814 ));
815 tensors.push((
816 Arc::get_mut(&mut layer.self_attn.v_proj)
817 .unwrap()
818 .quant_inner(),
819 Some(i),
820 ));
821 tensors.push((
822 Arc::get_mut(&mut layer.self_attn.o_proj)
823 .unwrap()
824 .quant_inner(),
825 Some(i),
826 ));
827 tensors.push((
828 Arc::get_mut(&mut layer.mlp.gate_proj)
829 .unwrap()
830 .quant_inner(),
831 Some(i),
832 ));
833 tensors.push((
834 Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
835 Some(i),
836 ));
837 tensors.push((
838 Arc::get_mut(&mut layer.mlp.down_proj)
839 .unwrap()
840 .quant_inner(),
841 Some(i),
842 ));
843 }
844 (tensors, &*self.mapper)
845 }
846
847 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
848 panic!("Cannot generate UQFF for an adapter model.")
849 }
850}
851
852impl NormalModel for Model {
853 fn forward(
854 &self,
855 _input_ids: &Tensor,
856 _seqlen_offsets: &[usize],
857 _context_lens: Vec<(usize, usize)>,
858 _position_ids: Vec<usize>,
859 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
860 _flash_params: &FlashParams,
861 ) -> Result<Tensor> {
862 unreachable!()
863 }
864 fn xlora_forward(
865 &self,
866 input_ids: &Tensor,
867 input_ids_full: &Tensor,
868 seqlen_offsets: &[usize],
869 seqlen_offsets_full: &[usize],
870 no_kv_cache: bool,
871 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
872 context_lens: Vec<(usize, usize)>,
873 _position_ids: Vec<usize>,
874 flash_params: &FlashParams,
875 flash_params_full: &FlashParams,
876 ) -> Result<Tensor> {
877 self.forward(
878 input_ids,
879 input_ids_full,
880 seqlen_offsets,
881 seqlen_offsets_full,
882 no_kv_cache,
883 non_granular_state,
884 context_lens,
885 flash_params,
886 flash_params_full,
887 )
888 }
889 fn cache(&self) -> &EitherCache {
890 &self.cache
891 }
892 fn cache_mut(&mut self) -> &mut EitherCache {
893 &mut self.cache
894 }
895 fn device(&self) -> &Device {
896 &self.device
897 }
898 fn is_xlora(&self) -> bool {
899 false
900 }
901 fn max_seq_len(&self) -> usize {
902 self.max_seq_len
903 }
904 fn config(&self) -> &ModelConfigMetadata {
905 &self.cfg
906 }
907}
908
909impl ScalingsMaker for Model {
910 fn dtype(&self) -> DType {
911 self.dtype
912 }
913 fn get_cache(&self) -> &EitherCache {
914 &self.cache
915 }
916 fn get_classifier(&self) -> &XLoraClassifier {
917 self.xlora_classifier.as_ref().unwrap()
918 }
919 fn forward(
920 &self,
921 input_ids: &Tensor,
922 seqlen_offsets: &[usize],
923 scalings: Tensor,
924 is_full_pass: bool,
925 no_kv_cache: bool,
926 is_scaling_pass: Option<f64>,
927 _context_lens: &[usize],
928 flash_params: &FlashParams,
929 ) -> Result<Tensor> {
930 self.inner_forward(
931 input_ids,
932 seqlen_offsets,
933 Some(scalings),
934 is_full_pass,
935 no_kv_cache,
936 is_scaling_pass,
937 flash_params,
938 )
939 }
940}
941
942impl AnyMoeBaseModelMixin for Model {}