1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{self, Activation, Sdpa},
9 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24 device_map::DeviceMapper,
25 layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26 models::phi3::Config,
27 pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35 qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 num_heads: usize,
38 num_kv_heads: usize,
39 head_dim: usize,
40 rotary_emb: Arc<PhiRotaryEmbedding>,
41 sliding_window: Option<usize>,
42 sdpa_params: SdpaParams,
43}
44
45impl Attention {
46 #[allow(clippy::too_many_arguments)]
47 fn new(
48 rotary_emb: Arc<PhiRotaryEmbedding>,
49 cfg: &Config,
50 vb: ShardedVarBuilder,
51 lora_config: &[((String, String), LoraConfig)],
52 count: &mut usize,
53 ord: &Ordering,
54 mapper: &dyn DeviceMapper,
55 layer_idx: usize,
56 loading_isq: bool,
57 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58 ) -> Result<Self> {
59 let num_heads = cfg.num_attention_heads;
60 let num_kv_heads = cfg.num_key_value_heads;
61 let head_dim = cfg.head_dim();
62 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63 let qkv_proj = linear_no_bias(
64 cfg.hidden_size,
65 op_size,
66 mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 let o_proj = linear_no_bias(
74 num_heads * head_dim,
75 cfg.hidden_size,
76 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78 lora_config,
79 count,
80 ord,
81 preload_adapters,
82 )?;
83 Ok(Self {
84 qkv_proj,
85 o_proj,
86 rotary_emb,
87 num_heads,
88 num_kv_heads,
89 head_dim,
90 sliding_window: cfg.sliding_window,
91 sdpa_params: SdpaParams {
92 n_kv_groups: num_heads / num_kv_heads,
93 use_flash_attn: cfg.use_flash_attn,
94 softcap: None,
95 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
96 sliding_window: cfg.sliding_window,
97 },
98 })
99 }
100
101 #[allow(clippy::too_many_arguments)]
102 fn forward(
103 &self,
104 xs: &Tensor,
105 attention_mask: Option<&Tensor>,
106 seqlen_offsets: &[usize],
107 position_ids: &[usize],
108 kv_cache: &mut Option<(Tensor, Tensor)>,
109 scalings: Option<Tensor>,
110 global_scaling_weight: f64,
111 is_scaling_pass: Option<f64>,
112 flash_params: &FlashParams,
113 ) -> Result<Tensor> {
114 let (b_sz, q_len, _) = xs.dims3()?;
115
116 let original_dtype = xs.dtype();
117 let mut xs = xs.clone();
118 if let Some(t) = self.qkv_proj.quantized_act_type() {
119 xs = xs.to_dtype(t)?;
120 }
121 let mut qkv = self.qkv_proj.lora_forward(
122 &xs,
123 scalings.clone(),
124 global_scaling_weight,
125 is_scaling_pass,
126 )?;
127 if self.qkv_proj.quantized_act_type().is_some() {
128 qkv = qkv.to_dtype(original_dtype)?;
129 }
130 let query_pos = self.num_heads * self.head_dim;
131 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
132 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
133 let v = qkv.narrow(
134 D::Minus1,
135 query_pos + self.num_kv_heads * self.head_dim,
136 self.num_kv_heads * self.head_dim,
137 )?;
138
139 let (q, k, v) = if q_len != 1 {
140 let q = q
141 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
142 .transpose(1, 2)?;
143 let k = k
144 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
145 .transpose(1, 2)?;
146 let v = v
147 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
148 .transpose(1, 2)?;
149 (q, k, v)
150 } else {
151 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
152 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
154 (q, k, v)
155 };
156
157 let (q, k) = self
158 .rotary_emb
159 .forward(&q, &k, seqlen_offsets, position_ids)?;
160
161 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
162 kv_cache,
163 k,
164 v,
165 attention_mask,
166 self.sliding_window,
167 true,
168 )?;
169
170 let mut attn_output = Sdpa.run_attention(
171 &q,
172 &k,
173 &v,
174 attn_mask.as_ref(),
175 Some(flash_params),
176 &self.sdpa_params,
177 )?;
178
179 if let Some(t) = self.qkv_proj.quantized_act_type() {
180 attn_output = attn_output.to_dtype(t)?;
181 }
182 let mut res = self.o_proj.lora_forward(
183 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
184 scalings.clone(),
185 global_scaling_weight,
186 is_scaling_pass,
187 )?;
188 if self.qkv_proj.quantized_act_type().is_some() {
189 res = res.to_dtype(original_dtype)?;
190 }
191 Ok(res)
192 }
193}
194
195#[derive(Clone)]
196struct Mlp {
197 gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
198 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
199 act_fn: Activation,
200 i_size: usize,
201}
202
203impl Mlp {
204 #[allow(clippy::too_many_arguments)]
205 fn new(
206 cfg: &Config,
207 vb: ShardedVarBuilder,
208 lora_config: &[((String, String), LoraConfig)],
209 count: &mut usize,
210 ord: &Ordering,
211 mapper: &dyn DeviceMapper,
212 layer_idx: usize,
213 loading_isq: bool,
214 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
215 ) -> Result<Self> {
216 let hidden_size = cfg.hidden_size;
217 let i_size = cfg.intermediate_size;
218 let gate_up_proj = linear_no_bias(
219 hidden_size,
220 2 * i_size,
221 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
222 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
223 lora_config,
224 count,
225 ord,
226 preload_adapters,
227 )?;
228 let down_proj = linear_no_bias(
229 i_size,
230 hidden_size,
231 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
232 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
233 lora_config,
234 count,
235 ord,
236 preload_adapters,
237 )?;
238 Ok(Self {
239 gate_up_proj,
240 down_proj,
241 act_fn: cfg.hidden_act,
242 i_size,
243 })
244 }
245
246 fn forward(
247 &self,
248 xs: &Tensor,
249 scalings: Option<Tensor>,
250 global_scaling_weight: f64,
251 is_scaling_pass: Option<f64>,
252 ) -> Result<Tensor> {
253 let original_dtype = xs.dtype();
254 let mut xs = xs.clone();
255 if let Some(t) = self.gate_up_proj.quantized_act_type() {
256 xs = xs.to_dtype(t)?;
257 }
258 let up_states = self.gate_up_proj.lora_forward(
259 &xs,
260 scalings.clone(),
261 global_scaling_weight,
262 is_scaling_pass,
263 )?;
264 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
265 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
266 let up_states = (up_states * gate.apply(&self.act_fn))?;
267 let mut res = self.down_proj.lora_forward(
268 &up_states,
269 scalings,
270 global_scaling_weight,
271 is_scaling_pass,
272 )?;
273 if self.gate_up_proj.quantized_act_type().is_some() {
274 res = res.to_dtype(original_dtype)?;
275 }
276 Ok(res)
277 }
278}
279
280struct DecoderLayer {
281 self_attn: Attention,
282 mlp: Mlp,
283 input_layernorm: RmsNorm,
284 post_attention_layernorm: RmsNorm,
285}
286
287impl DecoderLayer {
288 #[allow(clippy::too_many_arguments)]
289 fn new(
290 rotary_emb: Arc<PhiRotaryEmbedding>,
291 cfg: &Config,
292 vb: ShardedVarBuilder,
293 lora_config: &[((String, String), LoraConfig)],
294 count: &mut usize,
295 ord: &Ordering,
296 mapper: &dyn DeviceMapper,
297 layer_idx: usize,
298 loading_isq: bool,
299 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
300 ) -> Result<Self> {
301 let self_attn = Attention::new(
302 rotary_emb,
303 cfg,
304 vb.pp("self_attn"),
305 lora_config,
306 count,
307 ord,
308 mapper,
309 layer_idx,
310 loading_isq,
311 preload_adapters,
312 )?;
313 let mlp = Mlp::new(
314 cfg,
315 vb.pp("mlp"),
316 lora_config,
317 count,
318 ord,
319 mapper,
320 layer_idx,
321 loading_isq,
322 preload_adapters,
323 )?;
324 let input_layernorm = RmsNorm::new(
325 cfg.hidden_size,
326 cfg.rms_norm_eps,
327 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
328 )?;
329 let post_attention_layernorm = RmsNorm::new(
330 cfg.hidden_size,
331 cfg.rms_norm_eps,
332 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
333 )?;
334 Ok(Self {
335 self_attn,
336 mlp,
337 input_layernorm,
338 post_attention_layernorm,
339 })
340 }
341
342 #[allow(clippy::too_many_arguments)]
343 fn forward(
344 &self,
345 xs: &Tensor,
346 attention_mask: Option<&Tensor>,
347 seqlen_offsets: &[usize],
348 position_ids: &[usize],
349 kv_cache: &mut Option<(Tensor, Tensor)>,
350 scalings: Option<Tensor>,
351 global_scaling_weight: f64,
352 is_scaling_pass: Option<f64>,
353 flash_params: &FlashParams,
354 ) -> Result<Tensor> {
355 let residual = xs;
356 let xs = self.input_layernorm.forward(xs)?;
357 let xs = self.self_attn.forward(
358 &xs,
359 attention_mask,
360 seqlen_offsets,
361 position_ids,
362 kv_cache,
363 scalings.clone(),
364 global_scaling_weight,
365 is_scaling_pass,
366 flash_params,
367 )?;
368 let xs = (xs + residual)?;
369 let residual = &xs;
370 let xs = self.mlp.forward(
371 &xs.apply(&self.post_attention_layernorm)?,
372 scalings,
373 global_scaling_weight,
374 is_scaling_pass,
375 );
376 residual + xs
377 }
378}
379
380pub struct Model {
381 embed_tokens: candle_nn::Embedding,
382 layers: Vec<DecoderLayer>,
383 norm: RmsNorm,
384 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
385 dtype: DType,
386 device: Device,
387 cache: EitherCache,
388 max_seq_len: usize,
389 mapper: Box<dyn DeviceMapper + Send + Sync>,
390 xlora_classifier: Option<XLoraClassifier>,
391 sliding_window: Option<usize>,
392 cfg: ModelConfigMetadata,
393}
394
395impl Model {
396 #[allow(clippy::too_many_arguments)]
397 pub fn new(
398 cfg: &Config,
399 vb: ShardedVarBuilder,
400 lora_config: &[((String, String), LoraConfig)],
401 xlora_config: Option<XLoraConfig>,
402 xlora_ordering: Ordering,
403 _is_gptx: bool,
404 normal_loading_metadata: NormalLoadingMetadata,
405 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
406 ) -> Result<Self> {
407 if let Some(ref quant_cfg) = &cfg.quantization_config {
408 tracing::info!(
409 "Using {} quantization: {}.",
410 quant_cfg.name(),
411 quant_cfg.get_bits_name(&vb)
412 );
413 }
414 let mapper = normal_loading_metadata.mapper;
415 let vb_m = vb.pp("model");
416
417 let embed_tokens = layers::embedding(
418 cfg.vocab_size,
419 cfg.hidden_size,
420 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
421 &cfg.quantization_config,
422 )?;
423 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
424 let vb_l = vb_m.pp("layers");
425 let mut ropes = HashMap::new();
426 for layer_idx in 0..cfg.num_hidden_layers {
427 let device = mapper
428 .device_for(layer_idx, false)
429 .unwrap_or(&normal_loading_metadata.real_device);
430 ropes.insert(
431 device.location(),
432 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
433 );
434 }
435 let mut count = 0;
436 for layer_idx in NiceProgressBar::<_, 'b'>(
437 0..cfg.num_hidden_layers,
438 "Loading repeating layers",
439 &normal_loading_metadata.multi_progress,
440 ) {
441 let device = mapper
442 .device_for(layer_idx, false)
443 .unwrap_or(&normal_loading_metadata.real_device);
444 let rotary_emb = ropes
445 .get(&device.location())
446 .expect("No RoPE for device location!")
447 .clone();
448 let layer = DecoderLayer::new(
449 rotary_emb.clone(),
450 cfg,
451 vb_l.pp(layer_idx),
452 lora_config,
453 &mut count,
454 &xlora_ordering,
455 &*mapper,
456 layer_idx,
457 normal_loading_metadata.loading_isq,
458 preload_adapters,
459 )?;
460 layers.push(layer)
461 }
462 if xlora_config.is_none() && preload_adapters.is_none() {
463 info!("Merging LoRA adapters.");
465 for layer in layers.iter_mut().tqdm() {
466 Arc::get_mut(&mut layer.self_attn.qkv_proj)
467 .unwrap()
468 .merge_weights()?;
469 Arc::get_mut(&mut layer.self_attn.o_proj)
470 .unwrap()
471 .merge_weights()?;
472
473 Arc::get_mut(&mut layer.mlp.down_proj)
474 .unwrap()
475 .merge_weights()?;
476 Arc::get_mut(&mut layer.mlp.gate_up_proj)
477 .unwrap()
478 .merge_weights()?;
479 }
480 }
481 let norm = RmsNorm::new(
482 cfg.hidden_size,
483 cfg.rms_norm_eps,
484 mapper.set_nm_device(vb_m.pp("norm"), false),
485 )?;
486 let lm_head = linear_no_bias(
487 cfg.hidden_size,
488 cfg.vocab_size,
489 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
490 mapper.set_nm_device(vb.pp("lm_head"), false),
491 lora_config,
492 &mut count,
493 &xlora_ordering,
494 preload_adapters,
495 )?;
496 if xlora_config.is_some() && lm_head.is_lora() {
497 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
499 }
500 Ok(Self {
501 embed_tokens,
502 layers,
503 norm,
504 lm_head,
505 device: normal_loading_metadata.real_device,
506 dtype: vb.dtype(),
507 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
508 max_seq_len: cfg.max_position_embeddings,
509 mapper,
510 xlora_classifier: xlora_config.map(|xlora_config| {
511 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
512 }),
513 sliding_window: cfg.sliding_window,
514 cfg: ModelConfigMetadata {
515 max_seq_len: cfg.max_position_embeddings,
516 num_layers: cfg.num_hidden_layers,
517 hidden_size: cfg.hidden_size,
518 num_kv_heads: cfg.num_key_value_heads,
519 num_attn_heads: cfg.num_attention_heads,
520 sliding_window: cfg.sliding_window,
521 k_head_dim: cfg.head_dim(),
522 v_head_dim: cfg.head_dim(),
523 },
524 })
525 }
526
527 #[allow(clippy::too_many_arguments)]
528 fn inner_forward(
529 &self,
530 input_ids: &Tensor,
531 seqlen_offsets: &[usize],
532 position_ids: &[usize],
533 scalings: Option<Tensor>,
534 is_full_pass: bool,
535 no_kv_cache: bool,
536 is_scaling_pass: Option<f64>,
537 flash_params: &FlashParams,
538 ) -> Result<Tensor> {
539 let mut xs = self.embed_tokens.forward(input_ids)?;
540 let mut cache = if is_full_pass {
541 if no_kv_cache {
542 let mut new_cache = Vec::new();
543 for _ in 0..self.cache.full().xlora_lock().len() {
544 new_cache.push(None);
545 }
546
547 self.cache.full().xlora_lock().clone_from(&new_cache);
548 }
549 self.cache.full().xlora_lock()
550 } else {
551 self.cache.full().lock()
552 };
553 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
554 input_ids,
555 &*cache,
556 self.sliding_window,
557 xs.dtype(),
558 self.cfg.num_attn_heads,
559 )?;
560
561 for (i, layer) in self.layers.iter().enumerate() {
562 xs = self.mapper.map(xs, i)?;
563 xs = layer.forward(
564 &xs,
565 attention_mask
566 .as_ref()
567 .map(|m| m.to_device(xs.device()).unwrap())
568 .as_ref(),
569 seqlen_offsets,
570 position_ids,
571 &mut cache[i],
572 scalings.clone(),
573 self.xlora_classifier
574 .as_ref()
575 .map(|classifier| classifier.get_global_scaling_weight())
576 .unwrap_or(1.0),
577 is_scaling_pass,
578 flash_params,
579 )?
580 }
581 let xs = xs.to_device(&self.device)?;
582 xs.apply(&self.norm)
583 }
584
585 #[allow(clippy::too_many_arguments)]
586 pub fn forward(
587 &self,
588 input_ids: &Tensor,
589 input_ids_full: &Tensor,
590 seqlen_offsets: &[usize],
591 seqlen_offsets_full: &[usize],
592 no_kv_cache: bool,
593 non_granular_state: &Option<NonGranularState>,
594 context_lens: Vec<(usize, usize)>,
595 position_ids: Vec<usize>,
596 flash_params: &FlashParams,
597 flash_params_full: &FlashParams,
598 ) -> Result<Tensor> {
599 if self.xlora_classifier.is_some() {
600 let scalings = self.get_scalings(
601 input_ids,
602 input_ids_full,
603 seqlen_offsets,
604 seqlen_offsets_full,
605 no_kv_cache,
606 non_granular_state,
607 &position_ids,
608 flash_params,
609 flash_params_full,
610 )?;
611
612 if no_kv_cache {
613 let mut res = self
614 .inner_forward(
615 input_ids_full,
616 seqlen_offsets_full,
617 &position_ids,
618 Some(scalings),
619 true,
620 no_kv_cache,
621 None,
622 flash_params_full,
623 )?
624 .contiguous()?;
625 if let Some(t) = self.lm_head.quantized_act_type() {
626 res = res.to_dtype(t)?;
627 }
628 extract_logits(
629 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
630 context_lens,
631 )
632 } else {
633 let mut res = self
635 .inner_forward(
636 input_ids,
637 seqlen_offsets,
638 &position_ids,
639 Some(scalings),
640 true,
641 no_kv_cache,
642 None,
643 flash_params,
644 )?
645 .contiguous()?;
646 if let Some(t) = self.lm_head.quantized_act_type() {
647 res = res.to_dtype(t)?;
648 }
649 extract_logits(
650 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
651 context_lens,
652 )
653 }
654 } else {
655 let mut res = self
656 .inner_forward(
657 input_ids,
658 seqlen_offsets,
659 &position_ids,
660 None,
661 false,
662 no_kv_cache,
663 None,
664 flash_params,
665 )?
666 .contiguous()?;
667 if let Some(t) = self.lm_head.quantized_act_type() {
668 res = res.to_dtype(t)?;
669 }
670 extract_logits(
671 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
672 context_lens,
673 )
674 }
675 }
676}
677
678impl IsqModel for Model {
679 fn get_layers(
680 &mut self,
681 ) -> (
682 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
683 &dyn DeviceMapper,
684 ) {
685 let mut tensors = Vec::new();
686 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
687 for (i, layer) in self.layers.iter_mut().enumerate() {
688 tensors.push((
689 Arc::get_mut(&mut layer.self_attn.qkv_proj)
690 .unwrap()
691 .quant_inner(),
692 Some(i),
693 ));
694 tensors.push((
695 Arc::get_mut(&mut layer.self_attn.o_proj)
696 .unwrap()
697 .quant_inner(),
698 Some(i),
699 ));
700 tensors.push((
701 Arc::get_mut(&mut layer.mlp.gate_up_proj)
702 .unwrap()
703 .quant_inner(),
704 Some(i),
705 ));
706 tensors.push((
707 Arc::get_mut(&mut layer.mlp.down_proj)
708 .unwrap()
709 .quant_inner(),
710 Some(i),
711 ));
712 }
713 (tensors, &*self.mapper)
714 }
715
716 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
717 panic!("Cannot generate UQFF for an adapter model.")
718 }
719}
720
721impl NormalModel for Model {
722 fn forward(
723 &self,
724 _input_ids: &Tensor,
725 _seqlen_offsets: &[usize],
726 _context_lens: Vec<(usize, usize)>,
727 _position_ids: Vec<usize>,
728 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
729 _flash_params: &FlashParams,
730 ) -> Result<Tensor> {
731 unreachable!()
732 }
733 fn xlora_forward(
734 &self,
735 input_ids: &Tensor,
736 input_ids_full: &Tensor,
737 seqlen_offsets: &[usize],
738 seqlen_offsets_full: &[usize],
739 no_kv_cache: bool,
740 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
741 context_lens: Vec<(usize, usize)>,
742 position_ids: Vec<usize>,
743 flash_params: &FlashParams,
744 flash_params_full: &FlashParams,
745 ) -> Result<Tensor> {
746 self.forward(
747 input_ids,
748 input_ids_full,
749 seqlen_offsets,
750 seqlen_offsets_full,
751 no_kv_cache,
752 non_granular_state,
753 context_lens,
754 position_ids,
755 flash_params,
756 flash_params_full,
757 )
758 }
759 fn cache(&self) -> &EitherCache {
760 &self.cache
761 }
762 fn cache_mut(&mut self) -> &mut EitherCache {
763 &mut self.cache
764 }
765 fn device(&self) -> &Device {
766 &self.device
767 }
768 fn is_xlora(&self) -> bool {
769 true
770 }
771 fn max_seq_len(&self) -> usize {
772 self.max_seq_len
773 }
774 fn config(&self) -> &ModelConfigMetadata {
775 &self.cfg
776 }
777}
778
779impl ScalingsMaker for Model {
780 fn dtype(&self) -> DType {
781 self.dtype
782 }
783 fn get_cache(&self) -> &EitherCache {
784 &self.cache
785 }
786 fn get_classifier(&self) -> &XLoraClassifier {
787 self.xlora_classifier.as_ref().unwrap()
788 }
789 fn forward(
790 &self,
791 input_ids: &Tensor,
792 seqlen_offsets: &[usize],
793 scalings: Tensor,
794 is_full_pass: bool,
795 no_kv_cache: bool,
796 is_scaling_pass: Option<f64>,
797 context_lens: &[usize],
798 flash_params: &FlashParams,
799 ) -> Result<Tensor> {
800 self.inner_forward(
802 input_ids,
803 seqlen_offsets,
804 context_lens,
805 Some(scalings),
806 is_full_pass,
807 no_kv_cache,
808 is_scaling_pass,
809 flash_params,
810 )
811 }
812}
813
814impl AnyMoeBaseModelMixin for Model {}