1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{self, Activation, Sdpa},
9 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24 device_map::DeviceMapper,
25 layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26 models::phi3::Config,
27 pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35 qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 num_heads: usize,
38 num_kv_heads: usize,
39 head_dim: usize,
40 rotary_emb: Arc<PhiRotaryEmbedding>,
41 sliding_window: Option<usize>,
42 sdpa_params: SdpaParams,
43}
44
45impl Attention {
46 #[allow(clippy::too_many_arguments)]
47 fn new(
48 rotary_emb: Arc<PhiRotaryEmbedding>,
49 cfg: &Config,
50 vb: ShardedVarBuilder,
51 lora_config: &[((String, String), LoraConfig)],
52 count: &mut usize,
53 ord: &Ordering,
54 mapper: &dyn DeviceMapper,
55 layer_idx: usize,
56 loading_isq: bool,
57 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58 ) -> Result<Self> {
59 let num_heads = cfg.num_attention_heads;
60 let num_kv_heads = cfg.num_key_value_heads;
61 let head_dim = cfg.head_dim();
62 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63 let qkv_proj = linear_no_bias(
64 cfg.hidden_size,
65 op_size,
66 mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 let o_proj = linear_no_bias(
74 num_heads * head_dim,
75 cfg.hidden_size,
76 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78 lora_config,
79 count,
80 ord,
81 preload_adapters,
82 )?;
83 Ok(Self {
84 qkv_proj,
85 o_proj,
86 rotary_emb,
87 num_heads,
88 num_kv_heads,
89 head_dim,
90 sliding_window: cfg.sliding_window,
91 sdpa_params: SdpaParams {
92 n_kv_groups: num_heads / num_kv_heads,
93 softcap: None,
94 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
95 sliding_window: cfg.sliding_window,
96 },
97 })
98 }
99
100 #[allow(clippy::too_many_arguments)]
101 fn forward(
102 &self,
103 xs: &Tensor,
104 attention_mask: Option<&Tensor>,
105 seqlen_offsets: &[usize],
106 position_ids: &[usize],
107 kv_cache: &mut Option<(Tensor, Tensor)>,
108 scalings: Option<Tensor>,
109 global_scaling_weight: f64,
110 is_scaling_pass: Option<f64>,
111 flash_params: &FlashParams,
112 ) -> Result<Tensor> {
113 let (b_sz, q_len, _) = xs.dims3()?;
114
115 let original_dtype = xs.dtype();
116 let mut xs = xs.clone();
117 if let Some(t) = self.qkv_proj.quantized_act_type() {
118 xs = xs.to_dtype(t)?;
119 }
120 let mut qkv = self.qkv_proj.lora_forward(
121 &xs,
122 scalings.clone(),
123 global_scaling_weight,
124 is_scaling_pass,
125 )?;
126 if self.qkv_proj.quantized_act_type().is_some() {
127 qkv = qkv.to_dtype(original_dtype)?;
128 }
129 let query_pos = self.num_heads * self.head_dim;
130 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
131 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
132 let v = qkv.narrow(
133 D::Minus1,
134 query_pos + self.num_kv_heads * self.head_dim,
135 self.num_kv_heads * self.head_dim,
136 )?;
137
138 let (q, k, v) = if q_len != 1 {
139 let q = q
140 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
141 .transpose(1, 2)?;
142 let k = k
143 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
144 .transpose(1, 2)?;
145 let v = v
146 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
147 .transpose(1, 2)?;
148 (q, k, v)
149 } else {
150 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
151 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
152 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153 (q, k, v)
154 };
155
156 let (q, k) = self
157 .rotary_emb
158 .forward(&q, &k, seqlen_offsets, position_ids)?;
159
160 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
161 kv_cache,
162 k,
163 v,
164 attention_mask,
165 self.sliding_window,
166 true,
167 )?;
168
169 let mut attn_output = Sdpa.run_attention(
170 &q,
171 &k,
172 &v,
173 attn_mask.as_ref(),
174 Some(flash_params),
175 &self.sdpa_params,
176 )?;
177
178 if let Some(t) = self.qkv_proj.quantized_act_type() {
179 attn_output = attn_output.to_dtype(t)?;
180 }
181 let mut res = self.o_proj.lora_forward(
182 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
183 scalings.clone(),
184 global_scaling_weight,
185 is_scaling_pass,
186 )?;
187 if self.qkv_proj.quantized_act_type().is_some() {
188 res = res.to_dtype(original_dtype)?;
189 }
190 Ok(res)
191 }
192}
193
194#[derive(Clone)]
195struct Mlp {
196 gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
197 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
198 act_fn: Activation,
199 i_size: usize,
200}
201
202impl Mlp {
203 #[allow(clippy::too_many_arguments)]
204 fn new(
205 cfg: &Config,
206 vb: ShardedVarBuilder,
207 lora_config: &[((String, String), LoraConfig)],
208 count: &mut usize,
209 ord: &Ordering,
210 mapper: &dyn DeviceMapper,
211 layer_idx: usize,
212 loading_isq: bool,
213 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
214 ) -> Result<Self> {
215 let hidden_size = cfg.hidden_size;
216 let i_size = cfg.intermediate_size;
217 let gate_up_proj = linear_no_bias(
218 hidden_size,
219 2 * i_size,
220 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
221 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
222 lora_config,
223 count,
224 ord,
225 preload_adapters,
226 )?;
227 let down_proj = linear_no_bias(
228 i_size,
229 hidden_size,
230 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
231 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
232 lora_config,
233 count,
234 ord,
235 preload_adapters,
236 )?;
237 Ok(Self {
238 gate_up_proj,
239 down_proj,
240 act_fn: cfg.hidden_act,
241 i_size,
242 })
243 }
244
245 fn forward(
246 &self,
247 xs: &Tensor,
248 scalings: Option<Tensor>,
249 global_scaling_weight: f64,
250 is_scaling_pass: Option<f64>,
251 ) -> Result<Tensor> {
252 let original_dtype = xs.dtype();
253 let mut xs = xs.clone();
254 if let Some(t) = self.gate_up_proj.quantized_act_type() {
255 xs = xs.to_dtype(t)?;
256 }
257 let up_states = self.gate_up_proj.lora_forward(
258 &xs,
259 scalings.clone(),
260 global_scaling_weight,
261 is_scaling_pass,
262 )?;
263 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
264 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
265 let up_states = (up_states * gate.apply(&self.act_fn))?;
266 let mut res = self.down_proj.lora_forward(
267 &up_states,
268 scalings,
269 global_scaling_weight,
270 is_scaling_pass,
271 )?;
272 if self.gate_up_proj.quantized_act_type().is_some() {
273 res = res.to_dtype(original_dtype)?;
274 }
275 Ok(res)
276 }
277}
278
279struct DecoderLayer {
280 self_attn: Attention,
281 mlp: Mlp,
282 input_layernorm: RmsNorm,
283 post_attention_layernorm: RmsNorm,
284}
285
286impl DecoderLayer {
287 #[allow(clippy::too_many_arguments)]
288 fn new(
289 rotary_emb: Arc<PhiRotaryEmbedding>,
290 cfg: &Config,
291 vb: ShardedVarBuilder,
292 lora_config: &[((String, String), LoraConfig)],
293 count: &mut usize,
294 ord: &Ordering,
295 mapper: &dyn DeviceMapper,
296 layer_idx: usize,
297 loading_isq: bool,
298 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
299 ) -> Result<Self> {
300 let self_attn = Attention::new(
301 rotary_emb,
302 cfg,
303 vb.pp("self_attn"),
304 lora_config,
305 count,
306 ord,
307 mapper,
308 layer_idx,
309 loading_isq,
310 preload_adapters,
311 )?;
312 let mlp = Mlp::new(
313 cfg,
314 vb.pp("mlp"),
315 lora_config,
316 count,
317 ord,
318 mapper,
319 layer_idx,
320 loading_isq,
321 preload_adapters,
322 )?;
323 let input_layernorm = RmsNorm::new(
324 cfg.hidden_size,
325 cfg.rms_norm_eps,
326 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
327 )?;
328 let post_attention_layernorm = RmsNorm::new(
329 cfg.hidden_size,
330 cfg.rms_norm_eps,
331 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
332 )?;
333 Ok(Self {
334 self_attn,
335 mlp,
336 input_layernorm,
337 post_attention_layernorm,
338 })
339 }
340
341 #[allow(clippy::too_many_arguments)]
342 fn forward(
343 &self,
344 xs: &Tensor,
345 attention_mask: Option<&Tensor>,
346 seqlen_offsets: &[usize],
347 position_ids: &[usize],
348 kv_cache: &mut Option<(Tensor, Tensor)>,
349 scalings: Option<Tensor>,
350 global_scaling_weight: f64,
351 is_scaling_pass: Option<f64>,
352 flash_params: &FlashParams,
353 ) -> Result<Tensor> {
354 let residual = xs;
355 let xs = self.input_layernorm.forward(xs)?;
356 let xs = self.self_attn.forward(
357 &xs,
358 attention_mask,
359 seqlen_offsets,
360 position_ids,
361 kv_cache,
362 scalings.clone(),
363 global_scaling_weight,
364 is_scaling_pass,
365 flash_params,
366 )?;
367 let xs = (xs + residual)?;
368 let residual = &xs;
369 let xs = self.mlp.forward(
370 &xs.apply(&self.post_attention_layernorm)?,
371 scalings,
372 global_scaling_weight,
373 is_scaling_pass,
374 );
375 residual + xs
376 }
377}
378
379pub struct Model {
380 embed_tokens: candle_nn::Embedding,
381 layers: Vec<DecoderLayer>,
382 norm: RmsNorm,
383 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
384 dtype: DType,
385 device: Device,
386 cache: EitherCache,
387 max_seq_len: usize,
388 mapper: Box<dyn DeviceMapper + Send + Sync>,
389 xlora_classifier: Option<XLoraClassifier>,
390 sliding_window: Option<usize>,
391 cfg: ModelConfigMetadata,
392}
393
394impl Model {
395 #[allow(clippy::too_many_arguments)]
396 pub fn new(
397 cfg: &Config,
398 vb: ShardedVarBuilder,
399 lora_config: &[((String, String), LoraConfig)],
400 xlora_config: Option<XLoraConfig>,
401 xlora_ordering: Ordering,
402 _is_gptx: bool,
403 normal_loading_metadata: NormalLoadingMetadata,
404 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
405 ) -> Result<Self> {
406 if let Some(ref quant_cfg) = &cfg.quantization_config {
407 tracing::info!(
408 "Using {} quantization: {}.",
409 quant_cfg.name(),
410 quant_cfg.get_bits_name(&vb)
411 );
412 }
413 let mapper = normal_loading_metadata.mapper;
414 let vb_m = vb.pp("model");
415
416 let embed_tokens = layers::embedding(
417 cfg.vocab_size,
418 cfg.hidden_size,
419 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
420 &cfg.quantization_config,
421 )?;
422 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
423 let vb_l = vb_m.pp("layers");
424 let mut ropes = HashMap::new();
425 for layer_idx in 0..cfg.num_hidden_layers {
426 let device = mapper
427 .device_for(layer_idx, false)
428 .unwrap_or(&normal_loading_metadata.real_device);
429 ropes.insert(
430 device.location(),
431 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
432 );
433 }
434 let mut count = 0;
435 for layer_idx in NiceProgressBar::<_, 'b'>(
436 0..cfg.num_hidden_layers,
437 "Loading repeating layers",
438 &normal_loading_metadata.multi_progress,
439 ) {
440 let device = mapper
441 .device_for(layer_idx, false)
442 .unwrap_or(&normal_loading_metadata.real_device);
443 let rotary_emb = ropes
444 .get(&device.location())
445 .expect("No RoPE for device location!")
446 .clone();
447 let layer = DecoderLayer::new(
448 rotary_emb.clone(),
449 cfg,
450 vb_l.pp(layer_idx),
451 lora_config,
452 &mut count,
453 &xlora_ordering,
454 &*mapper,
455 layer_idx,
456 normal_loading_metadata.loading_isq,
457 preload_adapters,
458 )?;
459 layers.push(layer)
460 }
461 if xlora_config.is_none() && preload_adapters.is_none() {
462 info!("Merging LoRA adapters.");
464 for layer in layers.iter_mut().tqdm() {
465 Arc::get_mut(&mut layer.self_attn.qkv_proj)
466 .unwrap()
467 .merge_weights()?;
468 Arc::get_mut(&mut layer.self_attn.o_proj)
469 .unwrap()
470 .merge_weights()?;
471
472 Arc::get_mut(&mut layer.mlp.down_proj)
473 .unwrap()
474 .merge_weights()?;
475 Arc::get_mut(&mut layer.mlp.gate_up_proj)
476 .unwrap()
477 .merge_weights()?;
478 }
479 }
480 let norm = RmsNorm::new(
481 cfg.hidden_size,
482 cfg.rms_norm_eps,
483 mapper.set_nm_device(vb_m.pp("norm"), false),
484 )?;
485 let lm_head = linear_no_bias(
486 cfg.hidden_size,
487 cfg.vocab_size,
488 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
489 mapper.set_nm_device(vb.pp("lm_head"), false),
490 lora_config,
491 &mut count,
492 &xlora_ordering,
493 preload_adapters,
494 )?;
495 if xlora_config.is_some() && lm_head.is_lora() {
496 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
498 }
499 Ok(Self {
500 embed_tokens,
501 layers,
502 norm,
503 lm_head,
504 device: normal_loading_metadata.real_device,
505 dtype: vb.dtype(),
506 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
507 max_seq_len: cfg.max_position_embeddings,
508 mapper,
509 xlora_classifier: xlora_config.map(|xlora_config| {
510 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
511 }),
512 sliding_window: cfg.sliding_window,
513 cfg: ModelConfigMetadata {
514 max_seq_len: cfg.max_position_embeddings,
515 num_layers: cfg.num_hidden_layers,
516 hidden_size: cfg.hidden_size,
517 num_kv_heads: cfg.num_key_value_heads,
518 num_attn_heads: cfg.num_attention_heads,
519 sliding_window: cfg.sliding_window,
520 k_head_dim: cfg.head_dim(),
521 v_head_dim: cfg.head_dim(),
522 },
523 })
524 }
525
526 #[allow(clippy::too_many_arguments)]
527 fn inner_forward(
528 &self,
529 input_ids: &Tensor,
530 seqlen_offsets: &[usize],
531 position_ids: &[usize],
532 scalings: Option<Tensor>,
533 is_full_pass: bool,
534 no_kv_cache: bool,
535 is_scaling_pass: Option<f64>,
536 flash_params: &FlashParams,
537 ) -> Result<Tensor> {
538 let mut xs = self.embed_tokens.forward(input_ids)?;
539 let mut cache = if is_full_pass {
540 if no_kv_cache {
541 let mut new_cache = Vec::new();
542 for _ in 0..self.cache.full().xlora_lock().len() {
543 new_cache.push(None);
544 }
545
546 self.cache.full().xlora_lock().clone_from(&new_cache);
547 }
548 self.cache.full().xlora_lock()
549 } else {
550 self.cache.full().lock()
551 };
552 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
553 input_ids,
554 &*cache,
555 self.sliding_window,
556 xs.dtype(),
557 self.cfg.num_attn_heads,
558 )?;
559
560 for (i, layer) in self.layers.iter().enumerate() {
561 xs = self.mapper.map(xs, i)?;
562 xs = layer.forward(
563 &xs,
564 attention_mask
565 .as_ref()
566 .map(|m| m.to_device(xs.device()).unwrap())
567 .as_ref(),
568 seqlen_offsets,
569 position_ids,
570 &mut cache[i],
571 scalings.clone(),
572 self.xlora_classifier
573 .as_ref()
574 .map(|classifier| classifier.get_global_scaling_weight())
575 .unwrap_or(1.0),
576 is_scaling_pass,
577 flash_params,
578 )?
579 }
580 let xs = xs.to_device(&self.device)?;
581 xs.apply(&self.norm)
582 }
583
584 #[allow(clippy::too_many_arguments)]
585 pub fn forward(
586 &self,
587 input_ids: &Tensor,
588 input_ids_full: &Tensor,
589 seqlen_offsets: &[usize],
590 seqlen_offsets_full: &[usize],
591 no_kv_cache: bool,
592 non_granular_state: &Option<NonGranularState>,
593 context_lens: Vec<(usize, usize)>,
594 position_ids: Vec<usize>,
595 flash_params: &FlashParams,
596 flash_params_full: &FlashParams,
597 ) -> Result<Tensor> {
598 if self.xlora_classifier.is_some() {
599 let scalings = self.get_scalings(
600 input_ids,
601 input_ids_full,
602 seqlen_offsets,
603 seqlen_offsets_full,
604 no_kv_cache,
605 non_granular_state,
606 &position_ids,
607 flash_params,
608 flash_params_full,
609 )?;
610
611 if no_kv_cache {
612 let mut res = self
613 .inner_forward(
614 input_ids_full,
615 seqlen_offsets_full,
616 &position_ids,
617 Some(scalings),
618 true,
619 no_kv_cache,
620 None,
621 flash_params_full,
622 )?
623 .contiguous()?;
624 if let Some(t) = self.lm_head.quantized_act_type() {
625 res = res.to_dtype(t)?;
626 }
627 extract_logits(
628 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
629 context_lens,
630 )
631 } else {
632 let mut res = self
634 .inner_forward(
635 input_ids,
636 seqlen_offsets,
637 &position_ids,
638 Some(scalings),
639 true,
640 no_kv_cache,
641 None,
642 flash_params,
643 )?
644 .contiguous()?;
645 if let Some(t) = self.lm_head.quantized_act_type() {
646 res = res.to_dtype(t)?;
647 }
648 extract_logits(
649 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
650 context_lens,
651 )
652 }
653 } else {
654 let mut res = self
655 .inner_forward(
656 input_ids,
657 seqlen_offsets,
658 &position_ids,
659 None,
660 false,
661 no_kv_cache,
662 None,
663 flash_params,
664 )?
665 .contiguous()?;
666 if let Some(t) = self.lm_head.quantized_act_type() {
667 res = res.to_dtype(t)?;
668 }
669 extract_logits(
670 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
671 context_lens,
672 )
673 }
674 }
675}
676
677impl IsqModel for Model {
678 fn get_layers(
679 &mut self,
680 ) -> (
681 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
682 &dyn DeviceMapper,
683 ) {
684 let mut tensors = Vec::new();
685 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
686 for (i, layer) in self.layers.iter_mut().enumerate() {
687 tensors.push((
688 Arc::get_mut(&mut layer.self_attn.qkv_proj)
689 .unwrap()
690 .quant_inner(),
691 Some(i),
692 ));
693 tensors.push((
694 Arc::get_mut(&mut layer.self_attn.o_proj)
695 .unwrap()
696 .quant_inner(),
697 Some(i),
698 ));
699 tensors.push((
700 Arc::get_mut(&mut layer.mlp.gate_up_proj)
701 .unwrap()
702 .quant_inner(),
703 Some(i),
704 ));
705 tensors.push((
706 Arc::get_mut(&mut layer.mlp.down_proj)
707 .unwrap()
708 .quant_inner(),
709 Some(i),
710 ));
711 }
712 (tensors, &*self.mapper)
713 }
714
715 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
716 panic!("Cannot generate UQFF for an adapter model.")
717 }
718}
719
720impl NormalModel for Model {
721 fn forward(
722 &self,
723 _input_ids: &Tensor,
724 _seqlen_offsets: &[usize],
725 _context_lens: Vec<(usize, usize)>,
726 _position_ids: Vec<usize>,
727 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
728 _flash_params: &FlashParams,
729 ) -> Result<Tensor> {
730 unreachable!()
731 }
732 fn xlora_forward(
733 &self,
734 input_ids: &Tensor,
735 input_ids_full: &Tensor,
736 seqlen_offsets: &[usize],
737 seqlen_offsets_full: &[usize],
738 no_kv_cache: bool,
739 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
740 context_lens: Vec<(usize, usize)>,
741 position_ids: Vec<usize>,
742 flash_params: &FlashParams,
743 flash_params_full: &FlashParams,
744 ) -> Result<Tensor> {
745 self.forward(
746 input_ids,
747 input_ids_full,
748 seqlen_offsets,
749 seqlen_offsets_full,
750 no_kv_cache,
751 non_granular_state,
752 context_lens,
753 position_ids,
754 flash_params,
755 flash_params_full,
756 )
757 }
758 fn cache(&self) -> &EitherCache {
759 &self.cache
760 }
761 fn cache_mut(&mut self) -> &mut EitherCache {
762 &mut self.cache
763 }
764 fn device(&self) -> &Device {
765 &self.device
766 }
767 fn is_xlora(&self) -> bool {
768 true
769 }
770 fn max_seq_len(&self) -> usize {
771 self.max_seq_len
772 }
773 fn config(&self) -> &ModelConfigMetadata {
774 &self.cfg
775 }
776}
777
778impl ScalingsMaker for Model {
779 fn dtype(&self) -> DType {
780 self.dtype
781 }
782 fn get_cache(&self) -> &EitherCache {
783 &self.cache
784 }
785 fn get_classifier(&self) -> &XLoraClassifier {
786 self.xlora_classifier.as_ref().unwrap()
787 }
788 fn forward(
789 &self,
790 input_ids: &Tensor,
791 seqlen_offsets: &[usize],
792 scalings: Tensor,
793 is_full_pass: bool,
794 no_kv_cache: bool,
795 is_scaling_pass: Option<f64>,
796 context_lens: &[usize],
797 flash_params: &FlashParams,
798 ) -> Result<Tensor> {
799 self.inner_forward(
801 input_ids,
802 seqlen_offsets,
803 context_lens,
804 Some(scalings),
805 is_full_pass,
806 no_kv_cache,
807 is_scaling_pass,
808 flash_params,
809 )
810 }
811}
812
813impl AnyMoeBaseModelMixin for Model {}