1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use candle_core::{DType, Device, Module, Result, Tensor};
4use candle_nn::LayerNorm;
5use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
6use std::{collections::HashMap, sync::Arc};
7use tqdm::Iter;
8use tracing::info;
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 attention::SdpaParams,
13 device_map::DeviceMapper,
14 layers::{self, layer_norm, Activation, CausalMasker, RotaryEmbedding, Sdpa},
15 lora::{linear_b, linear_no_bias, LinearLayerLike, LoraConfig},
16 models::starcoder2::Config,
17 paged_attention::ModelConfigMetadata,
18 pipeline::{
19 extract_logits,
20 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
21 Cache, EitherCache, IsqModel, NormalLoadingMetadata, NormalModel,
22 },
23 utils::progress::NiceProgressBar,
24 Ordering,
25};
26
27use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
28
29#[derive(Clone)]
30#[allow(clippy::upper_case_acronyms)]
31struct MLP {
32 c_fc: Arc<dyn LinearLayerLike + Send + Sync>,
33 c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34 act: Activation,
35}
36
37impl MLP {
38 #[allow(clippy::too_many_arguments)]
39 fn new(
40 cfg: &Config,
41 vb: ShardedVarBuilder,
42 lora_config: &[((String, String), LoraConfig)],
43 count: &mut usize,
44 ord: &Ordering,
45 mapper: &dyn DeviceMapper,
46 layer_idx: usize,
47 loading_isq: bool,
48 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
49 ) -> Result<Self> {
50 let (h_size, i_size) = (cfg.hidden_size, cfg.intermediate_size);
51 let c_fc = linear_b(
52 h_size,
53 i_size,
54 cfg.use_bias,
55 mapper.set_device(layer_idx, vb.pp("c_fc"), loading_isq),
56 mapper.set_device(layer_idx, vb.pp("c_fc"), false),
57 lora_config,
58 count,
59 ord,
60 preload_adapters,
61 )?;
62 let c_proj = linear_b(
63 i_size,
64 h_size,
65 cfg.use_bias,
66 mapper.set_device(layer_idx, vb.pp("c_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("c_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 Ok(Self {
74 c_fc,
75 c_proj,
76 act: cfg.hidden_act,
77 })
78 }
79
80 fn forward(
81 &self,
82 xs: &Tensor,
83 scalings: Option<Tensor>,
84 global_scaling_weight: f64,
85 is_scaling_pass: Option<f64>,
86 ) -> Result<Tensor> {
87 let original_dtype = xs.dtype();
88 let mut xs = xs.clone();
89 if let Some(t) = self.c_fc.quantized_act_type() {
90 xs = xs.to_dtype(t)?;
91 }
92 let mut res = self.c_proj.lora_forward(
93 &self
94 .c_fc
95 .lora_forward(
96 &xs,
97 scalings.clone(),
98 global_scaling_weight,
99 is_scaling_pass,
100 )?
101 .apply(&self.act)?,
102 scalings,
103 global_scaling_weight,
104 is_scaling_pass,
105 )?;
106 if self.c_fc.quantized_act_type().is_some() {
107 res = res.to_dtype(original_dtype)?;
108 }
109 Ok(res)
110 }
111}
112
113struct Attention {
114 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
115 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
116 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
117 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
118 num_heads: usize,
119 num_kv_heads: usize,
120 head_dim: usize,
121 hidden_size: usize,
122 rotary_emb: Arc<RotaryEmbedding>,
123 sliding_window: Option<usize>,
124 sdpa_params: SdpaParams,
125}
126
127impl Attention {
128 #[allow(clippy::too_many_arguments)]
129 fn new(
130 rotary_emb: Arc<RotaryEmbedding>,
131 cfg: &Config,
132 vb: ShardedVarBuilder,
133 lora_config: &[((String, String), LoraConfig)],
134 count: &mut usize,
135 ord: &Ordering,
136 mapper: &dyn DeviceMapper,
137 layer_idx: usize,
138 loading_isq: bool,
139 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
140 ) -> Result<Self> {
141 let hidden_sz = cfg.hidden_size;
142 let num_heads = cfg.num_attention_heads;
143 let num_kv_heads = cfg.num_key_value_heads;
144 let head_dim = hidden_sz / num_heads;
145 let b = cfg.use_bias;
146 let q_proj = linear_b(
147 hidden_sz,
148 num_heads * head_dim,
149 b,
150 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
151 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
152 lora_config,
153 count,
154 ord,
155 preload_adapters,
156 )?;
157 let k_proj = linear_b(
158 hidden_sz,
159 num_kv_heads * head_dim,
160 b,
161 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
162 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
163 lora_config,
164 count,
165 ord,
166 preload_adapters,
167 )?;
168 let v_proj = linear_b(
169 hidden_sz,
170 num_kv_heads * head_dim,
171 b,
172 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
173 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
174 lora_config,
175 count,
176 ord,
177 preload_adapters,
178 )?;
179 let o_proj = linear_b(
180 num_heads * head_dim,
181 hidden_sz,
182 b,
183 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
184 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
185 lora_config,
186 count,
187 ord,
188 preload_adapters,
189 )?;
190 Ok(Self {
191 q_proj,
192 k_proj,
193 v_proj,
194 o_proj,
195 num_heads,
196 num_kv_heads,
197 head_dim,
198 hidden_size: hidden_sz,
199 rotary_emb,
200 sliding_window: cfg.sliding_window,
201 sdpa_params: SdpaParams {
202 n_kv_groups: num_heads / num_kv_heads,
203 softcap: None,
204 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
205 sliding_window: cfg.sliding_window,
206 },
207 })
208 }
209
210 #[allow(clippy::too_many_arguments)]
211 fn forward(
212 &self,
213 xs: &Tensor,
214 attention_mask: Option<&Tensor>,
215 seqlen_offsets: &[usize],
216 kv_cache: &mut Option<(Tensor, Tensor)>,
217 scalings: Option<Tensor>,
218 global_scaling_weight: f64,
219 is_scaling_pass: Option<f64>,
220 flash_params: &FlashParams,
221 ) -> Result<Tensor> {
222 let (b_sz, q_len, _) = xs.dims3()?;
223
224 let original_dtype = xs.dtype();
225 let mut xs = xs.clone();
226 if let Some(t) = self.q_proj.quantized_act_type() {
227 xs = xs.to_dtype(t)?;
228 }
229 let mut q = self.q_proj.lora_forward(
230 &xs,
231 scalings.clone(),
232 global_scaling_weight,
233 is_scaling_pass,
234 )?;
235 let mut k = self.k_proj.lora_forward(
236 &xs,
237 scalings.clone(),
238 global_scaling_weight,
239 is_scaling_pass,
240 )?;
241 let mut v = self.v_proj.lora_forward(
242 &xs,
243 scalings.clone(),
244 global_scaling_weight,
245 is_scaling_pass,
246 )?;
247 if self.q_proj.quantized_act_type().is_some() {
248 q = q.to_dtype(original_dtype)?;
249 k = k.to_dtype(original_dtype)?;
250 v = v.to_dtype(original_dtype)?;
251 }
252
253 let (q, k, v) = if q_len != 1 {
254 let q = q
255 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
256 .transpose(1, 2)?;
257 let k = k
258 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
259 .transpose(1, 2)?;
260 let v = v
261 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
262 .transpose(1, 2)?;
263 (q, k, v)
264 } else {
265 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
266 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
267 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
268 (q, k, v)
269 };
270
271 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
272
273 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
274 kv_cache,
275 k,
276 v,
277 attention_mask,
278 self.sliding_window,
279 )?;
280
281 let mut attn_output = Sdpa.run_attention(
282 &q,
283 &k,
284 &v,
285 attn_mask.as_ref(),
286 Some(flash_params),
287 &self.sdpa_params,
288 )?;
289
290 if let Some(t) = self.q_proj.quantized_act_type() {
291 attn_output = attn_output.to_dtype(t)?;
292 }
293 let mut res = self.o_proj.lora_forward(
294 &attn_output
295 .transpose(1, 2)?
296 .reshape((b_sz, q_len, self.hidden_size))?,
297 scalings.clone(),
298 global_scaling_weight,
299 is_scaling_pass,
300 )?;
301 if self.q_proj.quantized_act_type().is_some() {
302 res = res.to_dtype(original_dtype)?;
303 }
304 Ok(res)
305 }
306}
307
308struct DecoderLayer {
309 self_attn: Attention,
310 mlp: MLP,
311 input_layernorm: LayerNorm,
312 post_attention_layernorm: LayerNorm,
313}
314
315impl DecoderLayer {
316 #[allow(clippy::too_many_arguments)]
317 fn new(
318 rotary_emb: Arc<RotaryEmbedding>,
319 cfg: &Config,
320 vb: ShardedVarBuilder,
321 lora_config: &[((String, String), LoraConfig)],
322 count: &mut usize,
323 ord: &Ordering,
324 mapper: &dyn DeviceMapper,
325 layer_idx: usize,
326 loading_isq: bool,
327 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
328 ) -> Result<Self> {
329 let self_attn = Attention::new(
330 rotary_emb,
331 cfg,
332 vb.pp("self_attn"),
333 lora_config,
334 count,
335 ord,
336 mapper,
337 layer_idx,
338 loading_isq,
339 preload_adapters,
340 )?;
341 let mlp = MLP::new(
342 cfg,
343 vb.pp("mlp"),
344 lora_config,
345 count,
346 ord,
347 mapper,
348 layer_idx,
349 loading_isq,
350 preload_adapters,
351 )?;
352 let input_layernorm = layer_norm(
353 cfg.hidden_size,
354 cfg.norm_epsilon,
355 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
356 )?;
357 let post_attention_layernorm = layer_norm(
358 cfg.hidden_size,
359 cfg.norm_epsilon,
360 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
361 )?;
362 Ok(Self {
363 self_attn,
364 mlp,
365 input_layernorm,
366 post_attention_layernorm,
367 })
368 }
369
370 #[allow(clippy::too_many_arguments)]
371 fn forward(
372 &self,
373 xs: &Tensor,
374 attention_mask: Option<&Tensor>,
375 seqlen_offsets: &[usize],
376 kv_cache: &mut Option<(Tensor, Tensor)>,
377 scalings: Option<Tensor>,
378 global_scaling_weight: f64,
379 is_scaling_pass: Option<f64>,
380 flash_params: &FlashParams,
381 ) -> Result<Tensor> {
382 let residual = xs;
383 let xs = self.input_layernorm.forward(xs)?;
384 let xs = self.self_attn.forward(
385 &xs,
386 attention_mask,
387 seqlen_offsets,
388 kv_cache,
389 scalings.clone(),
390 global_scaling_weight,
391 is_scaling_pass,
392 flash_params,
393 )?;
394 let xs = (xs + residual)?;
395 let residual = &xs;
396 let xs = self.mlp.forward(
397 &xs.apply(&self.post_attention_layernorm)?,
398 scalings.clone(),
399 global_scaling_weight,
400 is_scaling_pass,
401 )?;
402 residual + xs
403 }
404}
405
406pub struct Model {
407 embed_tokens: candle_nn::Embedding,
408 layers: Vec<DecoderLayer>,
409 norm: LayerNorm,
410 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
411 sliding_window: Option<usize>,
412 device: Device,
413 cache: EitherCache,
414 max_seq_len: usize,
415 mapper: Box<dyn DeviceMapper + Send + Sync>,
416 xlora_classifier: Option<XLoraClassifier>,
417 dtype: DType,
418 cfg: ModelConfigMetadata,
419}
420
421impl Model {
422 #[allow(clippy::too_many_arguments)]
423 pub fn new(
424 cfg: &Config,
425 vb: ShardedVarBuilder,
426 lora_config: &[((String, String), LoraConfig)],
427 xlora_config: Option<XLoraConfig>,
428 xlora_ordering: Ordering,
429 is_gptx: bool,
430 normal_loading_metadata: NormalLoadingMetadata,
431 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
432 ) -> Result<Self> {
433 if let Some(ref quant_cfg) = &cfg.quantization_config {
434 tracing::info!(
435 "Using {} quantization: {}.",
436 quant_cfg.name(),
437 quant_cfg.get_bits_name(&vb)
438 );
439 }
440 let mapper = normal_loading_metadata.mapper;
441 let vb_m = vb.pp("model");
442
443 let embed_tokens = layers::embedding(
444 cfg.vocab_size,
445 cfg.hidden_size,
446 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
447 &cfg.quantization_config,
448 )?;
449 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
450 let vb_l = vb_m.pp("layers");
451 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
452 let mut ropes = HashMap::new();
453 for layer_idx in 0..cfg.num_hidden_layers {
454 let device = mapper
455 .device_for(layer_idx, false)
456 .unwrap_or(&normal_loading_metadata.real_device);
457 ropes.insert(
458 device.location(),
459 Arc::new(RotaryEmbedding::new(
460 cfg.rope_theta as f32,
461 head_dim,
462 cfg.max_position_embeddings,
463 device,
464 is_gptx,
465 vb_m.dtype(),
466 )?),
467 );
468 }
469 let mut count = 0;
470 for layer_idx in NiceProgressBar::<_, 'b'>(
471 0..cfg.num_hidden_layers,
472 "Loading repeating layers",
473 &normal_loading_metadata.multi_progress,
474 ) {
475 let device = mapper
476 .device_for(layer_idx, false)
477 .unwrap_or(&normal_loading_metadata.real_device);
478 let rotary_emb = ropes
479 .get(&device.location())
480 .expect("No RoPE for device location!")
481 .clone();
482 layers.push(DecoderLayer::new(
483 rotary_emb.clone(),
484 cfg,
485 vb_l.pp(layer_idx),
486 lora_config,
487 &mut count,
488 &xlora_ordering,
489 &*mapper,
490 layer_idx,
491 normal_loading_metadata.loading_isq,
492 preload_adapters,
493 )?)
494 }
495 if xlora_config.is_none() && preload_adapters.is_none() {
496 info!("Merging LoRA adapters.");
498 for layer in layers.iter_mut().tqdm() {
499 Arc::get_mut(&mut layer.self_attn.k_proj)
500 .unwrap()
501 .merge_weights()?;
502 Arc::get_mut(&mut layer.self_attn.o_proj)
503 .unwrap()
504 .merge_weights()?;
505 Arc::get_mut(&mut layer.self_attn.q_proj)
506 .unwrap()
507 .merge_weights()?;
508 Arc::get_mut(&mut layer.self_attn.v_proj)
509 .unwrap()
510 .merge_weights()?;
511
512 Arc::get_mut(&mut layer.mlp.c_fc).unwrap().merge_weights()?;
513 Arc::get_mut(&mut layer.mlp.c_proj)
514 .unwrap()
515 .merge_weights()?;
516 }
517 }
518 let norm = layer_norm(
519 cfg.hidden_size,
520 cfg.norm_epsilon,
521 mapper.set_nm_device(vb_m.pp("norm"), false),
522 )?;
523 let lm_head = linear_no_bias(
524 embed_tokens.embeddings().dim(1)?,
525 embed_tokens.embeddings().dim(0)?,
526 mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
527 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
528 lora_config,
529 &mut count,
530 &xlora_ordering,
531 preload_adapters,
532 )?;
533 if xlora_config.is_some() && lm_head.is_lora() {
534 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
536 }
537 Ok(Self {
538 embed_tokens,
539 layers,
540 norm,
541 lm_head,
542 sliding_window: cfg.sliding_window,
543 device: normal_loading_metadata.real_device,
544 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
545 max_seq_len: cfg.max_position_embeddings,
546 mapper,
547 dtype: vb.dtype(),
548 xlora_classifier: xlora_config.map(|xlora_config| {
549 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
550 }),
551 cfg: ModelConfigMetadata {
552 max_seq_len: cfg.max_position_embeddings,
553 num_layers: cfg.num_hidden_layers,
554 hidden_size: cfg.hidden_size,
555 num_kv_heads: cfg.num_key_value_heads,
556 num_attn_heads: cfg.num_attention_heads,
557 sliding_window: cfg.sliding_window,
558 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
559 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
560 },
561 })
562 }
563
564 #[allow(clippy::too_many_arguments)]
565 fn inner_forward(
566 &self,
567 input_ids: &Tensor,
568 seqlen_offsets: &[usize],
569 scalings: Option<Tensor>,
570 is_full_pass: bool,
571 no_kv_cache: bool,
572 is_scaling_pass: Option<f64>,
573 flash_params: &FlashParams,
574 ) -> Result<Tensor> {
575 let mut xs = self.embed_tokens.forward(input_ids)?;
576
577 let mut cache = if is_full_pass {
578 if no_kv_cache {
579 let mut new_cache = Vec::new();
580 for _ in 0..self.cache.full().xlora_lock().len() {
581 new_cache.push(None);
582 }
583
584 self.cache.full().xlora_lock().clone_from(&new_cache);
585 }
586 self.cache.full().xlora_lock()
587 } else {
588 self.cache.full().lock()
589 };
590 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
591 input_ids,
592 &*cache,
593 self.sliding_window,
594 xs.dtype(),
595 self.cfg.num_attn_heads,
596 )?;
597
598 for (i, layer) in self.layers.iter().enumerate() {
599 xs = self.mapper.map(xs, i)?;
600 xs = layer.forward(
601 &xs,
602 attention_mask
603 .as_ref()
604 .map(|m| m.to_device(xs.device()).unwrap())
605 .as_ref(),
606 seqlen_offsets,
607 &mut cache[i],
608 scalings.clone(),
609 self.xlora_classifier
610 .as_ref()
611 .map(|classifier| classifier.get_global_scaling_weight())
612 .unwrap_or(1.0),
613 is_scaling_pass,
614 flash_params,
615 )?
616 }
617 let xs = xs.to_device(&self.device)?;
618 xs.apply(&self.norm)
619 }
620
621 #[allow(clippy::too_many_arguments)]
622 pub fn forward(
623 &self,
624 input_ids: &Tensor,
625 input_ids_full: &Tensor,
626 seqlen_offsets: &[usize],
627 seqlen_offsets_full: &[usize],
628 no_kv_cache: bool,
629 non_granular_state: &Option<NonGranularState>,
630 context_lens: Vec<(usize, usize)>,
631 flash_params: &FlashParams,
632 flash_params_full: &FlashParams,
633 ) -> Result<Tensor> {
634 if self.xlora_classifier.is_some() {
635 let scalings = self.get_scalings(
636 input_ids,
637 input_ids_full,
638 seqlen_offsets,
639 seqlen_offsets_full,
640 no_kv_cache,
641 non_granular_state,
642 &vec![usize::MAX; context_lens.len()],
643 flash_params,
644 flash_params_full,
645 )?;
646
647 if no_kv_cache {
648 let mut res = self
649 .inner_forward(
650 input_ids_full,
651 seqlen_offsets_full,
652 Some(scalings),
653 true,
654 no_kv_cache,
655 None,
656 flash_params_full,
657 )?
658 .contiguous()?;
659 if let Some(t) = self.lm_head.quantized_act_type() {
660 res = res.to_dtype(t)?;
661 }
662 extract_logits(
663 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
664 context_lens,
665 )
666 } else {
667 let mut res = self
669 .inner_forward(
670 input_ids,
671 seqlen_offsets,
672 Some(scalings),
673 true,
674 no_kv_cache,
675 None,
676 flash_params,
677 )?
678 .contiguous()?;
679 if let Some(t) = self.lm_head.quantized_act_type() {
680 res = res.to_dtype(t)?;
681 }
682 extract_logits(
683 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
684 context_lens,
685 )
686 }
687 } else {
688 let mut res = self
689 .inner_forward(
690 input_ids,
691 seqlen_offsets,
692 None,
693 false,
694 no_kv_cache,
695 None,
696 flash_params,
697 )?
698 .contiguous()?;
699 if let Some(t) = self.lm_head.quantized_act_type() {
700 res = res.to_dtype(t)?;
701 }
702 extract_logits(
703 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
704 context_lens,
705 )
706 }
707 }
708}
709
710impl IsqModel for Model {
711 fn get_layers(
712 &mut self,
713 ) -> (
714 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
715 &dyn DeviceMapper,
716 ) {
717 let mut tensors = Vec::new();
718 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
719 for (i, layer) in self.layers.iter_mut().enumerate() {
720 tensors.push((
721 Arc::get_mut(&mut layer.self_attn.q_proj)
722 .unwrap()
723 .quant_inner(),
724 Some(i),
725 ));
726 tensors.push((
727 Arc::get_mut(&mut layer.self_attn.k_proj)
728 .unwrap()
729 .quant_inner(),
730 Some(i),
731 ));
732 tensors.push((
733 Arc::get_mut(&mut layer.self_attn.v_proj)
734 .unwrap()
735 .quant_inner(),
736 Some(i),
737 ));
738 tensors.push((
739 Arc::get_mut(&mut layer.self_attn.o_proj)
740 .unwrap()
741 .quant_inner(),
742 Some(i),
743 ));
744 tensors.push((
745 Arc::get_mut(&mut layer.mlp.c_fc).unwrap().quant_inner(),
746 Some(i),
747 ));
748 tensors.push((
749 Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
750 Some(i),
751 ));
752 }
753 (tensors, &*self.mapper)
754 }
755
756 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
757 panic!("Cannot generate UQFF for an adapter model.")
758 }
759}
760
761impl NormalModel for Model {
762 fn forward(
763 &self,
764 _input_ids: &Tensor,
765 _seqlen_offsets: &[usize],
766 _context_lens: Vec<(usize, usize)>,
767 _position_ids: Vec<usize>,
768 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
769 _flash_params: &FlashParams,
770 ) -> Result<Tensor> {
771 unimplemented!()
772 }
773 fn xlora_forward(
774 &self,
775 input_ids: &Tensor,
776 input_ids_full: &Tensor,
777 seqlen_offsets: &[usize],
778 seqlen_offsets_full: &[usize],
779 no_kv_cache: bool,
780 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
781 context_lens: Vec<(usize, usize)>,
782 _position_ids: Vec<usize>,
783 flash_params: &FlashParams,
784 flash_params_full: &FlashParams,
785 ) -> Result<Tensor> {
786 self.forward(
787 input_ids,
788 input_ids_full,
789 seqlen_offsets,
790 seqlen_offsets_full,
791 no_kv_cache,
792 non_granular_state,
793 context_lens,
794 flash_params,
795 flash_params_full,
796 )
797 }
798 fn cache(&self) -> &EitherCache {
799 &self.cache
800 }
801 fn cache_mut(&mut self) -> &mut EitherCache {
802 &mut self.cache
803 }
804 fn device(&self) -> &Device {
805 &self.device
806 }
807 fn is_xlora(&self) -> bool {
808 false
809 }
810 fn max_seq_len(&self) -> usize {
811 self.max_seq_len
812 }
813 fn config(&self) -> &ModelConfigMetadata {
814 &self.cfg
815 }
816}
817
818impl ScalingsMaker for Model {
819 fn dtype(&self) -> DType {
820 self.dtype
821 }
822 fn get_cache(&self) -> &EitherCache {
823 &self.cache
824 }
825 fn get_classifier(&self) -> &XLoraClassifier {
826 self.xlora_classifier.as_ref().unwrap()
827 }
828 fn forward(
829 &self,
830 input_ids: &Tensor,
831 seqlen_offsets: &[usize],
832 scalings: Tensor,
833 is_full_pass: bool,
834 no_kv_cache: bool,
835 is_scaling_pass: Option<f64>,
836 _context_lens: &[usize],
837 flash_params: &FlashParams,
838 ) -> Result<Tensor> {
839 self.inner_forward(
840 input_ids,
841 seqlen_offsets,
842 Some(scalings),
843 is_full_pass,
844 no_kv_cache,
845 is_scaling_pass,
846 flash_params,
847 )
848 }
849}
850
851impl AnyMoeBaseModelMixin for Model {}