1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{collections::HashMap, sync::Arc};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, layer_norm, Activation, CausalMasker, RotaryEmbedding, Sdpa},
15 lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16 models::starcoder2::Config,
17 paged_attention::ModelConfigMetadata,
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22 },
23 utils::progress::NiceProgressBar,
24 Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32 c_fc: Arc<dyn LinearLayerLike + Send + Sync>,
33 c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 act: Activation,
35}
36
37impl MLP {
38 #[allow(clippy::too_many_arguments)]
39 fn new(
40 cfg: &Config,
41 vb: ShardedVarBuilder,
42 lora_config: &[((String, String), LoraConfig)],
43 count: &mut usize,
44 ord: &Ordering,
45 mapper: &dyn DeviceMapper,
46 layer_idx: usize,
47 loading_isq: bool,
48 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
49 ) -> Result<Self> {
50 let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
51 let c_fc = linear_b(
52 h_size,
53 i_size,
54 cfg.use_bias,
55 mapper.set_device(layer_idx, vb.pp("c_fc"), loading_isq),
56 mapper.set_device(layer_idx, vb.pp("c_fc"), false),
57 lora_config,
58 count,
59 ord,
60 preload_adapters,
61 )?;
62 let c_proj = linear_b(
63 i_size,
64 h_size,
65 cfg.use_bias,
66 mapper.set_device(layer_idx, vb.pp("c_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("c_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 Ok(Self {
74 c_fc,
75 c_proj,
76 act: cfg.hidden_act,
77 })
78 }
79
80 fn forward(
81 &self,
82 xs: &Tensor,
83 scalings: Option<Tensor>,
84 global_scaling_weight: f64,
85 is_scaling_pass: Option<f64>,
86 ) -> Result<Tensor> {
87 let original_dtype = xs.dtype();
88 let mut xs = xs.clone();
89 if let Some(t) = self.c_fc.quantized_act_type() {
90 xs = xs.to_dtype(t)?;
91 }
92 let mut res = self.c_proj.lora_forward(
93 &self
94 .c_fc
95 .lora_forward(
96 &xs,
97 scalings.clone(),
98 global_scaling_weight,
99 is_scaling_pass,
100 )?
101 .apply(&self.act)?,
102 scalings,
103 global_scaling_weight,
104 is_scaling_pass,
105 )?;
106 if self.c_fc.quantized_act_type().is_some() {
107 res = res.to_dtype(original_dtype)?;
108 }
109 Ok(res)
110 }
111}
112
113struct Attention {
114 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
115 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
116 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
117 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
118 num_heads: usize,
119 num_kv_heads: usize,
120 head_dim: usize,
121 hidden_size: usize,
122 rotary_emb: Arc<RotaryEmbedding>,
123 sliding_window: Option<usize>,
124 sdpa_params: SdpaParams,
125}
126
127impl Attention {
128 #[allow(clippy::too_many_arguments)]
129 fn new(
130 rotary_emb: Arc<RotaryEmbedding>,
131 cfg: &Config,
132 vb: ShardedVarBuilder,
133 lora_config: &[((String, String), LoraConfig)],
134 count: &mut usize,
135 ord: &Ordering,
136 mapper: &dyn DeviceMapper,
137 layer_idx: usize,
138 loading_isq: bool,
139 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
140 ) -> Result<Self> {
141 let hidden_sz = cfg.hidden_size;
142 let num_heads = cfg.num_attention_heads;
143 let num_kv_heads = cfg.num_key_value_heads;
144 let head_dim = hidden_sz / num_heads;
145 let b = cfg.use_bias;
146 let q_proj = linear_b(
147 hidden_sz,
148 num_heads * head_dim,
149 b,
150 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
151 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
152 lora_config,
153 count,
154 ord,
155 preload_adapters,
156 )?;
157 let k_proj = linear_b(
158 hidden_sz,
159 num_kv_heads * head_dim,
160 b,
161 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
162 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
163 lora_config,
164 count,
165 ord,
166 preload_adapters,
167 )?;
168 let v_proj = linear_b(
169 hidden_sz,
170 num_kv_heads * head_dim,
171 b,
172 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
173 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
174 lora_config,
175 count,
176 ord,
177 preload_adapters,
178 )?;
179 let o_proj = linear_b(
180 num_heads * head_dim,
181 hidden_sz,
182 b,
183 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
184 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
185 lora_config,
186 count,
187 ord,
188 preload_adapters,
189 )?;
190 Ok(Self {
191 q_proj,
192 k_proj,
193 v_proj,
194 o_proj,
195 num_heads,
196 num_kv_heads,
197 head_dim,
198 hidden_size: hidden_sz,
199 rotary_emb,
200 sliding_window: cfg.sliding_window,
201 sdpa_params: SdpaParams {
202 n_kv_groups: num_heads / num_kv_heads,
203 use_flash_attn: cfg.use_flash_attn,
204 softcap: None,
205 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
206 sliding_window: cfg.sliding_window,
207 },
208 })
209 }
210
211 #[allow(clippy::too_many_arguments)]
212 fn forward(
213 &self,
214 xs: &Tensor,
215 attention_mask: Option<&Tensor>,
216 seqlen_offsets: &[usize],
217 kv_cache: &mut Option<(Tensor, Tensor)>,
218 scalings: Option<Tensor>,
219 global_scaling_weight: f64,
220 is_scaling_pass: Option<f64>,
221 flash_params: &FlashParams,
222 ) -> Result<Tensor> {
223 let (b_sz, q_len, _) = xs.dims3()?;
224
225 let original_dtype = xs.dtype();
226 let mut xs = xs.clone();
227 if let Some(t) = self.q_proj.quantized_act_type() {
228 xs = xs.to_dtype(t)?;
229 }
230 let mut q = self.q_proj.lora_forward(
231 &xs,
232 scalings.clone(),
233 global_scaling_weight,
234 is_scaling_pass,
235 )?;
236 let mut k = self.k_proj.lora_forward(
237 &xs,
238 scalings.clone(),
239 global_scaling_weight,
240 is_scaling_pass,
241 )?;
242 let mut v = self.v_proj.lora_forward(
243 &xs,
244 scalings.clone(),
245 global_scaling_weight,
246 is_scaling_pass,
247 )?;
248 if self.q_proj.quantized_act_type().is_some() {
249 q = q.to_dtype(original_dtype)?;
250 k = k.to_dtype(original_dtype)?;
251 v = v.to_dtype(original_dtype)?;
252 }
253
254 let (q, k, v) = if q_len != 1 {
255 let q = q
256 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
257 .transpose(1, 2)?;
258 let k = k
259 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
260 .transpose(1, 2)?;
261 let v = v
262 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
263 .transpose(1, 2)?;
264 (q, k, v)
265 } else {
266 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
267 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
268 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
269 (q, k, v)
270 };
271
272 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
273
274 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
275 kv_cache,
276 k,
277 v,
278 attention_mask,
279 self.sliding_window,
280 false,
281 )?;
282
283 let mut attn_output = Sdpa.run_attention(
284 &q,
285 &k,
286 &v,
287 attn_mask.as_ref(),
288 Some(flash_params),
289 &self.sdpa_params,
290 )?;
291
292 if let Some(t) = self.q_proj.quantized_act_type() {
293 attn_output = attn_output.to_dtype(t)?;
294 }
295 let mut res = self.o_proj.lora_forward(
296 &attn_output
297 .transpose(1, 2)?
298 .reshape((b_sz, q_len, self.hidden_size))?,
299 scalings.clone(),
300 global_scaling_weight,
301 is_scaling_pass,
302 )?;
303 if self.q_proj.quantized_act_type().is_some() {
304 res = res.to_dtype(original_dtype)?;
305 }
306 Ok(res)
307 }
308}
309
310struct DecoderLayer {
311 self_attn: Attention,
312 mlp: MLP,
313 input_layernorm: LayerNorm,
314 post_attention_layernorm: LayerNorm,
315}
316
317impl DecoderLayer {
318 #[allow(clippy::too_many_arguments)]
319 fn new(
320 rotary_emb: Arc<RotaryEmbedding>,
321 cfg: &Config,
322 vb: ShardedVarBuilder,
323 lora_config: &[((String, String), LoraConfig)],
324 count: &mut usize,
325 ord: &Ordering,
326 mapper: &dyn DeviceMapper,
327 layer_idx: usize,
328 loading_isq: bool,
329 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
330 ) -> Result<Self> {
331 let self_attn = Attention::new(
332 rotary_emb,
333 cfg,
334 vb.pp("self_attn"),
335 lora_config,
336 count,
337 ord,
338 mapper,
339 layer_idx,
340 loading_isq,
341 preload_adapters,
342 )?;
343 let mlp = MLP::new(
344 cfg,
345 vb.pp("mlp"),
346 lora_config,
347 count,
348 ord,
349 mapper,
350 layer_idx,
351 loading_isq,
352 preload_adapters,
353 )?;
354 let input_layernorm = layer_norm(
355 cfg.hidden_size,
356 cfg.norm_epsilon,
357 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
358 )?;
359 let post_attention_layernorm = layer_norm(
360 cfg.hidden_size,
361 cfg.norm_epsilon,
362 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
363 )?;
364 Ok(Self {
365 self_attn,
366 mlp,
367 input_layernorm,
368 post_attention_layernorm,
369 })
370 }
371
372 #[allow(clippy::too_many_arguments)]
373 fn forward(
374 &self,
375 xs: &Tensor,
376 attention_mask: Option<&Tensor>,
377 seqlen_offsets: &[usize],
378 kv_cache: &mut Option<(Tensor, Tensor)>,
379 scalings: Option<Tensor>,
380 global_scaling_weight: f64,
381 is_scaling_pass: Option<f64>,
382 flash_params: &FlashParams,
383 ) -> Result<Tensor> {
384 let residual = xs;
385 let xs = self.input_layernorm.forward(xs)?;
386 let xs = self.self_attn.forward(
387 &xs,
388 attention_mask,
389 seqlen_offsets,
390 kv_cache,
391 scalings.clone(),
392 global_scaling_weight,
393 is_scaling_pass,
394 flash_params,
395 )?;
396 let xs = (xs + residual)?;
397 let residual = &xs;
398 let xs = self.mlp.forward(
399 &xs.apply(&self.post_attention_layernorm)?,
400 scalings.clone(),
401 global_scaling_weight,
402 is_scaling_pass,
403 )?;
404 residual + xs
405 }
406}
407
408pub struct Model {
409 embed_tokens: candle_nn::Embedding,
410 layers: Vec<DecoderLayer>,
411 norm: LayerNorm,
412 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
413 sliding_window: Option<usize>,
414 device: Device,
415 cache: EitherCache,
416 max_seq_len: usize,
417 mapper: Box<dyn DeviceMapper + Send + Sync>,
418 xlora_classifier: Option<XLoraClassifier>,
419 dtype: DType,
420 cfg: ModelConfigMetadata,
421}
422
423impl Model {
424 #[allow(clippy::too_many_arguments)]
425 pub fn new(
426 cfg: &Config,
427 vb: ShardedVarBuilder,
428 lora_config: &[((String, String), LoraConfig)],
429 xlora_config: Option<XLoraConfig>,
430 xlora_ordering: Ordering,
431 is_gptx: bool,
432 normal_loading_metadata: NormalLoadingMetadata,
433 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
434 ) -> Result<Self> {
435 if let Some(ref quant_cfg) = &cfg.quantization_config {
436 tracing::info!(
437 "Using {} quantization: {}.",
438 quant_cfg.name(),
439 quant_cfg.get_bits_name(&vb)
440 );
441 }
442 let mapper = normal_loading_metadata.mapper;
443 let vb_m = vb.pp("model");
444
445 let embed_tokens = layers::embedding(
446 cfg.vocab_size,
447 cfg.hidden_size,
448 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
449 &cfg.quantization_config,
450 )?;
451 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
452 let vb_l = vb_m.pp("layers");
453 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
454 let mut ropes = HashMap::new();
455 for layer_idx in 0..cfg.num_hidden_layers {
456 let device = mapper
457 .device_for(layer_idx, false)
458 .unwrap_or(&normal_loading_metadata.real_device);
459 ropes.insert(
460 device.location(),
461 Arc::new(RotaryEmbedding::new(
462 cfg.rope_theta as f32,
463 head_dim,
464 cfg.max_position_embeddings,
465 device,
466 is_gptx,
467 vb_m.dtype(),
468 )?),
469 );
470 }
471 let mut count = 0;
472 for layer_idx in NiceProgressBar::<_, 'b'>(
473 0..cfg.num_hidden_layers,
474 "Loading repeating layers",
475 &normal_loading_metadata.multi_progress,
476 ) {
477 let device = mapper
478 .device_for(layer_idx, false)
479 .unwrap_or(&normal_loading_metadata.real_device);
480 let rotary_emb = ropes
481 .get(&device.location())
482 .expect("No RoPE for device location!")
483 .clone();
484 layers.push(DecoderLayer::new(
485 rotary_emb.clone(),
486 cfg,
487 vb_l.pp(layer_idx),
488 lora_config,
489 &mut count,
490 &xlora_ordering,
491 &*mapper,
492 layer_idx,
493 normal_loading_metadata.loading_isq,
494 preload_adapters,
495 )?)
496 }
497 if xlora_config.is_none() && preload_adapters.is_none() {
498 info!("Merging LoRA adapters.");
500 for layer in layers.iter_mut().tqdm() {
501 Arc::get_mut(&mut layer.self_attn.k_proj)
502 .unwrap()
503 .merge_weights()?;
504 Arc::get_mut(&mut layer.self_attn.o_proj)
505 .unwrap()
506 .merge_weights()?;
507 Arc::get_mut(&mut layer.self_attn.q_proj)
508 .unwrap()
509 .merge_weights()?;
510 Arc::get_mut(&mut layer.self_attn.v_proj)
511 .unwrap()
512 .merge_weights()?;
513
514 Arc::get_mut(&mut layer.mlp.c_fc).unwrap().merge_weights()?;
515 Arc::get_mut(&mut layer.mlp.c_proj)
516 .unwrap()
517 .merge_weights()?;
518 }
519 }
520 let norm = layer_norm(
521 cfg.hidden_size,
522 cfg.norm_epsilon,
523 mapper.set_nm_device(vb_m.pp("norm"), false),
524 )?;
525 let lm_head = linear_no_bias(
526 embed_tokens.embeddings().dim(1)?,
527 embed_tokens.embeddings().dim(0)?,
528 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
529 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
530 lora_config,
531 &mut count,
532 &xlora_ordering,
533 preload_adapters,
534 )?;
535 if xlora_config.is_some() && lm_head.is_lora() {
536 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
538 }
539 Ok(Self {
540 embed_tokens,
541 layers,
542 norm,
543 lm_head,
544 sliding_window: cfg.sliding_window,
545 device: normal_loading_metadata.real_device,
546 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
547 max_seq_len: cfg.max_position_embeddings,
548 mapper,
549 dtype: vb.dtype(),
550 xlora_classifier: xlora_config.map(|xlora_config| {
551 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
552 }),
553 cfg: ModelConfigMetadata {
554 max_seq_len: cfg.max_position_embeddings,
555 num_layers: cfg.num_hidden_layers,
556 hidden_size: cfg.hidden_size,
557 num_kv_heads: cfg.num_key_value_heads,
558 num_attn_heads: cfg.num_attention_heads,
559 sliding_window: cfg.sliding_window,
560 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
561 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
562 },
563 })
564 }
565
566 #[allow(clippy::too_many_arguments)]
567 fn inner_forward(
568 &self,
569 input_ids: &Tensor,
570 seqlen_offsets: &[usize],
571 scalings: Option<Tensor>,
572 is_full_pass: bool,
573 no_kv_cache: bool,
574 is_scaling_pass: Option<f64>,
575 flash_params: &FlashParams,
576 ) -> Result<Tensor> {
577 let mut xs = self.embed_tokens.forward(input_ids)?;
578
579 let mut cache = if is_full_pass {
580 if no_kv_cache {
581 let mut new_cache = Vec::new();
582 for _ in 0..self.cache.full().xlora_lock().len() {
583 new_cache.push(None);
584 }
585
586 self.cache.full().xlora_lock().clone_from(&new_cache);
587 }
588 self.cache.full().xlora_lock()
589 } else {
590 self.cache.full().lock()
591 };
592 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
593 input_ids,
594 &*cache,
595 self.sliding_window,
596 xs.dtype(),
597 self.cfg.num_attn_heads,
598 )?;
599
600 for (i, layer) in self.layers.iter().enumerate() {
601 xs = self.mapper.map(xs, i)?;
602 xs = layer.forward(
603 &xs,
604 attention_mask
605 .as_ref()
606 .map(|m| m.to_device(xs.device()).unwrap())
607 .as_ref(),
608 seqlen_offsets,
609 &mut cache[i],
610 scalings.clone(),
611 self.xlora_classifier
612 .as_ref()
613 .map(|classifier| classifier.get_global_scaling_weight())
614 .unwrap_or(1.0),
615 is_scaling_pass,
616 flash_params,
617 )?
618 }
619 let xs = xs.to_device(&self.device)?;
620 xs.apply(&self.norm)
621 }
622
623 #[allow(clippy::too_many_arguments)]
624 pub fn forward(
625 &self,
626 input_ids: &Tensor,
627 input_ids_full: &Tensor,
628 seqlen_offsets: &[usize],
629 seqlen_offsets_full: &[usize],
630 no_kv_cache: bool,
631 non_granular_state: &Option<NonGranularState>,
632 context_lens: Vec<(usize, usize)>,
633 flash_params: &FlashParams,
634 flash_params_full: &FlashParams,
635 ) -> Result<Tensor> {
636 if self.xlora_classifier.is_some() {
637 let scalings = self.get_scalings(
638 input_ids,
639 input_ids_full,
640 seqlen_offsets,
641 seqlen_offsets_full,
642 no_kv_cache,
643 non_granular_state,
644 &vec![usize::MAX; context_lens.len()],
645 flash_params,
646 flash_params_full,
647 )?;
648
649 if no_kv_cache {
650 let mut res = self
651 .inner_forward(
652 input_ids_full,
653 seqlen_offsets_full,
654 Some(scalings),
655 true,
656 no_kv_cache,
657 None,
658 flash_params_full,
659 )?
660 .contiguous()?;
661 if let Some(t) = self.lm_head.quantized_act_type() {
662 res = res.to_dtype(t)?;
663 }
664 extract_logits(
665 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
666 context_lens,
667 )
668 } else {
669 let mut res = self
671 .inner_forward(
672 input_ids,
673 seqlen_offsets,
674 Some(scalings),
675 true,
676 no_kv_cache,
677 None,
678 flash_params,
679 )?
680 .contiguous()?;
681 if let Some(t) = self.lm_head.quantized_act_type() {
682 res = res.to_dtype(t)?;
683 }
684 extract_logits(
685 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
686 context_lens,
687 )
688 }
689 } else {
690 let mut res = self
691 .inner_forward(
692 input_ids,
693 seqlen_offsets,
694 None,
695 false,
696 no_kv_cache,
697 None,
698 flash_params,
699 )?
700 .contiguous()?;
701 if let Some(t) = self.lm_head.quantized_act_type() {
702 res = res.to_dtype(t)?;
703 }
704 extract_logits(
705 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
706 context_lens,
707 )
708 }
709 }
710}
711
712impl IsqModel for Model {
713 fn get_layers(
714 &mut self,
715 ) -> (
716 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
717 &dyn DeviceMapper,
718 ) {
719 let mut tensors = Vec::new();
720 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
721 for (i, layer) in self.layers.iter_mut().enumerate() {
722 tensors.push((
723 Arc::get_mut(&mut layer.self_attn.q_proj)
724 .unwrap()
725 .quant_inner(),
726 Some(i),
727 ));
728 tensors.push((
729 Arc::get_mut(&mut layer.self_attn.k_proj)
730 .unwrap()
731 .quant_inner(),
732 Some(i),
733 ));
734 tensors.push((
735 Arc::get_mut(&mut layer.self_attn.v_proj)
736 .unwrap()
737 .quant_inner(),
738 Some(i),
739 ));
740 tensors.push((
741 Arc::get_mut(&mut layer.self_attn.o_proj)
742 .unwrap()
743 .quant_inner(),
744 Some(i),
745 ));
746 tensors.push((
747 Arc::get_mut(&mut layer.mlp.c_fc).unwrap().quant_inner(),
748 Some(i),
749 ));
750 tensors.push((
751 Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
752 Some(i),
753 ));
754 }
755 (tensors, &*self.mapper)
756 }
757
758 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
759 panic!("Cannot generate UQFF for an adapter model.")
760 }
761}
762
763impl NormalModel for Model {
764 fn forward(
765 &self,
766 _input_ids: &Tensor,
767 _seqlen_offsets: &[usize],
768 _context_lens: Vec<(usize, usize)>,
769 _position_ids: Vec<usize>,
770 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
771 _flash_params: &FlashParams,
772 ) -> Result<Tensor> {
773 unimplemented!()
774 }
775 fn xlora_forward(
776 &self,
777 input_ids: &Tensor,
778 input_ids_full: &Tensor,
779 seqlen_offsets: &[usize],
780 seqlen_offsets_full: &[usize],
781 no_kv_cache: bool,
782 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
783 context_lens: Vec<(usize, usize)>,
784 _position_ids: Vec<usize>,
785 flash_params: &FlashParams,
786 flash_params_full: &FlashParams,
787 ) -> Result<Tensor> {
788 self.forward(
789 input_ids,
790 input_ids_full,
791 seqlen_offsets,
792 seqlen_offsets_full,
793 no_kv_cache,
794 non_granular_state,
795 context_lens,
796 flash_params,
797 flash_params_full,
798 )
799 }
800 fn cache(&self) -> &EitherCache {
801 &self.cache
802 }
803 fn cache_mut(&mut self) -> &mut EitherCache {
804 &mut self.cache
805 }
806 fn device(&self) -> &Device {
807 &self.device
808 }
809 fn is_xlora(&self) -> bool {
810 false
811 }
812 fn max_seq_len(&self) -> usize {
813 self.max_seq_len
814 }
815 fn config(&self) -> &ModelConfigMetadata {
816 &self.cfg
817 }
818}
819
820impl ScalingsMaker for Model {
821 fn dtype(&self) -> DType {
822 self.dtype
823 }
824 fn get_cache(&self) -> &EitherCache {
825 &self.cache
826 }
827 fn get_classifier(&self) -> &XLoraClassifier {
828 self.xlora_classifier.as_ref().unwrap()
829 }
830 fn forward(
831 &self,
832 input_ids: &Tensor,
833 seqlen_offsets: &[usize],
834 scalings: Tensor,
835 is_full_pass: bool,
836 no_kv_cache: bool,
837 is_scaling_pass: Option<f64>,
838 _context_lens: &[usize],
839 flash_params: &FlashParams,
840 ) -> Result<Tensor> {
841 self.inner_forward(
842 input_ids,
843 seqlen_offsets,
844 Some(scalings),
845 is_full_pass,
846 no_kv_cache,
847 is_scaling_pass,
848 flash_params,
849 )
850 }
851}
852
853impl AnyMoeBaseModelMixin for Model {}