1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{collections::HashMap, sync::Arc};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, layer_norm, Activation, CausalMasker, RotaryEmbedding, Sdpa},
15 lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16 models::starcoder2::Config,
17 paged_attention::ModelConfigMetadata,
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22 },
23 utils::progress::NiceProgressBar,
24 Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32 c_fc: Arc<dyn LinearLayerLike + Send + Sync>,
33 c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 act: Activation,
35}
36
37impl MLP {
38 #[allow(clippy::too_many_arguments)]
39 fn new(
40 cfg: &Config,
41 vb: ShardedVarBuilder,
42 lora_config: &[((String, String), LoraConfig)],
43 count: &mut usize,
44 ord: &Ordering,
45 mapper: &dyn DeviceMapper,
46 layer_idx: usize,
47 loading_isq: bool,
48 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
49 ) -> Result<Self> {
50 let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
51 let c_fc = linear_b(
52 h_size,
53 i_size,
54 cfg.use_bias,
55 mapper.set_device(layer_idx, vb.pp("c_fc"), loading_isq),
56 mapper.set_device(layer_idx, vb.pp("c_fc"), false),
57 lora_config,
58 count,
59 ord,
60 preload_adapters,
61 )?;
62 let c_proj = linear_b(
63 i_size,
64 h_size,
65 cfg.use_bias,
66 mapper.set_device(layer_idx, vb.pp("c_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("c_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 Ok(Self {
74 c_fc,
75 c_proj,
76 act: cfg.hidden_act,
77 })
78 }
79
80 fn forward(
81 &self,
82 xs: &Tensor,
83 scalings: Option<Tensor>,
84 global_scaling_weight: f64,
85 is_scaling_pass: Option<f64>,
86 ) -> Result<Tensor> {
87 let original_dtype = xs.dtype();
88 let mut xs = xs.clone();
89 if let Some(t) = self.c_fc.quantized_act_type() {
90 xs = xs.to_dtype(t)?;
91 }
92 let mut res = self.c_proj.lora_forward(
93 &self
94 .c_fc
95 .lora_forward(
96 &xs,
97 scalings.clone(),
98 global_scaling_weight,
99 is_scaling_pass,
100 )?
101 .apply(&self.act)?,
102 scalings,
103 global_scaling_weight,
104 is_scaling_pass,
105 )?;
106 if self.c_fc.quantized_act_type().is_some() {
107 res = res.to_dtype(original_dtype)?;
108 }
109 Ok(res)
110 }
111}
112
113struct Attention {
114 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
115 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
116 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
117 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
118 num_heads: usize,
119 num_kv_heads: usize,
120 head_dim: usize,
121 hidden_size: usize,
122 rotary_emb: Arc<RotaryEmbedding>,
123 sliding_window: Option<usize>,
124 sdpa_params: SdpaParams,
125}
126
127impl Attention {
128 #[allow(clippy::too_many_arguments)]
129 fn new(
130 rotary_emb: Arc<RotaryEmbedding>,
131 cfg: &Config,
132 vb: ShardedVarBuilder,
133 lora_config: &[((String, String), LoraConfig)],
134 count: &mut usize,
135 ord: &Ordering,
136 mapper: &dyn DeviceMapper,
137 layer_idx: usize,
138 loading_isq: bool,
139 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
140 ) -> Result<Self> {
141 let hidden_sz = cfg.hidden_size;
142 let num_heads = cfg.num_attention_heads;
143 let num_kv_heads = cfg.num_key_value_heads;
144 let head_dim = hidden_sz / num_heads;
145 let b = cfg.use_bias;
146 let q_proj = linear_b(
147 hidden_sz,
148 num_heads * head_dim,
149 b,
150 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
151 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
152 lora_config,
153 count,
154 ord,
155 preload_adapters,
156 )?;
157 let k_proj = linear_b(
158 hidden_sz,
159 num_kv_heads * head_dim,
160 b,
161 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
162 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
163 lora_config,
164 count,
165 ord,
166 preload_adapters,
167 )?;
168 let v_proj = linear_b(
169 hidden_sz,
170 num_kv_heads * head_dim,
171 b,
172 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
173 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
174 lora_config,
175 count,
176 ord,
177 preload_adapters,
178 )?;
179 let o_proj = linear_b(
180 num_heads * head_dim,
181 hidden_sz,
182 b,
183 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
184 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
185 lora_config,
186 count,
187 ord,
188 preload_adapters,
189 )?;
190 Ok(Self {
191 q_proj,
192 k_proj,
193 v_proj,
194 o_proj,
195 num_heads,
196 num_kv_heads,
197 head_dim,
198 hidden_size: hidden_sz,
199 rotary_emb,
200 sliding_window: cfg.sliding_window,
201 sdpa_params: SdpaParams {
202 n_kv_groups: num_heads / num_kv_heads,
203 softcap: None,
204 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
205 sliding_window: cfg.sliding_window,
206 },
207 })
208 }
209
210 #[allow(clippy::too_many_arguments)]
211 fn forward(
212 &self,
213 xs: &Tensor,
214 attention_mask: Option<&Tensor>,
215 seqlen_offsets: &[usize],
216 kv_cache: &mut Option<(Tensor, Tensor)>,
217 scalings: Option<Tensor>,
218 global_scaling_weight: f64,
219 is_scaling_pass: Option<f64>,
220 flash_params: &FlashParams,
221 ) -> Result<Tensor> {
222 let (b_sz, q_len, _) = xs.dims3()?;
223
224 let original_dtype = xs.dtype();
225 let mut xs = xs.clone();
226 if let Some(t) = self.q_proj.quantized_act_type() {
227 xs = xs.to_dtype(t)?;
228 }
229 let mut q = self.q_proj.lora_forward(
230 &xs,
231 scalings.clone(),
232 global_scaling_weight,
233 is_scaling_pass,
234 )?;
235 let mut k = self.k_proj.lora_forward(
236 &xs,
237 scalings.clone(),
238 global_scaling_weight,
239 is_scaling_pass,
240 )?;
241 let mut v = self.v_proj.lora_forward(
242 &xs,
243 scalings.clone(),
244 global_scaling_weight,
245 is_scaling_pass,
246 )?;
247 if self.q_proj.quantized_act_type().is_some() {
248 q = q.to_dtype(original_dtype)?;
249 k = k.to_dtype(original_dtype)?;
250 v = v.to_dtype(original_dtype)?;
251 }
252
253 let (q, k, v) = if q_len != 1 {
254 let q = q
255 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
256 .transpose(1, 2)?;
257 let k = k
258 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
259 .transpose(1, 2)?;
260 let v = v
261 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
262 .transpose(1, 2)?;
263 (q, k, v)
264 } else {
265 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
266 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
267 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
268 (q, k, v)
269 };
270
271 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
272
273 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
274 kv_cache,
275 k,
276 v,
277 attention_mask,
278 self.sliding_window,
279 false,
280 )?;
281
282 let mut attn_output = Sdpa.run_attention(
283 &q,
284 &k,
285 &v,
286 attn_mask.as_ref(),
287 Some(flash_params),
288 &self.sdpa_params,
289 )?;
290
291 if let Some(t) = self.q_proj.quantized_act_type() {
292 attn_output = attn_output.to_dtype(t)?;
293 }
294 let mut res = self.o_proj.lora_forward(
295 &attn_output
296 .transpose(1, 2)?
297 .reshape((b_sz, q_len, self.hidden_size))?,
298 scalings.clone(),
299 global_scaling_weight,
300 is_scaling_pass,
301 )?;
302 if self.q_proj.quantized_act_type().is_some() {
303 res = res.to_dtype(original_dtype)?;
304 }
305 Ok(res)
306 }
307}
308
309struct DecoderLayer {
310 self_attn: Attention,
311 mlp: MLP,
312 input_layernorm: LayerNorm,
313 post_attention_layernorm: LayerNorm,
314}
315
316impl DecoderLayer {
317 #[allow(clippy::too_many_arguments)]
318 fn new(
319 rotary_emb: Arc<RotaryEmbedding>,
320 cfg: &Config,
321 vb: ShardedVarBuilder,
322 lora_config: &[((String, String), LoraConfig)],
323 count: &mut usize,
324 ord: &Ordering,
325 mapper: &dyn DeviceMapper,
326 layer_idx: usize,
327 loading_isq: bool,
328 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
329 ) -> Result<Self> {
330 let self_attn = Attention::new(
331 rotary_emb,
332 cfg,
333 vb.pp("self_attn"),
334 lora_config,
335 count,
336 ord,
337 mapper,
338 layer_idx,
339 loading_isq,
340 preload_adapters,
341 )?;
342 let mlp = MLP::new(
343 cfg,
344 vb.pp("mlp"),
345 lora_config,
346 count,
347 ord,
348 mapper,
349 layer_idx,
350 loading_isq,
351 preload_adapters,
352 )?;
353 let input_layernorm = layer_norm(
354 cfg.hidden_size,
355 cfg.norm_epsilon,
356 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
357 )?;
358 let post_attention_layernorm = layer_norm(
359 cfg.hidden_size,
360 cfg.norm_epsilon,
361 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
362 )?;
363 Ok(Self {
364 self_attn,
365 mlp,
366 input_layernorm,
367 post_attention_layernorm,
368 })
369 }
370
371 #[allow(clippy::too_many_arguments)]
372 fn forward(
373 &self,
374 xs: &Tensor,
375 attention_mask: Option<&Tensor>,
376 seqlen_offsets: &[usize],
377 kv_cache: &mut Option<(Tensor, Tensor)>,
378 scalings: Option<Tensor>,
379 global_scaling_weight: f64,
380 is_scaling_pass: Option<f64>,
381 flash_params: &FlashParams,
382 ) -> Result<Tensor> {
383 let residual = xs;
384 let xs = self.input_layernorm.forward(xs)?;
385 let xs = self.self_attn.forward(
386 &xs,
387 attention_mask,
388 seqlen_offsets,
389 kv_cache,
390 scalings.clone(),
391 global_scaling_weight,
392 is_scaling_pass,
393 flash_params,
394 )?;
395 let xs = (xs + residual)?;
396 let residual = &xs;
397 let xs = self.mlp.forward(
398 &xs.apply(&self.post_attention_layernorm)?,
399 scalings.clone(),
400 global_scaling_weight,
401 is_scaling_pass,
402 )?;
403 residual + xs
404 }
405}
406
407pub struct Model {
408 embed_tokens: candle_nn::Embedding,
409 layers: Vec<DecoderLayer>,
410 norm: LayerNorm,
411 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
412 sliding_window: Option<usize>,
413 device: Device,
414 cache: EitherCache,
415 max_seq_len: usize,
416 mapper: Box<dyn DeviceMapper + Send + Sync>,
417 xlora_classifier: Option<XLoraClassifier>,
418 dtype: DType,
419 cfg: ModelConfigMetadata,
420}
421
422impl Model {
423 #[allow(clippy::too_many_arguments)]
424 pub fn new(
425 cfg: &Config,
426 vb: ShardedVarBuilder,
427 lora_config: &[((String, String), LoraConfig)],
428 xlora_config: Option<XLoraConfig>,
429 xlora_ordering: Ordering,
430 is_gptx: bool,
431 normal_loading_metadata: NormalLoadingMetadata,
432 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
433 ) -> Result<Self> {
434 if let Some(ref quant_cfg) = &cfg.quantization_config {
435 tracing::info!(
436 "Using {} quantization: {}.",
437 quant_cfg.name(),
438 quant_cfg.get_bits_name(&vb)
439 );
440 }
441 let mapper = normal_loading_metadata.mapper;
442 let vb_m = vb.pp("model");
443
444 let embed_tokens = layers::embedding(
445 cfg.vocab_size,
446 cfg.hidden_size,
447 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
448 &cfg.quantization_config,
449 )?;
450 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
451 let vb_l = vb_m.pp("layers");
452 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
453 let mut ropes = HashMap::new();
454 for layer_idx in 0..cfg.num_hidden_layers {
455 let device = mapper
456 .device_for(layer_idx, false)
457 .unwrap_or(&normal_loading_metadata.real_device);
458 ropes.insert(
459 device.location(),
460 Arc::new(RotaryEmbedding::new(
461 cfg.rope_theta as f32,
462 head_dim,
463 cfg.max_position_embeddings,
464 device,
465 is_gptx,
466 vb_m.dtype(),
467 )?),
468 );
469 }
470 let mut count = 0;
471 for layer_idx in NiceProgressBar::<_, 'b'>(
472 0..cfg.num_hidden_layers,
473 "Loading repeating layers",
474 &normal_loading_metadata.multi_progress,
475 ) {
476 let device = mapper
477 .device_for(layer_idx, false)
478 .unwrap_or(&normal_loading_metadata.real_device);
479 let rotary_emb = ropes
480 .get(&device.location())
481 .expect("No RoPE for device location!")
482 .clone();
483 layers.push(DecoderLayer::new(
484 rotary_emb.clone(),
485 cfg,
486 vb_l.pp(layer_idx),
487 lora_config,
488 &mut count,
489 &xlora_ordering,
490 &*mapper,
491 layer_idx,
492 normal_loading_metadata.loading_isq,
493 preload_adapters,
494 )?)
495 }
496 if xlora_config.is_none() && preload_adapters.is_none() {
497 info!("Merging LoRA adapters.");
499 for layer in layers.iter_mut().tqdm() {
500 Arc::get_mut(&mut layer.self_attn.k_proj)
501 .unwrap()
502 .merge_weights()?;
503 Arc::get_mut(&mut layer.self_attn.o_proj)
504 .unwrap()
505 .merge_weights()?;
506 Arc::get_mut(&mut layer.self_attn.q_proj)
507 .unwrap()
508 .merge_weights()?;
509 Arc::get_mut(&mut layer.self_attn.v_proj)
510 .unwrap()
511 .merge_weights()?;
512
513 Arc::get_mut(&mut layer.mlp.c_fc).unwrap().merge_weights()?;
514 Arc::get_mut(&mut layer.mlp.c_proj)
515 .unwrap()
516 .merge_weights()?;
517 }
518 }
519 let norm = layer_norm(
520 cfg.hidden_size,
521 cfg.norm_epsilon,
522 mapper.set_nm_device(vb_m.pp("norm"), false),
523 )?;
524 let lm_head = linear_no_bias(
525 embed_tokens.embeddings().dim(1)?,
526 embed_tokens.embeddings().dim(0)?,
527 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
528 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
529 lora_config,
530 &mut count,
531 &xlora_ordering,
532 preload_adapters,
533 )?;
534 if xlora_config.is_some() && lm_head.is_lora() {
535 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
537 }
538 Ok(Self {
539 embed_tokens,
540 layers,
541 norm,
542 lm_head,
543 sliding_window: cfg.sliding_window,
544 device: normal_loading_metadata.real_device,
545 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
546 max_seq_len: cfg.max_position_embeddings,
547 mapper,
548 dtype: vb.dtype(),
549 xlora_classifier: xlora_config.map(|xlora_config| {
550 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
551 }),
552 cfg: ModelConfigMetadata {
553 max_seq_len: cfg.max_position_embeddings,
554 num_layers: cfg.num_hidden_layers,
555 hidden_size: cfg.hidden_size,
556 num_kv_heads: cfg.num_key_value_heads,
557 num_attn_heads: cfg.num_attention_heads,
558 sliding_window: cfg.sliding_window,
559 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
560 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
561 },
562 })
563 }
564
565 #[allow(clippy::too_many_arguments)]
566 fn inner_forward(
567 &self,
568 input_ids: &Tensor,
569 seqlen_offsets: &[usize],
570 scalings: Option<Tensor>,
571 is_full_pass: bool,
572 no_kv_cache: bool,
573 is_scaling_pass: Option<f64>,
574 flash_params: &FlashParams,
575 ) -> Result<Tensor> {
576 let mut xs = self.embed_tokens.forward(input_ids)?;
577
578 let mut cache = if is_full_pass {
579 if no_kv_cache {
580 let mut new_cache = Vec::new();
581 for _ in 0..self.cache.full().xlora_lock().len() {
582 new_cache.push(None);
583 }
584
585 self.cache.full().xlora_lock().clone_from(&new_cache);
586 }
587 self.cache.full().xlora_lock()
588 } else {
589 self.cache.full().lock()
590 };
591 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
592 input_ids,
593 &*cache,
594 self.sliding_window,
595 xs.dtype(),
596 self.cfg.num_attn_heads,
597 )?;
598
599 for (i, layer) in self.layers.iter().enumerate() {
600 xs = self.mapper.map(xs, i)?;
601 xs = layer.forward(
602 &xs,
603 attention_mask
604 .as_ref()
605 .map(|m| m.to_device(xs.device()).unwrap())
606 .as_ref(),
607 seqlen_offsets,
608 &mut cache[i],
609 scalings.clone(),
610 self.xlora_classifier
611 .as_ref()
612 .map(|classifier| classifier.get_global_scaling_weight())
613 .unwrap_or(1.0),
614 is_scaling_pass,
615 flash_params,
616 )?
617 }
618 let xs = xs.to_device(&self.device)?;
619 xs.apply(&self.norm)
620 }
621
622 #[allow(clippy::too_many_arguments)]
623 pub fn forward(
624 &self,
625 input_ids: &Tensor,
626 input_ids_full: &Tensor,
627 seqlen_offsets: &[usize],
628 seqlen_offsets_full: &[usize],
629 no_kv_cache: bool,
630 non_granular_state: &Option<NonGranularState>,
631 context_lens: Vec<(usize, usize)>,
632 flash_params: &FlashParams,
633 flash_params_full: &FlashParams,
634 ) -> Result<Tensor> {
635 if self.xlora_classifier.is_some() {
636 let scalings = self.get_scalings(
637 input_ids,
638 input_ids_full,
639 seqlen_offsets,
640 seqlen_offsets_full,
641 no_kv_cache,
642 non_granular_state,
643 &vec![usize::MAX; context_lens.len()],
644 flash_params,
645 flash_params_full,
646 )?;
647
648 if no_kv_cache {
649 let mut res = self
650 .inner_forward(
651 input_ids_full,
652 seqlen_offsets_full,
653 Some(scalings),
654 true,
655 no_kv_cache,
656 None,
657 flash_params_full,
658 )?
659 .contiguous()?;
660 if let Some(t) = self.lm_head.quantized_act_type() {
661 res = res.to_dtype(t)?;
662 }
663 extract_logits(
664 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
665 context_lens,
666 )
667 } else {
668 let mut res = self
670 .inner_forward(
671 input_ids,
672 seqlen_offsets,
673 Some(scalings),
674 true,
675 no_kv_cache,
676 None,
677 flash_params,
678 )?
679 .contiguous()?;
680 if let Some(t) = self.lm_head.quantized_act_type() {
681 res = res.to_dtype(t)?;
682 }
683 extract_logits(
684 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
685 context_lens,
686 )
687 }
688 } else {
689 let mut res = self
690 .inner_forward(
691 input_ids,
692 seqlen_offsets,
693 None,
694 false,
695 no_kv_cache,
696 None,
697 flash_params,
698 )?
699 .contiguous()?;
700 if let Some(t) = self.lm_head.quantized_act_type() {
701 res = res.to_dtype(t)?;
702 }
703 extract_logits(
704 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
705 context_lens,
706 )
707 }
708 }
709}
710
711impl IsqModel for Model {
712 fn get_layers(
713 &mut self,
714 ) -> (
715 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
716 &dyn DeviceMapper,
717 ) {
718 let mut tensors = Vec::new();
719 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
720 for (i, layer) in self.layers.iter_mut().enumerate() {
721 tensors.push((
722 Arc::get_mut(&mut layer.self_attn.q_proj)
723 .unwrap()
724 .quant_inner(),
725 Some(i),
726 ));
727 tensors.push((
728 Arc::get_mut(&mut layer.self_attn.k_proj)
729 .unwrap()
730 .quant_inner(),
731 Some(i),
732 ));
733 tensors.push((
734 Arc::get_mut(&mut layer.self_attn.v_proj)
735 .unwrap()
736 .quant_inner(),
737 Some(i),
738 ));
739 tensors.push((
740 Arc::get_mut(&mut layer.self_attn.o_proj)
741 .unwrap()
742 .quant_inner(),
743 Some(i),
744 ));
745 tensors.push((
746 Arc::get_mut(&mut layer.mlp.c_fc).unwrap().quant_inner(),
747 Some(i),
748 ));
749 tensors.push((
750 Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
751 Some(i),
752 ));
753 }
754 (tensors, &*self.mapper)
755 }
756
757 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
758 panic!("Cannot generate UQFF for an adapter model.")
759 }
760}
761
762impl NormalModel for Model {
763 fn forward(
764 &self,
765 _input_ids: &Tensor,
766 _seqlen_offsets: &[usize],
767 _context_lens: Vec<(usize, usize)>,
768 _position_ids: Vec<usize>,
769 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
770 _flash_params: &FlashParams,
771 ) -> Result<Tensor> {
772 unimplemented!()
773 }
774 fn xlora_forward(
775 &self,
776 input_ids: &Tensor,
777 input_ids_full: &Tensor,
778 seqlen_offsets: &[usize],
779 seqlen_offsets_full: &[usize],
780 no_kv_cache: bool,
781 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
782 context_lens: Vec<(usize, usize)>,
783 _position_ids: Vec<usize>,
784 flash_params: &FlashParams,
785 flash_params_full: &FlashParams,
786 ) -> Result<Tensor> {
787 self.forward(
788 input_ids,
789 input_ids_full,
790 seqlen_offsets,
791 seqlen_offsets_full,
792 no_kv_cache,
793 non_granular_state,
794 context_lens,
795 flash_params,
796 flash_params_full,
797 )
798 }
799 fn cache(&self) -> &EitherCache {
800 &self.cache
801 }
802 fn cache_mut(&mut self) -> &mut EitherCache {
803 &mut self.cache
804 }
805 fn device(&self) -> &Device {
806 &self.device
807 }
808 fn is_xlora(&self) -> bool {
809 false
810 }
811 fn max_seq_len(&self) -> usize {
812 self.max_seq_len
813 }
814 fn config(&self) -> &ModelConfigMetadata {
815 &self.cfg
816 }
817}
818
819impl ScalingsMaker for Model {
820 fn dtype(&self) -> DType {
821 self.dtype
822 }
823 fn get_cache(&self) -> &EitherCache {
824 &self.cache
825 }
826 fn get_classifier(&self) -> &XLoraClassifier {
827 self.xlora_classifier.as_ref().unwrap()
828 }
829 fn forward(
830 &self,
831 input_ids: &Tensor,
832 seqlen_offsets: &[usize],
833 scalings: Tensor,
834 is_full_pass: bool,
835 no_kv_cache: bool,
836 is_scaling_pass: Option<f64>,
837 _context_lens: &[usize],
838 flash_params: &FlashParams,
839 ) -> Result<Tensor> {
840 self.inner_forward(
841 input_ids,
842 seqlen_offsets,
843 Some(scalings),
844 is_full_pass,
845 no_kv_cache,
846 is_scaling_pass,
847 flash_params,
848 )
849 }
850}
851
852impl AnyMoeBaseModelMixin for Model {}