1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6 amoe::AnyMoeBaseModelMixin,
7 attention::SdpaParams,
8 layers::{Activation, RotaryEmbedding, Sdpa},
9 lora::{linear, LinearLayerLike, LoraConfig, Ordering},
10 paged_attention::ModelConfigMetadata,
11 pipeline::{
12 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13 EitherCache, IsqModel, NormalLoadingMetadata,
14 },
15 utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Result, Tensor};
23use candle_nn::{Embedding, LayerNorm};
24use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
25use tqdm::Iter;
26use tracing::info;
27
28use crate::{
29 device_map::DeviceMapper,
30 layers::{embedding, layer_norm, CausalMasker},
31 models::phi2::Config,
32 pipeline::{extract_logits, NormalModel},
33};
34
35use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig};
36
37#[derive(Clone)]
38#[allow(clippy::upper_case_acronyms)]
39struct MLP {
40 fc1: Arc<dyn LinearLayerLike + Send + Sync>,
41 fc2: Arc<dyn LinearLayerLike + Send + Sync>,
42 act: Activation,
43}
44
45impl MLP {
46 #[allow(clippy::too_many_arguments)]
47 fn new(
48 cfg: &Config,
49 vb: ShardedVarBuilder,
50 lora_config: &[((String, String), LoraConfig)],
51 count: &mut usize,
52 ord: &Ordering,
53 mapper: &dyn DeviceMapper,
54 layer_idx: usize,
55 loading_isq: bool,
56 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
57 ) -> Result<Self> {
58 let fc1 = linear(
59 cfg.hidden_size,
60 cfg.intermediate_size,
61 mapper.set_device(layer_idx, vb.pp("fc1"), loading_isq),
62 mapper.set_device(layer_idx, vb.pp("fc1"), false),
63 lora_config,
64 count,
65 ord,
66 preload_adapters,
67 )?;
68 let fc2 = linear(
69 cfg.intermediate_size,
70 cfg.hidden_size,
71 mapper.set_device(layer_idx, vb.pp("fc2"), loading_isq),
72 mapper.set_device(layer_idx, vb.pp("fc2"), false),
73 lora_config,
74 count,
75 ord,
76 preload_adapters,
77 )?;
78 Ok(Self {
79 fc1,
80 fc2,
81 act: cfg.hidden_act,
84 })
85 }
86
87 fn forward(
88 &self,
89 xs: &Tensor,
90 scalings: Option<Tensor>,
91 global_scaling_weight: f64,
92 is_scaling_pass: Option<f64>,
93 ) -> Result<Tensor> {
94 let original_dtype = xs.dtype();
95 let mut xs = xs.clone();
96 if let Some(t) = self.fc1.quantized_act_type() {
97 xs = xs.to_dtype(t)?;
98 }
99 let mut res = self.fc2.lora_forward(
100 &self
101 .fc1
102 .lora_forward(
103 &xs,
104 scalings.clone(),
105 global_scaling_weight,
106 is_scaling_pass,
107 )?
108 .apply(&self.act)?,
109 scalings,
110 global_scaling_weight,
111 is_scaling_pass,
112 )?;
113 if self.fc1.quantized_act_type().is_some() {
114 res = res.to_dtype(original_dtype)?;
115 }
116 Ok(res)
117 }
118}
119
120struct Attention {
121 q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
122 k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
123 v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
124 dense: Arc<dyn LinearLayerLike + Send + Sync>,
125 q_layernorm: Option<LayerNorm>,
126 k_layernorm: Option<LayerNorm>,
127 rotary_emb: Arc<RotaryEmbedding>,
128 num_heads: usize,
129 num_kv_heads: usize,
130 head_dim: usize,
131 sdpa_params: SdpaParams,
132}
133
134impl Attention {
135 #[allow(clippy::too_many_arguments)]
136 fn new(
137 cfg: &Config,
138 vb: ShardedVarBuilder,
139 lora_config: &[((String, String), LoraConfig)],
140 count: &mut usize,
141 ord: &Ordering,
142 mapper: &dyn DeviceMapper,
143 layer_idx: usize,
144 loading_isq: bool,
145 rope: Arc<RotaryEmbedding>,
146 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
147 ) -> Result<Self> {
148 let num_heads = cfg.num_attention_heads;
149 let num_kv_heads = cfg.num_key_value_heads();
150 let head_dim = cfg.head_dim();
151 let q_proj = linear(
152 cfg.hidden_size,
153 num_heads * head_dim,
154 mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
155 mapper.set_device(layer_idx, vb.pp("q_proj"), false),
156 lora_config,
157 count,
158 ord,
159 preload_adapters,
160 )?;
161 let k_proj = linear(
162 cfg.hidden_size,
163 num_kv_heads * head_dim,
164 mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
165 mapper.set_device(layer_idx, vb.pp("k_proj"), false),
166 lora_config,
167 count,
168 ord,
169 preload_adapters,
170 )?;
171 let v_proj = linear(
172 cfg.hidden_size,
173 num_kv_heads * head_dim,
174 mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
175 mapper.set_device(layer_idx, vb.pp("v_proj"), false),
176 lora_config,
177 count,
178 ord,
179 preload_adapters,
180 )?;
181 let dense = linear(
182 num_heads * head_dim,
183 cfg.hidden_size,
184 mapper.set_device(layer_idx, vb.pp("dense"), loading_isq),
185 mapper.set_device(layer_idx, vb.pp("dense"), false),
186 lora_config,
187 count,
188 ord,
189 preload_adapters,
190 )?;
191 let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
192 let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
193 let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
194 (Some(q_layernorm), Some(k_layernorm))
195 } else {
196 (None, None)
197 };
198 Ok(Self {
199 q_proj,
200 k_proj,
201 v_proj,
202 dense,
203 q_layernorm,
204 k_layernorm,
205 rotary_emb: rope,
206 num_heads,
207 num_kv_heads,
208 head_dim,
209 sdpa_params: SdpaParams {
210 n_kv_groups: num_heads / num_kv_heads,
211 use_flash_attn: cfg.use_flash_attn,
212 softcap: None,
213 softmax_scale: 1.0 / (head_dim as f32).sqrt(),
214 sliding_window: None,
215 },
216 })
217 }
218
219 #[allow(clippy::too_many_arguments)]
220 fn forward(
221 &self,
222 xs: &Tensor,
223 mask: Option<&Tensor>,
224 seqlen_offsets: &[usize],
225 kv_cache: &mut Option<(Tensor, Tensor)>,
226 scalings: Option<Tensor>,
227 global_scaling_weight: f64,
228 is_scaling_pass: Option<f64>,
229 flash_params: &FlashParams,
230 ) -> Result<Tensor> {
231 let (b_size, seq_len, _n_embd) = xs.dims3()?;
232 let original_dtype = xs.dtype();
233 let mut xs = xs.clone();
234 if let Some(t) = self.q_proj.quantized_act_type() {
235 xs = xs.to_dtype(t)?;
236 }
237 let mut q = self.q_proj.lora_forward(
238 &xs,
239 scalings.clone(),
240 global_scaling_weight,
241 is_scaling_pass,
242 )?;
243 let mut k = self.k_proj.lora_forward(
244 &xs,
245 scalings.clone(),
246 global_scaling_weight,
247 is_scaling_pass,
248 )?;
249 let mut v = self.v_proj.lora_forward(
250 &xs,
251 scalings.clone(),
252 global_scaling_weight,
253 is_scaling_pass,
254 )?;
255 if self.q_proj.quantized_act_type().is_some() {
256 q = q.to_dtype(original_dtype)?;
257 k = k.to_dtype(original_dtype)?;
258 v = v.to_dtype(original_dtype)?;
259 }
260
261 let q = match &self.q_layernorm {
262 None => q,
263 Some(ln) => q.apply(ln)?,
264 };
265 let k = match &self.k_layernorm {
266 None => k,
267 Some(ln) => k.apply(ln)?,
268 };
269
270 let (q, k, v) = if seq_len != 1 {
271 let q = q
272 .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
273 .transpose(1, 2)?;
274 let k = k
275 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
276 .transpose(1, 2)?;
277 let v = v
278 .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
279 .transpose(1, 2)?;
280 (q, k, v)
281 } else {
282 let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
283 let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
284 let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
285 (q, k, v)
286 };
287
288 let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
289
290 let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
291
292 let attn_output =
293 Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?;
294
295 let mut attn_output = attn_output
296 .transpose(1, 2)?
297 .reshape((b_size, seq_len, ()))?;
298 if let Some(t) = self.q_proj.quantized_act_type() {
299 attn_output = attn_output.to_dtype(t)?;
300 }
301 let mut res = self.dense.lora_forward(
302 &attn_output,
303 scalings,
304 global_scaling_weight,
305 is_scaling_pass,
306 )?;
307 if self.q_proj.quantized_act_type().is_some() {
308 res = res.to_dtype(original_dtype)?;
309 }
310 Ok(res)
311 }
312}
313
314struct DecoderLayer {
315 self_attn: Attention,
316 mlp: MLP,
317 input_layernorm: LayerNorm,
318}
319
320impl DecoderLayer {
321 #[allow(clippy::too_many_arguments)]
322 fn new(
323 cfg: &Config,
324 vb: ShardedVarBuilder,
325 lora_config: &[((String, String), LoraConfig)],
326 count: &mut usize,
327 ord: &Ordering,
328 mapper: &dyn DeviceMapper,
329 layer_idx: usize,
330 loading_isq: bool,
331 rope: Arc<RotaryEmbedding>,
332 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
333 ) -> Result<Self> {
334 let self_attn = Attention::new(
335 cfg,
336 vb.pp("self_attn"),
337 lora_config,
338 count,
339 ord,
340 mapper,
341 layer_idx,
342 loading_isq,
343 rope,
344 preload_adapters,
345 )?;
346 let mlp = MLP::new(
347 cfg,
348 vb.pp("mlp"),
349 lora_config,
350 count,
351 ord,
352 mapper,
353 layer_idx,
354 loading_isq,
355 preload_adapters,
356 )?;
357 let input_layernorm = layer_norm(
358 cfg.hidden_size,
359 cfg.layer_norm_eps,
360 mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
361 )?;
362 Ok(Self {
363 self_attn,
364 mlp,
365 input_layernorm,
366 })
367 }
368
369 #[allow(clippy::too_many_arguments)]
370 fn forward(
371 &self,
372 xs: &Tensor,
373 mask: Option<&Tensor>,
374 seqlen_offsets: &[usize],
375 kv_cache: &mut Option<(Tensor, Tensor)>,
376 scalings: Option<Tensor>,
377 global_scaling_weight: f64,
378 is_scaling_pass: Option<f64>,
379 flash_params: &FlashParams,
380 ) -> Result<Tensor> {
381 let residual = xs;
382 let xs = xs.apply(&self.input_layernorm)?;
383 let attn_outputs = self.self_attn.forward(
384 &xs,
385 mask,
386 seqlen_offsets,
387 kv_cache,
388 scalings.clone(),
389 global_scaling_weight,
390 is_scaling_pass,
391 flash_params,
392 )?;
393 let feed_forward_hidden_states =
394 self.mlp
395 .forward(&xs, scalings, global_scaling_weight, is_scaling_pass)?;
396 attn_outputs + feed_forward_hidden_states + residual
397 }
398}
399
400pub struct Model {
401 embed_tokens: Embedding,
402 layers: Vec<DecoderLayer>,
403 final_layernorm: LayerNorm,
404 lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
405 cache: EitherCache,
406 device: Device,
407 max_seq_len: usize,
408 xlora_classifier: Option<XLoraClassifier>,
409 dtype: DType,
410 mapper: Box<dyn DeviceMapper + Send + Sync>,
411 cfg: ModelConfigMetadata,
412}
413
414impl Model {
415 #[allow(clippy::too_many_arguments)]
416 pub fn new(
417 cfg: &Config,
418 vb: ShardedVarBuilder,
419 lora_config: &[((String, String), LoraConfig)],
420 xlora_config: Option<XLoraConfig>,
421 xlora_ordering: Ordering,
422 is_gptx: bool,
423 normal_loading_metadata: NormalLoadingMetadata,
424 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
425 ) -> Result<Self> {
426 if let Some(ref quant_cfg) = &cfg.quantization_config {
427 tracing::info!(
428 "Using {} quantization: {}.",
429 quant_cfg.name(),
430 quant_cfg.get_bits_name(&vb)
431 );
432 }
433 let mapper = normal_loading_metadata.mapper;
434 let vb_m = vb.pp("model");
435
436 let embed_tokens = embedding(
437 cfg.vocab_size,
438 cfg.hidden_size,
439 mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
440 &cfg.quantization_config,
441 )?;
442 let final_layernorm = layer_norm(
443 cfg.hidden_size,
444 cfg.layer_norm_eps,
445 mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
446 )?;
447 let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
448 let vb_m = vb_m.pp("layers");
449 let mut ropes = HashMap::new();
450 for layer_idx in 0..cfg.num_hidden_layers {
451 let device = mapper
452 .device_for(layer_idx, false)
453 .unwrap_or(&normal_loading_metadata.real_device);
454 ropes.insert(
456 device.location(),
457 Arc::new(RotaryEmbedding::new_partial(
458 cfg.rope_theta,
459 (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
460 cfg.max_position_embeddings,
461 device,
462 is_gptx,
463 vb.dtype(),
464 )?),
465 );
466 }
467 let mut count = 0;
468 for layer_idx in NiceProgressBar::<_, 'b'>(
469 0..cfg.num_hidden_layers,
470 "Loading repeating layers",
471 &normal_loading_metadata.multi_progress,
472 ) {
473 let device = mapper
474 .device_for(layer_idx, false)
475 .unwrap_or(&normal_loading_metadata.real_device);
476 let rotary_emb = ropes
477 .get(&device.location())
478 .expect("No RoPE for device location!")
479 .clone();
480 let layer = DecoderLayer::new(
481 cfg,
482 vb_m.pp(layer_idx),
483 lora_config,
484 &mut count,
485 &xlora_ordering,
486 &*mapper,
487 layer_idx,
488 normal_loading_metadata.loading_isq,
489 rotary_emb,
490 preload_adapters,
491 )?;
492 layers.push(layer)
493 }
494 if xlora_config.is_none() && preload_adapters.is_none() {
495 info!("Merging LoRA adapters.");
497 for layer in layers.iter_mut().tqdm() {
498 Arc::get_mut(&mut layer.self_attn.k_proj)
499 .unwrap()
500 .merge_weights()?;
501 Arc::get_mut(&mut layer.self_attn.dense)
502 .unwrap()
503 .merge_weights()?;
504 Arc::get_mut(&mut layer.self_attn.q_proj)
505 .unwrap()
506 .merge_weights()?;
507 Arc::get_mut(&mut layer.self_attn.v_proj)
508 .unwrap()
509 .merge_weights()?;
510
511 Arc::get_mut(&mut layer.mlp.fc1).unwrap().merge_weights()?;
512 Arc::get_mut(&mut layer.mlp.fc2).unwrap().merge_weights()?;
513 }
514 }
515 let lm_head = linear(
516 cfg.hidden_size,
517 cfg.vocab_size,
518 mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
519 mapper.set_nm_device(vb.pp("lm_head"), false),
520 lora_config,
521 &mut count,
522 &xlora_ordering,
523 preload_adapters,
524 )?;
525 if xlora_config.is_some() && lm_head.is_lora() {
526 candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
528 }
529 Ok(Self {
530 embed_tokens,
531 layers,
532 final_layernorm,
533 lm_head,
534 cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
535 device: normal_loading_metadata.real_device,
536 max_seq_len: cfg.max_position_embeddings,
537 dtype: vb.dtype(),
538 xlora_classifier: xlora_config.map(|xlora_config| {
539 XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
540 }),
541 mapper,
542 cfg: ModelConfigMetadata {
543 max_seq_len: cfg.max_position_embeddings,
544 num_layers: cfg.num_hidden_layers,
545 hidden_size: cfg.hidden_size,
546 num_kv_heads: cfg.num_key_value_heads(),
547 num_attn_heads: cfg.num_attention_heads,
548 sliding_window: None,
549 k_head_dim: cfg.head_dim(),
550 v_head_dim: cfg.head_dim(),
551 },
552 })
553 }
554
555 #[allow(clippy::too_many_arguments)]
556 fn inner_forward(
557 &self,
558 input_ids: &Tensor,
559 seqlen_offsets: &[usize],
560 scalings: Option<Tensor>,
561 is_full_pass: bool,
562 no_kv_cache: bool,
563 is_scaling_pass: Option<f64>,
564 flash_params: &FlashParams,
565 ) -> Result<Tensor> {
566 let mut xs = input_ids.apply(&self.embed_tokens)?;
567 let mut cache = if is_full_pass {
568 if no_kv_cache {
569 let mut new_cache = Vec::new();
570 for _ in 0..self.cache.full().xlora_lock().len() {
571 new_cache.push(None);
572 }
573
574 self.cache.full().xlora_lock().clone_from(&new_cache);
575 }
576 self.cache.full().xlora_lock()
577 } else {
578 self.cache.full().lock()
579 };
580 let mask = CausalMasker.make_causal_mask_matrix(
581 input_ids,
582 &*cache,
583 xs.dtype(),
584 self.cfg.num_attn_heads,
585 )?;
586 for (i, layer) in self.layers.iter().enumerate() {
587 xs = self.mapper.map(xs, i)?;
588 xs = layer.forward(
589 &xs,
590 mask.as_ref()
591 .map(|m| m.to_device(xs.device()).unwrap())
592 .as_ref(),
593 seqlen_offsets,
594 &mut cache[i],
595 scalings.clone(),
596 self.xlora_classifier
597 .as_ref()
598 .map(|classifier| classifier.get_global_scaling_weight())
599 .unwrap_or(1.0),
600 is_scaling_pass,
601 flash_params,
602 )?;
603 }
604 let xs = xs.to_device(&self.device)?;
605 xs.apply(&self.final_layernorm)
606 }
607
608 #[allow(clippy::too_many_arguments)]
609 pub fn forward(
610 &self,
611 input_ids: &Tensor,
612 input_ids_full: &Tensor,
613 seqlen_offsets: &[usize],
614 seqlen_offsets_full: &[usize],
615 no_kv_cache: bool,
616 non_granular_state: &Option<NonGranularState>,
617 context_lens: Vec<(usize, usize)>,
618 flash_params: &FlashParams,
619 flash_params_full: &FlashParams,
620 ) -> Result<Tensor> {
621 if self.xlora_classifier.is_some() {
622 let scalings = self.get_scalings(
623 input_ids,
624 input_ids_full,
625 seqlen_offsets,
626 seqlen_offsets_full,
627 no_kv_cache,
628 non_granular_state,
629 &vec![usize::MAX; context_lens.len()],
630 flash_params,
631 flash_params_full,
632 )?;
633
634 if no_kv_cache {
635 let mut res = self
636 .inner_forward(
637 input_ids_full,
638 seqlen_offsets_full,
639 Some(scalings),
640 true,
641 no_kv_cache,
642 None,
643 flash_params_full,
644 )?
645 .contiguous()?;
646 if let Some(t) = self.lm_head.quantized_act_type() {
647 res = res.to_dtype(t)?;
648 }
649 extract_logits(
650 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
651 context_lens,
652 )
653 } else {
654 let mut res = self
656 .inner_forward(
657 input_ids,
658 seqlen_offsets,
659 Some(scalings),
660 true,
661 no_kv_cache,
662 None,
663 flash_params,
664 )?
665 .contiguous()?;
666 if let Some(t) = self.lm_head.quantized_act_type() {
667 res = res.to_dtype(t)?;
668 }
669 extract_logits(
670 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
671 context_lens,
672 )
673 }
674 } else {
675 let mut res = self
676 .inner_forward(
677 input_ids,
678 seqlen_offsets,
679 None,
680 false,
681 no_kv_cache,
682 None,
683 flash_params,
684 )?
685 .contiguous()?;
686 if let Some(t) = self.lm_head.quantized_act_type() {
687 res = res.to_dtype(t)?;
688 }
689 extract_logits(
690 &self.lm_head.lora_forward(&res, None, 1.0, None)?,
691 context_lens,
692 )
693 }
694 }
695}
696
697impl IsqModel for Model {
698 fn get_layers(
699 &mut self,
700 ) -> (
701 Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
702 &dyn DeviceMapper,
703 ) {
704 let mut tensors = Vec::new();
705 tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
706 for (i, layer) in self.layers.iter_mut().enumerate() {
707 tensors.push((
708 Arc::get_mut(&mut layer.self_attn.q_proj)
709 .unwrap()
710 .quant_inner(),
711 Some(i),
712 ));
713 tensors.push((
714 Arc::get_mut(&mut layer.self_attn.k_proj)
715 .unwrap()
716 .quant_inner(),
717 Some(i),
718 ));
719 tensors.push((
720 Arc::get_mut(&mut layer.self_attn.v_proj)
721 .unwrap()
722 .quant_inner(),
723 Some(i),
724 ));
725 tensors.push((
726 Arc::get_mut(&mut layer.self_attn.dense)
727 .unwrap()
728 .quant_inner(),
729 Some(i),
730 ));
731 tensors.push((
732 Arc::get_mut(&mut layer.mlp.fc1).unwrap().quant_inner(),
733 Some(i),
734 ));
735 tensors.push((
736 Arc::get_mut(&mut layer.mlp.fc2).unwrap().quant_inner(),
737 Some(i),
738 ));
739 }
740 (tensors, &*self.mapper)
741 }
742
743 fn residual_tensors(&self) -> Vec<(String, Tensor)> {
744 panic!("Cannot generate UQFF for an adapter model.")
745 }
746}
747
748impl NormalModel for Model {
749 fn forward(
750 &self,
751 _input_ids: &Tensor,
752 _seqlen_offsets: &[usize],
753 _context_lens: Vec<(usize, usize)>,
754 _position_ids: Vec<usize>,
755 _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
756 _flash_params: &FlashParams,
757 ) -> Result<Tensor> {
758 unreachable!()
759 }
760 fn xlora_forward(
761 &self,
762 input_ids: &Tensor,
763 input_ids_full: &Tensor,
764 seqlen_offsets: &[usize],
765 seqlen_offsets_full: &[usize],
766 no_kv_cache: bool,
767 non_granular_state: &Option<crate::xlora_models::NonGranularState>,
768 context_lens: Vec<(usize, usize)>,
769 _position_ids: Vec<usize>,
770 flash_params: &FlashParams,
771 flash_params_full: &FlashParams,
772 ) -> Result<Tensor> {
773 self.forward(
774 input_ids,
775 input_ids_full,
776 seqlen_offsets,
777 seqlen_offsets_full,
778 no_kv_cache,
779 non_granular_state,
780 context_lens,
781 flash_params,
782 flash_params_full,
783 )
784 }
785 fn cache(&self) -> &EitherCache {
786 &self.cache
787 }
788 fn cache_mut(&mut self) -> &mut EitherCache {
789 &mut self.cache
790 }
791 fn device(&self) -> &Device {
792 &self.device
793 }
794 fn is_xlora(&self) -> bool {
795 true
796 }
797 fn max_seq_len(&self) -> usize {
798 self.max_seq_len
799 }
800 fn config(&self) -> &ModelConfigMetadata {
801 &self.cfg
802 }
803}
804
805impl ScalingsMaker for Model {
806 fn dtype(&self) -> DType {
807 self.dtype
808 }
809 fn get_cache(&self) -> &EitherCache {
810 &self.cache
811 }
812 fn get_classifier(&self) -> &XLoraClassifier {
813 self.xlora_classifier.as_ref().unwrap()
814 }
815 fn forward(
816 &self,
817 input_ids: &Tensor,
818 seqlen_offsets: &[usize],
819 scalings: Tensor,
820 is_full_pass: bool,
821 no_kv_cache: bool,
822 is_scaling_pass: Option<f64>,
823 _context_lens: &[usize],
824 flash_params: &FlashParams,
825 ) -> Result<Tensor> {
826 self.inner_forward(
827 input_ids,
828 seqlen_offsets,
829 Some(scalings),
830 is_full_pass,
831 no_kv_cache,
832 is_scaling_pass,
833 flash_params,
834 )
835 }
836}
837
838impl AnyMoeBaseModelMixin for Model {}