mistralrs_core/xlora_models/
phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use crate::{
6    amoe::AnyMoeBaseModelMixin,
7    attention::SdpaParams,
8    layers::{self, Activation, Sdpa},
9    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10    paged_attention::ModelConfigMetadata,
11    pipeline::{
12        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13        EitherCache, IsqModel, NormalLoadingMetadata,
14    },
15    utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24    device_map::DeviceMapper,
25    layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26    models::phi3::Config,
27    pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35    qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    num_heads: usize,
38    num_kv_heads: usize,
39    head_dim: usize,
40    rotary_emb: Arc<PhiRotaryEmbedding>,
41    sliding_window: Option<usize>,
42    sdpa_params: SdpaParams,
43}
44
45impl Attention {
46    #[allow(clippy::too_many_arguments)]
47    fn new(
48        rotary_emb: Arc<PhiRotaryEmbedding>,
49        cfg: &Config,
50        vb: ShardedVarBuilder,
51        lora_config: &[((String, String), LoraConfig)],
52        count: &mut usize,
53        ord: &Ordering,
54        mapper: &dyn DeviceMapper,
55        layer_idx: usize,
56        loading_isq: bool,
57        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58    ) -> Result<Self> {
59        let num_heads = cfg.num_attention_heads;
60        let num_kv_heads = cfg.num_key_value_heads;
61        let head_dim = cfg.head_dim();
62        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63        let qkv_proj = linear_no_bias(
64            cfg.hidden_size,
65            op_size,
66            mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67            mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68            lora_config,
69            count,
70            ord,
71            preload_adapters,
72        )?;
73        let o_proj = linear_no_bias(
74            num_heads * head_dim,
75            cfg.hidden_size,
76            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78            lora_config,
79            count,
80            ord,
81            preload_adapters,
82        )?;
83        Ok(Self {
84            qkv_proj,
85            o_proj,
86            rotary_emb,
87            num_heads,
88            num_kv_heads,
89            head_dim,
90            sliding_window: cfg.sliding_window,
91            sdpa_params: SdpaParams {
92                n_kv_groups: num_heads / num_kv_heads,
93                softcap: None,
94                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
95                sliding_window: cfg.sliding_window,
96            },
97        })
98    }
99
100    #[allow(clippy::too_many_arguments)]
101    fn forward(
102        &self,
103        xs: &Tensor,
104        attention_mask: Option<&Tensor>,
105        seqlen_offsets: &[usize],
106        position_ids: &[usize],
107        kv_cache: &mut Option<(Tensor, Tensor)>,
108        scalings: Option<Tensor>,
109        global_scaling_weight: f64,
110        is_scaling_pass: Option<f64>,
111        flash_params: &FlashParams,
112    ) -> Result<Tensor> {
113        let (b_sz, q_len, _) = xs.dims3()?;
114
115        let original_dtype = xs.dtype();
116        let mut xs = xs.clone();
117        if let Some(t) = self.qkv_proj.quantized_act_type() {
118            xs = xs.to_dtype(t)?;
119        }
120        let mut qkv = self.qkv_proj.lora_forward(
121            &xs,
122            scalings.clone(),
123            global_scaling_weight,
124            is_scaling_pass,
125        )?;
126        if self.qkv_proj.quantized_act_type().is_some() {
127            qkv = qkv.to_dtype(original_dtype)?;
128        }
129        let query_pos = self.num_heads * self.head_dim;
130        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
131        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
132        let v = qkv.narrow(
133            D::Minus1,
134            query_pos + self.num_kv_heads * self.head_dim,
135            self.num_kv_heads * self.head_dim,
136        )?;
137
138        let (q, k, v) = if q_len != 1 {
139            let q = q
140                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
141                .transpose(1, 2)?;
142            let k = k
143                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
144                .transpose(1, 2)?;
145            let v = v
146                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
147                .transpose(1, 2)?;
148            (q, k, v)
149        } else {
150            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
151            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
152            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153            (q, k, v)
154        };
155
156        let (q, k) = self
157            .rotary_emb
158            .forward(&q, &k, seqlen_offsets, position_ids)?;
159
160        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
161            kv_cache,
162            k,
163            v,
164            attention_mask,
165            self.sliding_window,
166        )?;
167
168        let mut attn_output = Sdpa.run_attention(
169            &q,
170            &k,
171            &v,
172            attn_mask.as_ref(),
173            Some(flash_params),
174            &self.sdpa_params,
175        )?;
176
177        if let Some(t) = self.qkv_proj.quantized_act_type() {
178            attn_output = attn_output.to_dtype(t)?;
179        }
180        let mut res = self.o_proj.lora_forward(
181            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
182            scalings.clone(),
183            global_scaling_weight,
184            is_scaling_pass,
185        )?;
186        if self.qkv_proj.quantized_act_type().is_some() {
187            res = res.to_dtype(original_dtype)?;
188        }
189        Ok(res)
190    }
191}
192
193#[derive(Clone)]
194struct Mlp {
195    gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
196    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
197    act_fn: Activation,
198    i_size: usize,
199}
200
201impl Mlp {
202    #[allow(clippy::too_many_arguments)]
203    fn new(
204        cfg: &Config,
205        vb: ShardedVarBuilder,
206        lora_config: &[((String, String), LoraConfig)],
207        count: &mut usize,
208        ord: &Ordering,
209        mapper: &dyn DeviceMapper,
210        layer_idx: usize,
211        loading_isq: bool,
212        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
213    ) -> Result<Self> {
214        let hidden_size = cfg.hidden_size;
215        let i_size = cfg.intermediate_size;
216        let gate_up_proj = linear_no_bias(
217            hidden_size,
218            2 * i_size,
219            mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
220            mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
221            lora_config,
222            count,
223            ord,
224            preload_adapters,
225        )?;
226        let down_proj = linear_no_bias(
227            i_size,
228            hidden_size,
229            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
230            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
231            lora_config,
232            count,
233            ord,
234            preload_adapters,
235        )?;
236        Ok(Self {
237            gate_up_proj,
238            down_proj,
239            act_fn: cfg.hidden_act,
240            i_size,
241        })
242    }
243
244    fn forward(
245        &self,
246        xs: &Tensor,
247        scalings: Option<Tensor>,
248        global_scaling_weight: f64,
249        is_scaling_pass: Option<f64>,
250    ) -> Result<Tensor> {
251        let original_dtype = xs.dtype();
252        let mut xs = xs.clone();
253        if let Some(t) = self.gate_up_proj.quantized_act_type() {
254            xs = xs.to_dtype(t)?;
255        }
256        let up_states = self.gate_up_proj.lora_forward(
257            &xs,
258            scalings.clone(),
259            global_scaling_weight,
260            is_scaling_pass,
261        )?;
262        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
263        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
264        let up_states = (up_states * gate.apply(&self.act_fn))?;
265        let mut res = self.down_proj.lora_forward(
266            &up_states,
267            scalings,
268            global_scaling_weight,
269            is_scaling_pass,
270        )?;
271        if self.gate_up_proj.quantized_act_type().is_some() {
272            res = res.to_dtype(original_dtype)?;
273        }
274        Ok(res)
275    }
276}
277
278struct DecoderLayer {
279    self_attn: Attention,
280    mlp: Mlp,
281    input_layernorm: RmsNorm,
282    post_attention_layernorm: RmsNorm,
283}
284
285impl DecoderLayer {
286    #[allow(clippy::too_many_arguments)]
287    fn new(
288        rotary_emb: Arc<PhiRotaryEmbedding>,
289        cfg: &Config,
290        vb: ShardedVarBuilder,
291        lora_config: &[((String, String), LoraConfig)],
292        count: &mut usize,
293        ord: &Ordering,
294        mapper: &dyn DeviceMapper,
295        layer_idx: usize,
296        loading_isq: bool,
297        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
298    ) -> Result<Self> {
299        let self_attn = Attention::new(
300            rotary_emb,
301            cfg,
302            vb.pp("self_attn"),
303            lora_config,
304            count,
305            ord,
306            mapper,
307            layer_idx,
308            loading_isq,
309            preload_adapters,
310        )?;
311        let mlp = Mlp::new(
312            cfg,
313            vb.pp("mlp"),
314            lora_config,
315            count,
316            ord,
317            mapper,
318            layer_idx,
319            loading_isq,
320            preload_adapters,
321        )?;
322        let input_layernorm = RmsNorm::new(
323            cfg.hidden_size,
324            cfg.rms_norm_eps,
325            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
326        )?;
327        let post_attention_layernorm = RmsNorm::new(
328            cfg.hidden_size,
329            cfg.rms_norm_eps,
330            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
331        )?;
332        Ok(Self {
333            self_attn,
334            mlp,
335            input_layernorm,
336            post_attention_layernorm,
337        })
338    }
339
340    #[allow(clippy::too_many_arguments)]
341    fn forward(
342        &self,
343        xs: &Tensor,
344        attention_mask: Option<&Tensor>,
345        seqlen_offsets: &[usize],
346        position_ids: &[usize],
347        kv_cache: &mut Option<(Tensor, Tensor)>,
348        scalings: Option<Tensor>,
349        global_scaling_weight: f64,
350        is_scaling_pass: Option<f64>,
351        flash_params: &FlashParams,
352    ) -> Result<Tensor> {
353        let residual = xs;
354        let xs = self.input_layernorm.forward(xs)?;
355        let xs = self.self_attn.forward(
356            &xs,
357            attention_mask,
358            seqlen_offsets,
359            position_ids,
360            kv_cache,
361            scalings.clone(),
362            global_scaling_weight,
363            is_scaling_pass,
364            flash_params,
365        )?;
366        let xs = (xs + residual)?;
367        let residual = &xs;
368        let xs = self.mlp.forward(
369            &xs.apply(&self.post_attention_layernorm)?,
370            scalings,
371            global_scaling_weight,
372            is_scaling_pass,
373        );
374        residual + xs
375    }
376}
377
378pub struct Model {
379    embed_tokens: candle_nn::Embedding,
380    layers: Vec<DecoderLayer>,
381    norm: RmsNorm,
382    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
383    dtype: DType,
384    device: Device,
385    cache: EitherCache,
386    max_seq_len: usize,
387    mapper: Box<dyn DeviceMapper + Send + Sync>,
388    xlora_classifier: Option<XLoraClassifier>,
389    sliding_window: Option<usize>,
390    cfg: ModelConfigMetadata,
391}
392
393impl Model {
394    #[allow(clippy::too_many_arguments)]
395    pub fn new(
396        cfg: &Config,
397        vb: ShardedVarBuilder,
398        lora_config: &[((String, String), LoraConfig)],
399        xlora_config: Option<XLoraConfig>,
400        xlora_ordering: Ordering,
401        _is_gptx: bool,
402        normal_loading_metadata: NormalLoadingMetadata,
403        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
404    ) -> Result<Self> {
405        if let Some(ref quant_cfg) = &cfg.quantization_config {
406            tracing::info!(
407                "Using {} quantization: {}.",
408                quant_cfg.name(),
409                quant_cfg.get_bits_name(&vb)
410            );
411        }
412        let mapper = normal_loading_metadata.mapper;
413        let vb_m = vb.pp("model");
414
415        let embed_tokens = layers::embedding(
416            cfg.vocab_size,
417            cfg.hidden_size,
418            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
419            &cfg.quantization_config,
420        )?;
421        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
422        let vb_l = vb_m.pp("layers");
423        let mut ropes = HashMap::new();
424        for layer_idx in 0..cfg.num_hidden_layers {
425            let device = mapper
426                .device_for(layer_idx, false)
427                .unwrap_or(&normal_loading_metadata.real_device);
428            ropes.insert(
429                device.location(),
430                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
431            );
432        }
433        let mut count = 0;
434        for layer_idx in NiceProgressBar::<_, 'b'>(
435            0..cfg.num_hidden_layers,
436            "Loading repeating layers",
437            &normal_loading_metadata.multi_progress,
438        ) {
439            let device = mapper
440                .device_for(layer_idx, false)
441                .unwrap_or(&normal_loading_metadata.real_device);
442            let rotary_emb = ropes
443                .get(&device.location())
444                .expect("No RoPE for device location!")
445                .clone();
446            let layer = DecoderLayer::new(
447                rotary_emb.clone(),
448                cfg,
449                vb_l.pp(layer_idx),
450                lora_config,
451                &mut count,
452                &xlora_ordering,
453                &*mapper,
454                layer_idx,
455                normal_loading_metadata.loading_isq,
456                preload_adapters,
457            )?;
458            layers.push(layer)
459        }
460        if xlora_config.is_none() && preload_adapters.is_none() {
461            // We are now a LoRA model so we must merge the weights
462            info!("Merging LoRA adapters.");
463            for layer in layers.iter_mut().tqdm() {
464                Arc::get_mut(&mut layer.self_attn.qkv_proj)
465                    .unwrap()
466                    .merge_weights()?;
467                Arc::get_mut(&mut layer.self_attn.o_proj)
468                    .unwrap()
469                    .merge_weights()?;
470
471                Arc::get_mut(&mut layer.mlp.down_proj)
472                    .unwrap()
473                    .merge_weights()?;
474                Arc::get_mut(&mut layer.mlp.gate_up_proj)
475                    .unwrap()
476                    .merge_weights()?;
477            }
478        }
479        let norm = RmsNorm::new(
480            cfg.hidden_size,
481            cfg.rms_norm_eps,
482            mapper.set_nm_device(vb_m.pp("norm"), false),
483        )?;
484        let lm_head = linear_no_bias(
485            cfg.hidden_size,
486            cfg.vocab_size,
487            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
488            mapper.set_nm_device(vb.pp("lm_head"), false),
489            lora_config,
490            &mut count,
491            &xlora_ordering,
492            preload_adapters,
493        )?;
494        if xlora_config.is_some() && lm_head.is_lora() {
495            // This is why we can pass dummy values (..., None, 1.0, None)?
496            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
497        }
498        Ok(Self {
499            embed_tokens,
500            layers,
501            norm,
502            lm_head,
503            device: normal_loading_metadata.real_device,
504            dtype: vb.dtype(),
505            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
506            max_seq_len: cfg.max_position_embeddings,
507            mapper,
508            xlora_classifier: xlora_config.map(|xlora_config| {
509                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
510            }),
511            sliding_window: cfg.sliding_window,
512            cfg: ModelConfigMetadata {
513                max_seq_len: cfg.max_position_embeddings,
514                num_layers: cfg.num_hidden_layers,
515                hidden_size: cfg.hidden_size,
516                num_kv_heads: cfg.num_key_value_heads,
517                num_attn_heads: cfg.num_attention_heads,
518                sliding_window: cfg.sliding_window,
519                k_head_dim: cfg.head_dim(),
520                v_head_dim: cfg.head_dim(),
521            },
522        })
523    }
524
525    #[allow(clippy::too_many_arguments)]
526    fn inner_forward(
527        &self,
528        input_ids: &Tensor,
529        seqlen_offsets: &[usize],
530        position_ids: &[usize],
531        scalings: Option<Tensor>,
532        is_full_pass: bool,
533        no_kv_cache: bool,
534        is_scaling_pass: Option<f64>,
535        flash_params: &FlashParams,
536    ) -> Result<Tensor> {
537        let mut xs = self.embed_tokens.forward(input_ids)?;
538        let mut cache = if is_full_pass {
539            if no_kv_cache {
540                let mut new_cache = Vec::new();
541                for _ in 0..self.cache.full().xlora_lock().len() {
542                    new_cache.push(None);
543                }
544
545                self.cache.full().xlora_lock().clone_from(&new_cache);
546            }
547            self.cache.full().xlora_lock()
548        } else {
549            self.cache.full().lock()
550        };
551        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
552            input_ids,
553            &*cache,
554            self.sliding_window,
555            xs.dtype(),
556            self.cfg.num_attn_heads,
557        )?;
558
559        for (i, layer) in self.layers.iter().enumerate() {
560            xs = self.mapper.map(xs, i)?;
561            xs = layer.forward(
562                &xs,
563                attention_mask
564                    .as_ref()
565                    .map(|m| m.to_device(xs.device()).unwrap())
566                    .as_ref(),
567                seqlen_offsets,
568                position_ids,
569                &mut cache[i],
570                scalings.clone(),
571                self.xlora_classifier
572                    .as_ref()
573                    .map(|classifier| classifier.get_global_scaling_weight())
574                    .unwrap_or(1.0),
575                is_scaling_pass,
576                flash_params,
577            )?
578        }
579        let xs = xs.to_device(&self.device)?;
580        xs.apply(&self.norm)
581    }
582
583    #[allow(clippy::too_many_arguments)]
584    pub fn forward(
585        &self,
586        input_ids: &Tensor,
587        input_ids_full: &Tensor,
588        seqlen_offsets: &[usize],
589        seqlen_offsets_full: &[usize],
590        no_kv_cache: bool,
591        non_granular_state: &Option<NonGranularState>,
592        context_lens: Vec<(usize, usize)>,
593        position_ids: Vec<usize>,
594        flash_params: &FlashParams,
595        flash_params_full: &FlashParams,
596    ) -> Result<Tensor> {
597        if self.xlora_classifier.is_some() {
598            let scalings = self.get_scalings(
599                input_ids,
600                input_ids_full,
601                seqlen_offsets,
602                seqlen_offsets_full,
603                no_kv_cache,
604                non_granular_state,
605                &position_ids,
606                flash_params,
607                flash_params_full,
608            )?;
609
610            if no_kv_cache {
611                let mut res = self
612                    .inner_forward(
613                        input_ids_full,
614                        seqlen_offsets_full,
615                        &position_ids,
616                        Some(scalings),
617                        true,
618                        no_kv_cache,
619                        None,
620                        flash_params_full,
621                    )?
622                    .contiguous()?;
623                if let Some(t) = self.lm_head.quantized_act_type() {
624                    res = res.to_dtype(t)?;
625                }
626                extract_logits(
627                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
628                    context_lens,
629                )
630            } else {
631                // is_full_pass=true is ok because no_kv_cache=false
632                let mut res = self
633                    .inner_forward(
634                        input_ids,
635                        seqlen_offsets,
636                        &position_ids,
637                        Some(scalings),
638                        true,
639                        no_kv_cache,
640                        None,
641                        flash_params,
642                    )?
643                    .contiguous()?;
644                if let Some(t) = self.lm_head.quantized_act_type() {
645                    res = res.to_dtype(t)?;
646                }
647                extract_logits(
648                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
649                    context_lens,
650                )
651            }
652        } else {
653            let mut res = self
654                .inner_forward(
655                    input_ids,
656                    seqlen_offsets,
657                    &position_ids,
658                    None,
659                    false,
660                    no_kv_cache,
661                    None,
662                    flash_params,
663                )?
664                .contiguous()?;
665            if let Some(t) = self.lm_head.quantized_act_type() {
666                res = res.to_dtype(t)?;
667            }
668            extract_logits(
669                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
670                context_lens,
671            )
672        }
673    }
674}
675
676impl IsqModel for Model {
677    fn get_layers(
678        &mut self,
679    ) -> (
680        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
681        &dyn DeviceMapper,
682    ) {
683        let mut tensors = Vec::new();
684        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
685        for (i, layer) in self.layers.iter_mut().enumerate() {
686            tensors.push((
687                Arc::get_mut(&mut layer.self_attn.qkv_proj)
688                    .unwrap()
689                    .quant_inner(),
690                Some(i),
691            ));
692            tensors.push((
693                Arc::get_mut(&mut layer.self_attn.o_proj)
694                    .unwrap()
695                    .quant_inner(),
696                Some(i),
697            ));
698            tensors.push((
699                Arc::get_mut(&mut layer.mlp.gate_up_proj)
700                    .unwrap()
701                    .quant_inner(),
702                Some(i),
703            ));
704            tensors.push((
705                Arc::get_mut(&mut layer.mlp.down_proj)
706                    .unwrap()
707                    .quant_inner(),
708                Some(i),
709            ));
710        }
711        (tensors, &*self.mapper)
712    }
713
714    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
715        panic!("Cannot generate UQFF for an adapter model.")
716    }
717}
718
719impl NormalModel for Model {
720    fn forward(
721        &self,
722        _input_ids: &Tensor,
723        _seqlen_offsets: &[usize],
724        _context_lens: Vec<(usize, usize)>,
725        _position_ids: Vec<usize>,
726        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
727        _flash_params: &FlashParams,
728    ) -> Result<Tensor> {
729        unreachable!()
730    }
731    fn xlora_forward(
732        &self,
733        input_ids: &Tensor,
734        input_ids_full: &Tensor,
735        seqlen_offsets: &[usize],
736        seqlen_offsets_full: &[usize],
737        no_kv_cache: bool,
738        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
739        context_lens: Vec<(usize, usize)>,
740        position_ids: Vec<usize>,
741        flash_params: &FlashParams,
742        flash_params_full: &FlashParams,
743    ) -> Result<Tensor> {
744        self.forward(
745            input_ids,
746            input_ids_full,
747            seqlen_offsets,
748            seqlen_offsets_full,
749            no_kv_cache,
750            non_granular_state,
751            context_lens,
752            position_ids,
753            flash_params,
754            flash_params_full,
755        )
756    }
757    fn cache(&self) -> &EitherCache {
758        &self.cache
759    }
760    fn cache_mut(&mut self) -> &mut EitherCache {
761        &mut self.cache
762    }
763    fn device(&self) -> &Device {
764        &self.device
765    }
766    fn is_xlora(&self) -> bool {
767        true
768    }
769    fn max_seq_len(&self) -> usize {
770        self.max_seq_len
771    }
772    fn config(&self) -> &ModelConfigMetadata {
773        &self.cfg
774    }
775}
776
777impl ScalingsMaker for Model {
778    fn dtype(&self) -> DType {
779        self.dtype
780    }
781    fn get_cache(&self) -> &EitherCache {
782        &self.cache
783    }
784    fn get_classifier(&self) -> &XLoraClassifier {
785        self.xlora_classifier.as_ref().unwrap()
786    }
787    fn forward(
788        &self,
789        input_ids: &Tensor,
790        seqlen_offsets: &[usize],
791        scalings: Tensor,
792        is_full_pass: bool,
793        no_kv_cache: bool,
794        is_scaling_pass: Option<f64>,
795        context_lens: &[usize],
796        flash_params: &FlashParams,
797    ) -> Result<Tensor> {
798        // NOTE(EricLBuehler): hacky yes, but passing the context lens to start the position ids calculation works
799        self.inner_forward(
800            input_ids,
801            seqlen_offsets,
802            context_lens,
803            Some(scalings),
804            is_full_pass,
805            no_kv_cache,
806            is_scaling_pass,
807            flash_params,
808        )
809    }
810}
811
812impl AnyMoeBaseModelMixin for Model {}