mistralrs_core/xlora_models/
phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use crate::{
6    amoe::AnyMoeBaseModelMixin,
7    attention::SdpaParams,
8    layers::{self, Activation, Sdpa},
9    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10    paged_attention::ModelConfigMetadata,
11    pipeline::{
12        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13        EitherCache, IsqModel, NormalLoadingMetadata,
14    },
15    utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24    device_map::DeviceMapper,
25    layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26    models::phi3::Config,
27    pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35    qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    num_heads: usize,
38    num_kv_heads: usize,
39    head_dim: usize,
40    rotary_emb: Arc<PhiRotaryEmbedding>,
41    sliding_window: Option<usize>,
42    sdpa_params: SdpaParams,
43}
44
45impl Attention {
46    #[allow(clippy::too_many_arguments)]
47    fn new(
48        rotary_emb: Arc<PhiRotaryEmbedding>,
49        cfg: &Config,
50        vb: ShardedVarBuilder,
51        lora_config: &[((String, String), LoraConfig)],
52        count: &mut usize,
53        ord: &Ordering,
54        mapper: &dyn DeviceMapper,
55        layer_idx: usize,
56        loading_isq: bool,
57        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58    ) -> Result<Self> {
59        let num_heads = cfg.num_attention_heads;
60        let num_kv_heads = cfg.num_key_value_heads;
61        let head_dim = cfg.head_dim();
62        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63        let qkv_proj = linear_no_bias(
64            cfg.hidden_size,
65            op_size,
66            mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67            mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68            lora_config,
69            count,
70            ord,
71            preload_adapters,
72        )?;
73        let o_proj = linear_no_bias(
74            num_heads * head_dim,
75            cfg.hidden_size,
76            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78            lora_config,
79            count,
80            ord,
81            preload_adapters,
82        )?;
83        Ok(Self {
84            qkv_proj,
85            o_proj,
86            rotary_emb,
87            num_heads,
88            num_kv_heads,
89            head_dim,
90            sliding_window: cfg.sliding_window,
91            sdpa_params: SdpaParams {
92                n_kv_groups: num_heads / num_kv_heads,
93                softcap: None,
94                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
95                sliding_window: cfg.sliding_window,
96            },
97        })
98    }
99
100    #[allow(clippy::too_many_arguments)]
101    fn forward(
102        &self,
103        xs: &Tensor,
104        attention_mask: Option<&Tensor>,
105        seqlen_offsets: &[usize],
106        position_ids: &[usize],
107        kv_cache: &mut Option<(Tensor, Tensor)>,
108        scalings: Option<Tensor>,
109        global_scaling_weight: f64,
110        is_scaling_pass: Option<f64>,
111        flash_params: &FlashParams,
112    ) -> Result<Tensor> {
113        let (b_sz, q_len, _) = xs.dims3()?;
114
115        let original_dtype = xs.dtype();
116        let mut xs = xs.clone();
117        if let Some(t) = self.qkv_proj.quantized_act_type() {
118            xs = xs.to_dtype(t)?;
119        }
120        let mut qkv = self.qkv_proj.lora_forward(
121            &xs,
122            scalings.clone(),
123            global_scaling_weight,
124            is_scaling_pass,
125        )?;
126        if self.qkv_proj.quantized_act_type().is_some() {
127            qkv = qkv.to_dtype(original_dtype)?;
128        }
129        let query_pos = self.num_heads * self.head_dim;
130        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
131        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
132        let v = qkv.narrow(
133            D::Minus1,
134            query_pos + self.num_kv_heads * self.head_dim,
135            self.num_kv_heads * self.head_dim,
136        )?;
137
138        let (q, k, v) = if q_len != 1 {
139            let q = q
140                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
141                .transpose(1, 2)?;
142            let k = k
143                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
144                .transpose(1, 2)?;
145            let v = v
146                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
147                .transpose(1, 2)?;
148            (q, k, v)
149        } else {
150            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
151            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
152            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153            (q, k, v)
154        };
155
156        let (q, k) = self
157            .rotary_emb
158            .forward(&q, &k, seqlen_offsets, position_ids)?;
159
160        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
161            kv_cache,
162            k,
163            v,
164            attention_mask,
165            self.sliding_window,
166            true,
167        )?;
168
169        let mut attn_output = Sdpa.run_attention(
170            &q,
171            &k,
172            &v,
173            attn_mask.as_ref(),
174            Some(flash_params),
175            &self.sdpa_params,
176        )?;
177
178        if let Some(t) = self.qkv_proj.quantized_act_type() {
179            attn_output = attn_output.to_dtype(t)?;
180        }
181        let mut res = self.o_proj.lora_forward(
182            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
183            scalings.clone(),
184            global_scaling_weight,
185            is_scaling_pass,
186        )?;
187        if self.qkv_proj.quantized_act_type().is_some() {
188            res = res.to_dtype(original_dtype)?;
189        }
190        Ok(res)
191    }
192}
193
194#[derive(Clone)]
195struct Mlp {
196    gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
197    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
198    act_fn: Activation,
199    i_size: usize,
200}
201
202impl Mlp {
203    #[allow(clippy::too_many_arguments)]
204    fn new(
205        cfg: &Config,
206        vb: ShardedVarBuilder,
207        lora_config: &[((String, String), LoraConfig)],
208        count: &mut usize,
209        ord: &Ordering,
210        mapper: &dyn DeviceMapper,
211        layer_idx: usize,
212        loading_isq: bool,
213        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
214    ) -> Result<Self> {
215        let hidden_size = cfg.hidden_size;
216        let i_size = cfg.intermediate_size;
217        let gate_up_proj = linear_no_bias(
218            hidden_size,
219            2 * i_size,
220            mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
221            mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
222            lora_config,
223            count,
224            ord,
225            preload_adapters,
226        )?;
227        let down_proj = linear_no_bias(
228            i_size,
229            hidden_size,
230            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
231            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
232            lora_config,
233            count,
234            ord,
235            preload_adapters,
236        )?;
237        Ok(Self {
238            gate_up_proj,
239            down_proj,
240            act_fn: cfg.hidden_act,
241            i_size,
242        })
243    }
244
245    fn forward(
246        &self,
247        xs: &Tensor,
248        scalings: Option<Tensor>,
249        global_scaling_weight: f64,
250        is_scaling_pass: Option<f64>,
251    ) -> Result<Tensor> {
252        let original_dtype = xs.dtype();
253        let mut xs = xs.clone();
254        if let Some(t) = self.gate_up_proj.quantized_act_type() {
255            xs = xs.to_dtype(t)?;
256        }
257        let up_states = self.gate_up_proj.lora_forward(
258            &xs,
259            scalings.clone(),
260            global_scaling_weight,
261            is_scaling_pass,
262        )?;
263        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
264        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
265        let up_states = (up_states * gate.apply(&self.act_fn))?;
266        let mut res = self.down_proj.lora_forward(
267            &up_states,
268            scalings,
269            global_scaling_weight,
270            is_scaling_pass,
271        )?;
272        if self.gate_up_proj.quantized_act_type().is_some() {
273            res = res.to_dtype(original_dtype)?;
274        }
275        Ok(res)
276    }
277}
278
279struct DecoderLayer {
280    self_attn: Attention,
281    mlp: Mlp,
282    input_layernorm: RmsNorm,
283    post_attention_layernorm: RmsNorm,
284}
285
286impl DecoderLayer {
287    #[allow(clippy::too_many_arguments)]
288    fn new(
289        rotary_emb: Arc<PhiRotaryEmbedding>,
290        cfg: &Config,
291        vb: ShardedVarBuilder,
292        lora_config: &[((String, String), LoraConfig)],
293        count: &mut usize,
294        ord: &Ordering,
295        mapper: &dyn DeviceMapper,
296        layer_idx: usize,
297        loading_isq: bool,
298        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
299    ) -> Result<Self> {
300        let self_attn = Attention::new(
301            rotary_emb,
302            cfg,
303            vb.pp("self_attn"),
304            lora_config,
305            count,
306            ord,
307            mapper,
308            layer_idx,
309            loading_isq,
310            preload_adapters,
311        )?;
312        let mlp = Mlp::new(
313            cfg,
314            vb.pp("mlp"),
315            lora_config,
316            count,
317            ord,
318            mapper,
319            layer_idx,
320            loading_isq,
321            preload_adapters,
322        )?;
323        let input_layernorm = RmsNorm::new(
324            cfg.hidden_size,
325            cfg.rms_norm_eps,
326            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
327        )?;
328        let post_attention_layernorm = RmsNorm::new(
329            cfg.hidden_size,
330            cfg.rms_norm_eps,
331            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
332        )?;
333        Ok(Self {
334            self_attn,
335            mlp,
336            input_layernorm,
337            post_attention_layernorm,
338        })
339    }
340
341    #[allow(clippy::too_many_arguments)]
342    fn forward(
343        &self,
344        xs: &Tensor,
345        attention_mask: Option<&Tensor>,
346        seqlen_offsets: &[usize],
347        position_ids: &[usize],
348        kv_cache: &mut Option<(Tensor, Tensor)>,
349        scalings: Option<Tensor>,
350        global_scaling_weight: f64,
351        is_scaling_pass: Option<f64>,
352        flash_params: &FlashParams,
353    ) -> Result<Tensor> {
354        let residual = xs;
355        let xs = self.input_layernorm.forward(xs)?;
356        let xs = self.self_attn.forward(
357            &xs,
358            attention_mask,
359            seqlen_offsets,
360            position_ids,
361            kv_cache,
362            scalings.clone(),
363            global_scaling_weight,
364            is_scaling_pass,
365            flash_params,
366        )?;
367        let xs = (xs + residual)?;
368        let residual = &xs;
369        let xs = self.mlp.forward(
370            &xs.apply(&self.post_attention_layernorm)?,
371            scalings,
372            global_scaling_weight,
373            is_scaling_pass,
374        );
375        residual + xs
376    }
377}
378
379pub struct Model {
380    embed_tokens: candle_nn::Embedding,
381    layers: Vec<DecoderLayer>,
382    norm: RmsNorm,
383    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
384    dtype: DType,
385    device: Device,
386    cache: EitherCache,
387    max_seq_len: usize,
388    mapper: Box<dyn DeviceMapper + Send + Sync>,
389    xlora_classifier: Option<XLoraClassifier>,
390    sliding_window: Option<usize>,
391    cfg: ModelConfigMetadata,
392}
393
394impl Model {
395    #[allow(clippy::too_many_arguments)]
396    pub fn new(
397        cfg: &Config,
398        vb: ShardedVarBuilder,
399        lora_config: &[((String, String), LoraConfig)],
400        xlora_config: Option<XLoraConfig>,
401        xlora_ordering: Ordering,
402        _is_gptx: bool,
403        normal_loading_metadata: NormalLoadingMetadata,
404        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
405    ) -> Result<Self> {
406        if let Some(ref quant_cfg) = &cfg.quantization_config {
407            tracing::info!(
408                "Using {} quantization: {}.",
409                quant_cfg.name(),
410                quant_cfg.get_bits_name(&vb)
411            );
412        }
413        let mapper = normal_loading_metadata.mapper;
414        let vb_m = vb.pp("model");
415
416        let embed_tokens = layers::embedding(
417            cfg.vocab_size,
418            cfg.hidden_size,
419            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
420            &cfg.quantization_config,
421        )?;
422        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
423        let vb_l = vb_m.pp("layers");
424        let mut ropes = HashMap::new();
425        for layer_idx in 0..cfg.num_hidden_layers {
426            let device = mapper
427                .device_for(layer_idx, false)
428                .unwrap_or(&normal_loading_metadata.real_device);
429            ropes.insert(
430                device.location(),
431                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
432            );
433        }
434        let mut count = 0;
435        for layer_idx in NiceProgressBar::<_, 'b'>(
436            0..cfg.num_hidden_layers,
437            "Loading repeating layers",
438            &normal_loading_metadata.multi_progress,
439        ) {
440            let device = mapper
441                .device_for(layer_idx, false)
442                .unwrap_or(&normal_loading_metadata.real_device);
443            let rotary_emb = ropes
444                .get(&device.location())
445                .expect("No RoPE for device location!")
446                .clone();
447            let layer = DecoderLayer::new(
448                rotary_emb.clone(),
449                cfg,
450                vb_l.pp(layer_idx),
451                lora_config,
452                &mut count,
453                &xlora_ordering,
454                &*mapper,
455                layer_idx,
456                normal_loading_metadata.loading_isq,
457                preload_adapters,
458            )?;
459            layers.push(layer)
460        }
461        if xlora_config.is_none() && preload_adapters.is_none() {
462            // We are now a LoRA model so we must merge the weights
463            info!("Merging LoRA adapters.");
464            for layer in layers.iter_mut().tqdm() {
465                Arc::get_mut(&mut layer.self_attn.qkv_proj)
466                    .unwrap()
467                    .merge_weights()?;
468                Arc::get_mut(&mut layer.self_attn.o_proj)
469                    .unwrap()
470                    .merge_weights()?;
471
472                Arc::get_mut(&mut layer.mlp.down_proj)
473                    .unwrap()
474                    .merge_weights()?;
475                Arc::get_mut(&mut layer.mlp.gate_up_proj)
476                    .unwrap()
477                    .merge_weights()?;
478            }
479        }
480        let norm = RmsNorm::new(
481            cfg.hidden_size,
482            cfg.rms_norm_eps,
483            mapper.set_nm_device(vb_m.pp("norm"), false),
484        )?;
485        let lm_head = linear_no_bias(
486            cfg.hidden_size,
487            cfg.vocab_size,
488            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
489            mapper.set_nm_device(vb.pp("lm_head"), false),
490            lora_config,
491            &mut count,
492            &xlora_ordering,
493            preload_adapters,
494        )?;
495        if xlora_config.is_some() && lm_head.is_lora() {
496            // This is why we can pass dummy values (..., None, 1.0, None)?
497            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
498        }
499        Ok(Self {
500            embed_tokens,
501            layers,
502            norm,
503            lm_head,
504            device: normal_loading_metadata.real_device,
505            dtype: vb.dtype(),
506            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
507            max_seq_len: cfg.max_position_embeddings,
508            mapper,
509            xlora_classifier: xlora_config.map(|xlora_config| {
510                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
511            }),
512            sliding_window: cfg.sliding_window,
513            cfg: ModelConfigMetadata {
514                max_seq_len: cfg.max_position_embeddings,
515                num_layers: cfg.num_hidden_layers,
516                hidden_size: cfg.hidden_size,
517                num_kv_heads: cfg.num_key_value_heads,
518                num_attn_heads: cfg.num_attention_heads,
519                sliding_window: cfg.sliding_window,
520                k_head_dim: cfg.head_dim(),
521                v_head_dim: cfg.head_dim(),
522            },
523        })
524    }
525
526    #[allow(clippy::too_many_arguments)]
527    fn inner_forward(
528        &self,
529        input_ids: &Tensor,
530        seqlen_offsets: &[usize],
531        position_ids: &[usize],
532        scalings: Option<Tensor>,
533        is_full_pass: bool,
534        no_kv_cache: bool,
535        is_scaling_pass: Option<f64>,
536        flash_params: &FlashParams,
537    ) -> Result<Tensor> {
538        let mut xs = self.embed_tokens.forward(input_ids)?;
539        let mut cache = if is_full_pass {
540            if no_kv_cache {
541                let mut new_cache = Vec::new();
542                for _ in 0..self.cache.full().xlora_lock().len() {
543                    new_cache.push(None);
544                }
545
546                self.cache.full().xlora_lock().clone_from(&new_cache);
547            }
548            self.cache.full().xlora_lock()
549        } else {
550            self.cache.full().lock()
551        };
552        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
553            input_ids,
554            &*cache,
555            self.sliding_window,
556            xs.dtype(),
557            self.cfg.num_attn_heads,
558        )?;
559
560        for (i, layer) in self.layers.iter().enumerate() {
561            xs = self.mapper.map(xs, i)?;
562            xs = layer.forward(
563                &xs,
564                attention_mask
565                    .as_ref()
566                    .map(|m| m.to_device(xs.device()).unwrap())
567                    .as_ref(),
568                seqlen_offsets,
569                position_ids,
570                &mut cache[i],
571                scalings.clone(),
572                self.xlora_classifier
573                    .as_ref()
574                    .map(|classifier| classifier.get_global_scaling_weight())
575                    .unwrap_or(1.0),
576                is_scaling_pass,
577                flash_params,
578            )?
579        }
580        let xs = xs.to_device(&self.device)?;
581        xs.apply(&self.norm)
582    }
583
584    #[allow(clippy::too_many_arguments)]
585    pub fn forward(
586        &self,
587        input_ids: &Tensor,
588        input_ids_full: &Tensor,
589        seqlen_offsets: &[usize],
590        seqlen_offsets_full: &[usize],
591        no_kv_cache: bool,
592        non_granular_state: &Option<NonGranularState>,
593        context_lens: Vec<(usize, usize)>,
594        position_ids: Vec<usize>,
595        flash_params: &FlashParams,
596        flash_params_full: &FlashParams,
597    ) -> Result<Tensor> {
598        if self.xlora_classifier.is_some() {
599            let scalings = self.get_scalings(
600                input_ids,
601                input_ids_full,
602                seqlen_offsets,
603                seqlen_offsets_full,
604                no_kv_cache,
605                non_granular_state,
606                &position_ids,
607                flash_params,
608                flash_params_full,
609            )?;
610
611            if no_kv_cache {
612                let mut res = self
613                    .inner_forward(
614                        input_ids_full,
615                        seqlen_offsets_full,
616                        &position_ids,
617                        Some(scalings),
618                        true,
619                        no_kv_cache,
620                        None,
621                        flash_params_full,
622                    )?
623                    .contiguous()?;
624                if let Some(t) = self.lm_head.quantized_act_type() {
625                    res = res.to_dtype(t)?;
626                }
627                extract_logits(
628                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
629                    context_lens,
630                )
631            } else {
632                // is_full_pass=true is ok because no_kv_cache=false
633                let mut res = self
634                    .inner_forward(
635                        input_ids,
636                        seqlen_offsets,
637                        &position_ids,
638                        Some(scalings),
639                        true,
640                        no_kv_cache,
641                        None,
642                        flash_params,
643                    )?
644                    .contiguous()?;
645                if let Some(t) = self.lm_head.quantized_act_type() {
646                    res = res.to_dtype(t)?;
647                }
648                extract_logits(
649                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
650                    context_lens,
651                )
652            }
653        } else {
654            let mut res = self
655                .inner_forward(
656                    input_ids,
657                    seqlen_offsets,
658                    &position_ids,
659                    None,
660                    false,
661                    no_kv_cache,
662                    None,
663                    flash_params,
664                )?
665                .contiguous()?;
666            if let Some(t) = self.lm_head.quantized_act_type() {
667                res = res.to_dtype(t)?;
668            }
669            extract_logits(
670                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
671                context_lens,
672            )
673        }
674    }
675}
676
677impl IsqModel for Model {
678    fn get_layers(
679        &mut self,
680    ) -> (
681        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
682        &dyn DeviceMapper,
683    ) {
684        let mut tensors = Vec::new();
685        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
686        for (i, layer) in self.layers.iter_mut().enumerate() {
687            tensors.push((
688                Arc::get_mut(&mut layer.self_attn.qkv_proj)
689                    .unwrap()
690                    .quant_inner(),
691                Some(i),
692            ));
693            tensors.push((
694                Arc::get_mut(&mut layer.self_attn.o_proj)
695                    .unwrap()
696                    .quant_inner(),
697                Some(i),
698            ));
699            tensors.push((
700                Arc::get_mut(&mut layer.mlp.gate_up_proj)
701                    .unwrap()
702                    .quant_inner(),
703                Some(i),
704            ));
705            tensors.push((
706                Arc::get_mut(&mut layer.mlp.down_proj)
707                    .unwrap()
708                    .quant_inner(),
709                Some(i),
710            ));
711        }
712        (tensors, &*self.mapper)
713    }
714
715    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
716        panic!("Cannot generate UQFF for an adapter model.")
717    }
718}
719
720impl NormalModel for Model {
721    fn forward(
722        &self,
723        _input_ids: &Tensor,
724        _seqlen_offsets: &[usize],
725        _context_lens: Vec<(usize, usize)>,
726        _position_ids: Vec<usize>,
727        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
728        _flash_params: &FlashParams,
729    ) -> Result<Tensor> {
730        unreachable!()
731    }
732    fn xlora_forward(
733        &self,
734        input_ids: &Tensor,
735        input_ids_full: &Tensor,
736        seqlen_offsets: &[usize],
737        seqlen_offsets_full: &[usize],
738        no_kv_cache: bool,
739        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
740        context_lens: Vec<(usize, usize)>,
741        position_ids: Vec<usize>,
742        flash_params: &FlashParams,
743        flash_params_full: &FlashParams,
744    ) -> Result<Tensor> {
745        self.forward(
746            input_ids,
747            input_ids_full,
748            seqlen_offsets,
749            seqlen_offsets_full,
750            no_kv_cache,
751            non_granular_state,
752            context_lens,
753            position_ids,
754            flash_params,
755            flash_params_full,
756        )
757    }
758    fn cache(&self) -> &EitherCache {
759        &self.cache
760    }
761    fn cache_mut(&mut self) -> &mut EitherCache {
762        &mut self.cache
763    }
764    fn device(&self) -> &Device {
765        &self.device
766    }
767    fn is_xlora(&self) -> bool {
768        true
769    }
770    fn max_seq_len(&self) -> usize {
771        self.max_seq_len
772    }
773    fn config(&self) -> &ModelConfigMetadata {
774        &self.cfg
775    }
776}
777
778impl ScalingsMaker for Model {
779    fn dtype(&self) -> DType {
780        self.dtype
781    }
782    fn get_cache(&self) -> &EitherCache {
783        &self.cache
784    }
785    fn get_classifier(&self) -> &XLoraClassifier {
786        self.xlora_classifier.as_ref().unwrap()
787    }
788    fn forward(
789        &self,
790        input_ids: &Tensor,
791        seqlen_offsets: &[usize],
792        scalings: Tensor,
793        is_full_pass: bool,
794        no_kv_cache: bool,
795        is_scaling_pass: Option<f64>,
796        context_lens: &[usize],
797        flash_params: &FlashParams,
798    ) -> Result<Tensor> {
799        // NOTE(EricLBuehler): hacky yes, but passing the context lens to start the position ids calculation works
800        self.inner_forward(
801            input_ids,
802            seqlen_offsets,
803            context_lens,
804            Some(scalings),
805            is_full_pass,
806            no_kv_cache,
807            is_scaling_pass,
808            flash_params,
809        )
810    }
811}
812
813impl AnyMoeBaseModelMixin for Model {}