1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{self, Activation, Sdpa},
9 lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24 device_map::DeviceMapper,
25 layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26 models::phi3::Config,
27 pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35 qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36 o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37 num_heads: usize,
38 num_kv_heads: usize,
39 head_dim: usize,
40 rotary_emb: Arc<PhiRotaryEmbedding>,
41 sliding_window: Option<usize>,
42 sdpa_params: SdpaParams,
43}
44
45impl Attention {
46 #[allow(clippy::too_many_arguments)]
47 fn new(
48 rotary_emb: Arc<PhiRotaryEmbedding>,
49 cfg: &Config,
50 vb: ShardedVarBuilder,
51 lora_config: &[((String, String), LoraConfig)],
52 count: &mut usize,
53 ord: &Ordering,
54 mapper: &dyn DeviceMapper,
55 layer_idx: usize,
56 loading_isq: bool,
57 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58 ) -> Result<Self> {
59 let num_heads = cfg.num_attention_heads;
60 let num_kv_heads = cfg.num_key_value_heads;
61 let head_dim = cfg.head_dim();
62 let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63 let qkv_proj = linear_no_bias(
64 cfg.hidden_size,
65 op_size,
66 mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67 mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68 lora_config,
69 count,
70 ord,
71 preload_adapters,
72 )?;
73 let o_proj = linear_no_bias(
74 num_heads * head_dim,
75 cfg.hidden_size,
76 mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77 mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78 lora_config,
79 count,
80 ord,
81 preload_adapters,
82 )?;
83 Ok(Self {
84 qkv_proj,
85 o_proj,
86 rotary_emb,
87 num_heads,
88 num_kv_heads,
89 head_dim,
90 sliding_window: cfg.sliding_window,
91 sdpa_params: SdpaParams {
92 n_kv_groups: num_heads / num_kv_heads,
93 softcap: None,
94 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
95 sliding_window: cfg.sliding_window,
96 },
97 })
98 }
99
100 #[allow(clippy::too_many_arguments)]
101 fn forward(
102 &self,
103 xs: &Tensor,
104 attention_mask: Option<&Tensor>,
105 seqlen_offsets: &[usize],
106 position_ids: &[usize],
107 kv_cache: &mut Option<(Tensor, Tensor)>,
108 scalings: Option<Tensor>,
109 global_scaling_weight: f64,
110 is_scaling_pass: Option<f64>,
111 flash_params: &FlashParams,
112 ) -> Result<Tensor> {
113 let (b_sz, q_len, _) = xs.dims3()?;
114
115 let original_dtype = xs.dtype();
116 let mut xs = xs.clone();
117 if let Some(t) = self.qkv_proj.quantized_act_type() {
118 xs = xs.to_dtype(t)?;
119 }
120 let mut qkv = self.qkv_proj.lora_forward(
121 &xs,
122 scalings.clone(),
123 global_scaling_weight,
124 is_scaling_pass,
125 )?;
126 if self.qkv_proj.quantized_act_type().is_some() {
127 qkv = qkv.to_dtype(original_dtype)?;
128 }
129 let query_pos = self.num_heads * self.head_dim;
130 let q = qkv.narrow(D::Minus1, 0, query_pos)?;
131 let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
132 let v = qkv.narrow(
133 D::Minus1,
134 query_pos + self.num_kv_heads * self.head_dim,
135 self.num_kv_heads * self.head_dim,
136 )?;
137
138 let (q, k, v) = if q_len != 1 {
139 let q = q
140 .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
141 .transpose(1, 2)?;
142 let k = k
143 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
144 .transpose(1, 2)?;
145 let v = v
146 .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
147 .transpose(1, 2)?;
148 (q, k, v)
149 } else {
150 let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
151 let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
152 let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153 (q, k, v)
154 };
155
156 let (q, k) = self
157 .rotary_emb
158 .forward(&q, &k, seqlen_offsets, position_ids)?;
159
160 let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
161 kv_cache,
162 k,
163 v,
164 attention_mask,
165 self.sliding_window,
166 )?;
167
168 let mut attn_output = Sdpa.run_attention(
169 &q,
170 &k,
171 &v,
172 attn_mask.as_ref(),
173 Some(flash_params),
174 &self.sdpa_params,
175 )?;
176
177 if let Some(t) = self.qkv_proj.quantized_act_type() {
178 attn_output = attn_output.to_dtype(t)?;
179 }
180 let mut res = self.o_proj.lora_forward(
181 &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
182 scalings.clone(),
183 global_scaling_weight,
184 is_scaling_pass,
185 )?;
186 if self.qkv_proj.quantized_act_type().is_some() {
187 res = res.to_dtype(original_dtype)?;
188 }
189 Ok(res)
190 }
191}
192
193#[derive(Clone)]
194struct Mlp {
195 gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
196 down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
197 act_fn: Activation,
198 i_size: usize,
199}
200
201impl Mlp {
202 #[allow(clippy::too_many_arguments)]
203 fn new(
204 cfg: &Config,
205 vb: ShardedVarBuilder,
206 lora_config: &[((String, String), LoraConfig)],
207 count: &mut usize,
208 ord: &Ordering,
209 mapper: &dyn DeviceMapper,
210 layer_idx: usize,
211 loading_isq: bool,
212 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
213 ) -> Result<Self> {
214 let hidden_size = cfg.hidden_size;
215 let i_size = cfg.intermediate_size;
216 let gate_up_proj = linear_no_bias(
217 hidden_size,
218 2 * i_size,
219 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
220 mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
221 lora_config,
222 count,
223 ord,
224 preload_adapters,
225 )?;
226 let down_proj = linear_no_bias(
227 i_size,
228 hidden_size,
229 mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
230 mapper.set_device(layer_idx, vb.pp("down_proj"), false),
231 lora_config,
232 count,
233 ord,
234 preload_adapters,
235 )?;
236 Ok(Self {
237 gate_up_proj,
238 down_proj,
239 act_fn: cfg.hidden_act,
240 i_size,
241 })
242 }
243
244 fn forward(
245 &self,
246 xs: &Tensor,
247 scalings: Option<Tensor>,
248 global_scaling_weight: f64,
249 is_scaling_pass: Option<f64>,
250 ) -> Result<Tensor> {
251 let original_dtype = xs.dtype();
252 let mut xs = xs.clone();
253 if let Some(t) = self.gate_up_proj.quantized_act_type() {
254 xs = xs.to_dtype(t)?;
255 }
256 let up_states = self.gate_up_proj.lora_forward(
257 &xs,
258 scalings.clone(),
259 global_scaling_weight,
260 is_scaling_pass,
261 )?;
262 let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
263 let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
264 let up_states = (up_states * gate.apply(&self.act_fn))?;
265 let mut res = self.down_proj.lora_forward(
266 &up_states,
267 scalings,
268 global_scaling_weight,
269 is_scaling_pass,
270 )?;
271 if self.gate_up_proj.quantized_act_type().is_some() {
272 res = res.to_dtype(original_dtype)?;
273 }
274 Ok(res)
275 }
276}
277
278struct DecoderLayer {
279 self_attn: Attention,
280 mlp: Mlp,
281 input_layernorm: RmsNorm,
282 post_attention_layernorm: RmsNorm,
283}
284
285impl DecoderLayer {
286 #[allow(clippy::too_many_arguments)]
287 fn new(
288 rotary_emb: Arc<PhiRotaryEmbedding>,
289 cfg: &Config,
290 vb: ShardedVarBuilder,
291 lora_config: &[((String, String), LoraConfig)],
292 count: &mut usize,
293 ord: &Ordering,
294 mapper: &dyn DeviceMapper,
295 layer_idx: usize,
296 loading_isq: bool,
297 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
298 ) -> Result<Self> {
299 let self_attn = Attention::new(
300 rotary_emb,
301 cfg,
302 vb.pp("self_attn"),
303 lora_config,
304 count,
305 ord,
306 mapper,
307 layer_idx,
308 loading_isq,
309 preload_adapters,
310 )?;
311 let mlp = Mlp::new(
312 cfg,
313 vb.pp("mlp"),
314 lora_config,
315 count,
316 ord,
317 mapper,
318 layer_idx,
319 loading_isq,
320 preload_adapters,
321 )?;
322 let input_layernorm = RmsNorm::new(
323 cfg.hidden_size,
324 cfg.rms_norm_eps,
325 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
326 )?;
327 let post_attention_layernorm = RmsNorm::new(
328 cfg.hidden_size,
329 cfg.rms_norm_eps,
330 mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
331 )?;
332 Ok(Self {
333 self_attn,
334 mlp,
335 input_layernorm,
336 post_attention_layernorm,
337 })
338 }
339
340 #[allow(clippy::too_many_arguments)]
341 fn forward(
342 &self,
343 xs: &Tensor,
344 attention_mask: Option<&Tensor>,
345 seqlen_offsets: &[usize],
346 position_ids: &[usize],
347 kv_cache: &mut Option<(Tensor, Tensor)>,
348 scalings: Option<Tensor>,
349 global_scaling_weight: f64,
350 is_scaling_pass: Option<f64>,
351 flash_params: &FlashParams,
352 ) -> Result<Tensor> {
353 let residual = xs;
354 let xs = self.input_layernorm.forward(xs)?;
355 let xs = self.self_attn.forward(
356 &xs,
357 attention_mask,
358 seqlen_offsets,
359 position_ids,
360 kv_cache,
361 scalings.clone(),
362 global_scaling_weight,
363 is_scaling_pass,
364 flash_params,
365 )?;
366 let xs = (xs + residual)?;
367 let residual = &xs;
368 let xs = self.mlp.forward(
369 &xs.apply(&self.post_attention_layernorm)?,
370 scalings,
371 global_scaling_weight,
372 is_scaling_pass,
373 );
374 residual + xs
375 }
376}
377
378pub struct Model {
379 embed_tokens: candle_nn::Embedding,
380 layers: Vec<DecoderLayer>,
381 norm: RmsNorm,
382 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
383 dtype: DType,
384 device: Device,
385 cache: EitherCache,
386 max_seq_len: usize,
387 mapper: Box<dyn DeviceMapper + Send + Sync>,
388 xlora_classifier: Option<XLoraClassifier>,
389 sliding_window: Option<usize>,
390 cfg: ModelConfigMetadata,
391}
392
393impl Model {
394 #[allow(clippy::too_many_arguments)]
395 pub fn new(
396 cfg: &Config,
397 vb: ShardedVarBuilder,
398 lora_config: &[((String, String), LoraConfig)],
399 xlora_config: Option<XLoraConfig>,
400 xlora_ordering: Ordering,
401 _is_gptx: bool,
402 normal_loading_metadata: NormalLoadingMetadata,
403 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
404 ) -> Result<Self> {
405 if let Some(ref quant_cfg) = &cfg.quantization_config {
406 tracing::info!(
407 "Using {} quantization: {}.",
408 quant_cfg.name(),
409 quant_cfg.get_bits_name(&vb)
410 );
411 }
412 let mapper = normal_loading_metadata.mapper;
413 let vb_m = vb.pp("model");
414
415 let embed_tokens = layers::embedding(
416 cfg.vocab_size,
417 cfg.hidden_size,
418 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
419 &cfg.quantization_config,
420 )?;
421 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
422 let vb_l = vb_m.pp("layers");
423 let mut ropes = HashMap::new();
424 for layer_idx in 0..cfg.num_hidden_layers {
425 let device = mapper
426 .device_for(layer_idx, false)
427 .unwrap_or(&normal_loading_metadata.real_device);
428 ropes.insert(
429 device.location(),
430 Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
431 );
432 }
433 let mut count = 0;
434 for layer_idx in NiceProgressBar::<_, 'b'>(
435 0..cfg.num_hidden_layers,
436 "Loading repeating layers",
437 &normal_loading_metadata.multi_progress,
438 ) {
439 let device = mapper
440 .device_for(layer_idx, false)
441 .unwrap_or(&normal_loading_metadata.real_device);
442 let rotary_emb = ropes
443 .get(&device.location())
444 .expect("No RoPE for device location!")
445 .clone();
446 let layer = DecoderLayer::new(
447 rotary_emb.clone(),
448 cfg,
449 vb_l.pp(layer_idx),
450 lora_config,
451 &mut count,
452 &xlora_ordering,
453 &*mapper,
454 layer_idx,
455 normal_loading_metadata.loading_isq,
456 preload_adapters,
457 )?;
458 layers.push(layer)
459 }
460 if xlora_config.is_none() && preload_adapters.is_none() {
461 info!("Merging LoRA adapters.");
463 for layer in layers.iter_mut().tqdm() {
464 Arc::get_mut(&mut layer.self_attn.qkv_proj)
465 .unwrap()
466 .merge_weights()?;
467 Arc::get_mut(&mut layer.self_attn.o_proj)
468 .unwrap()
469 .merge_weights()?;
470
471 Arc::get_mut(&mut layer.mlp.down_proj)
472 .unwrap()
473 .merge_weights()?;
474 Arc::get_mut(&mut layer.mlp.gate_up_proj)
475 .unwrap()
476 .merge_weights()?;
477 }
478 }
479 let norm = RmsNorm::new(
480 cfg.hidden_size,
481 cfg.rms_norm_eps,
482 mapper.set_nm_device(vb_m.pp("norm"), false),
483 )?;
484 let lm_head = linear_no_bias(
485 cfg.hidden_size,
486 cfg.vocab_size,
487 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
488 mapper.set_nm_device(vb.pp("lm_head"), false),
489 lora_config,
490 &mut count,
491 &xlora_ordering,
492 preload_adapters,
493 )?;
494 if xlora_config.is_some() && lm_head.is_lora() {
495 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
497 }
498 Ok(Self {
499 embed_tokens,
500 layers,
501 norm,
502 lm_head,
503 device: normal_loading_metadata.real_device,
504 dtype: vb.dtype(),
505 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
506 max_seq_len: cfg.max_position_embeddings,
507 mapper,
508 xlora_classifier: xlora_config.map(|xlora_config| {
509 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
510 }),
511 sliding_window: cfg.sliding_window,
512 cfg: ModelConfigMetadata {
513 max_seq_len: cfg.max_position_embeddings,
514 num_layers: cfg.num_hidden_layers,
515 hidden_size: cfg.hidden_size,
516 num_kv_heads: cfg.num_key_value_heads,
517 num_attn_heads: cfg.num_attention_heads,
518 sliding_window: cfg.sliding_window,
519 k_head_dim: cfg.head_dim(),
520 v_head_dim: cfg.head_dim(),
521 },
522 })
523 }
524
525 #[allow(clippy::too_many_arguments)]
526 fn inner_forward(
527 &self,
528 input_ids: &Tensor,
529 seqlen_offsets: &[usize],
530 position_ids: &[usize],
531 scalings: Option<Tensor>,
532 is_full_pass: bool,
533 no_kv_cache: bool,
534 is_scaling_pass: Option<f64>,
535 flash_params: &FlashParams,
536 ) -> Result<Tensor> {
537 let mut xs = self.embed_tokens.forward(input_ids)?;
538 let mut cache = if is_full_pass {
539 if no_kv_cache {
540 let mut new_cache = Vec::new();
541 for _ in 0..self.cache.full().xlora_lock().len() {
542 new_cache.push(None);
543 }
544
545 self.cache.full().xlora_lock().clone_from(&new_cache);
546 }
547 self.cache.full().xlora_lock()
548 } else {
549 self.cache.full().lock()
550 };
551 let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
552 input_ids,
553 &*cache,
554 self.sliding_window,
555 xs.dtype(),
556 self.cfg.num_attn_heads,
557 )?;
558
559 for (i, layer) in self.layers.iter().enumerate() {
560 xs = self.mapper.map(xs, i)?;
561 xs = layer.forward(
562 &xs,
563 attention_mask
564 .as_ref()
565 .map(|m| m.to_device(xs.device()).unwrap())
566 .as_ref(),
567 seqlen_offsets,
568 position_ids,
569 &mut cache[i],
570 scalings.clone(),
571 self.xlora_classifier
572 .as_ref()
573 .map(|classifier| classifier.get_global_scaling_weight())
574 .unwrap_or(1.0),
575 is_scaling_pass,
576 flash_params,
577 )?
578 }
579 let xs = xs.to_device(&self.device)?;
580 xs.apply(&self.norm)
581 }
582
583 #[allow(clippy::too_many_arguments)]
584 pub fn forward(
585 &self,
586 input_ids: &Tensor,
587 input_ids_full: &Tensor,
588 seqlen_offsets: &[usize],
589 seqlen_offsets_full: &[usize],
590 no_kv_cache: bool,
591 non_granular_state: &Option<NonGranularState>,
592 context_lens: Vec<(usize, usize)>,
593 position_ids: Vec<usize>,
594 flash_params: &FlashParams,
595 flash_params_full: &FlashParams,
596 ) -> Result<Tensor> {
597 if self.xlora_classifier.is_some() {
598 let scalings = self.get_scalings(
599 input_ids,
600 input_ids_full,
601 seqlen_offsets,
602 seqlen_offsets_full,
603 no_kv_cache,
604 non_granular_state,
605 &position_ids,
606 flash_params,
607 flash_params_full,
608 )?;
609
610 if no_kv_cache {
611 let mut res = self
612 .inner_forward(
613 input_ids_full,
614 seqlen_offsets_full,
615 &position_ids,
616 Some(scalings),
617 true,
618 no_kv_cache,
619 None,
620 flash_params_full,
621 )?
622 .contiguous()?;
623 if let Some(t) = self.lm_head.quantized_act_type() {
624 res = res.to_dtype(t)?;
625 }
626 extract_logits(
627 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
628 context_lens,
629 )
630 } else {
631 let mut res = self
633 .inner_forward(
634 input_ids,
635 seqlen_offsets,
636 &position_ids,
637 Some(scalings),
638 true,
639 no_kv_cache,
640 None,
641 flash_params,
642 )?
643 .contiguous()?;
644 if let Some(t) = self.lm_head.quantized_act_type() {
645 res = res.to_dtype(t)?;
646 }
647 extract_logits(
648 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
649 context_lens,
650 )
651 }
652 } else {
653 let mut res = self
654 .inner_forward(
655 input_ids,
656 seqlen_offsets,
657 &position_ids,
658 None,
659 false,
660 no_kv_cache,
661 None,
662 flash_params,
663 )?
664 .contiguous()?;
665 if let Some(t) = self.lm_head.quantized_act_type() {
666 res = res.to_dtype(t)?;
667 }
668 extract_logits(
669 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
670 context_lens,
671 )
672 }
673 }
674}
675
676impl IsqModel for Model {
677 fn get_layers(
678 &mut self,
679 ) -> (
680 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
681 &dyn DeviceMapper,
682 ) {
683 let mut tensors = Vec::new();
684 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
685 for (i, layer) in self.layers.iter_mut().enumerate() {
686 tensors.push((
687 Arc::get_mut(&mut layer.self_attn.qkv_proj)
688 .unwrap()
689 .quant_inner(),
690 Some(i),
691 ));
692 tensors.push((
693 Arc::get_mut(&mut layer.self_attn.o_proj)
694 .unwrap()
695 .quant_inner(),
696 Some(i),
697 ));
698 tensors.push((
699 Arc::get_mut(&mut layer.mlp.gate_up_proj)
700 .unwrap()
701 .quant_inner(),
702 Some(i),
703 ));
704 tensors.push((
705 Arc::get_mut(&mut layer.mlp.down_proj)
706 .unwrap()
707 .quant_inner(),
708 Some(i),
709 ));
710 }
711 (tensors, &*self.mapper)
712 }
713
714 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
715 panic!("Cannot generate UQFF for an adapter model.")
716 }
717}
718
719impl NormalModel for Model {
720 fn forward(
721 &self,
722 _input_ids: &Tensor,
723 _seqlen_offsets: &[usize],
724 _context_lens: Vec<(usize, usize)>,
725 _position_ids: Vec<usize>,
726 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
727 _flash_params: &FlashParams,
728 ) -> Result<Tensor> {
729 unreachable!()
730 }
731 fn xlora_forward(
732 &self,
733 input_ids: &Tensor,
734 input_ids_full: &Tensor,
735 seqlen_offsets: &[usize],
736 seqlen_offsets_full: &[usize],
737 no_kv_cache: bool,
738 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
739 context_lens: Vec<(usize, usize)>,
740 position_ids: Vec<usize>,
741 flash_params: &FlashParams,
742 flash_params_full: &FlashParams,
743 ) -> Result<Tensor> {
744 self.forward(
745 input_ids,
746 input_ids_full,
747 seqlen_offsets,
748 seqlen_offsets_full,
749 no_kv_cache,
750 non_granular_state,
751 context_lens,
752 position_ids,
753 flash_params,
754 flash_params_full,
755 )
756 }
757 fn cache(&self) -> &EitherCache {
758 &self.cache
759 }
760 fn cache_mut(&mut self) -> &mut EitherCache {
761 &mut self.cache
762 }
763 fn device(&self) -> &Device {
764 &self.device
765 }
766 fn is_xlora(&self) -> bool {
767 true
768 }
769 fn max_seq_len(&self) -> usize {
770 self.max_seq_len
771 }
772 fn config(&self) -> &ModelConfigMetadata {
773 &self.cfg
774 }
775}
776
777impl ScalingsMaker for Model {
778 fn dtype(&self) -> DType {
779 self.dtype
780 }
781 fn get_cache(&self) -> &EitherCache {
782 &self.cache
783 }
784 fn get_classifier(&self) -> &XLoraClassifier {
785 self.xlora_classifier.as_ref().unwrap()
786 }
787 fn forward(
788 &self,
789 input_ids: &Tensor,
790 seqlen_offsets: &[usize],
791 scalings: Tensor,
792 is_full_pass: bool,
793 no_kv_cache: bool,
794 is_scaling_pass: Option<f64>,
795 context_lens: &[usize],
796 flash_params: &FlashParams,
797 ) -> Result<Tensor> {
798 self.inner_forward(
800 input_ids,
801 seqlen_offsets,
802 context_lens,
803 Some(scalings),
804 is_full_pass,
805 no_kv_cache,
806 is_scaling_pass,
807 flash_params,
808 )
809 }
810}
811
812impl AnyMoeBaseModelMixin for Model {}