mistralrs_core/xlora_models/
phi3.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3// This implementation is based on:
4// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
5use crate::{
6    amoe::AnyMoeBaseModelMixin,
7    attention::SdpaParams,
8    layers::{self, Activation, Sdpa},
9    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
10    paged_attention::ModelConfigMetadata,
11    pipeline::{
12        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13        EitherCache, IsqModel, NormalLoadingMetadata,
14    },
15    utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor, D};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use std::{collections::HashMap, sync::Arc};
20use tqdm::Iter;
21use tracing::info;
22
23use crate::{
24    device_map::DeviceMapper,
25    layers::{CausalMasker, PhiRotaryEmbedding, RmsNorm},
26    models::phi3::Config,
27    pipeline::{extract_logits, NormalModel},
28};
29
30use crate::pipeline::Cache;
31
32use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
33
34struct Attention {
35    qkv_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    num_heads: usize,
38    num_kv_heads: usize,
39    head_dim: usize,
40    rotary_emb: Arc<PhiRotaryEmbedding>,
41    sliding_window: Option<usize>,
42    sdpa_params: SdpaParams,
43}
44
45impl Attention {
46    #[allow(clippy::too_many_arguments)]
47    fn new(
48        rotary_emb: Arc<PhiRotaryEmbedding>,
49        cfg: &Config,
50        vb: ShardedVarBuilder,
51        lora_config: &[((String, String), LoraConfig)],
52        count: &mut usize,
53        ord: &Ordering,
54        mapper: &dyn DeviceMapper,
55        layer_idx: usize,
56        loading_isq: bool,
57        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
58    ) -> Result<Self> {
59        let num_heads = cfg.num_attention_heads;
60        let num_kv_heads = cfg.num_key_value_heads;
61        let head_dim = cfg.head_dim();
62        let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim;
63        let qkv_proj = linear_no_bias(
64            cfg.hidden_size,
65            op_size,
66            mapper.set_device(layer_idx, vb.pp("qkv_proj"), loading_isq),
67            mapper.set_device(layer_idx, vb.pp("qkv_proj"), false),
68            lora_config,
69            count,
70            ord,
71            preload_adapters,
72        )?;
73        let o_proj = linear_no_bias(
74            num_heads * head_dim,
75            cfg.hidden_size,
76            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
77            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
78            lora_config,
79            count,
80            ord,
81            preload_adapters,
82        )?;
83        Ok(Self {
84            qkv_proj,
85            o_proj,
86            rotary_emb,
87            num_heads,
88            num_kv_heads,
89            head_dim,
90            sliding_window: cfg.sliding_window,
91            sdpa_params: SdpaParams {
92                n_kv_groups: num_heads / num_kv_heads,
93                use_flash_attn: cfg.use_flash_attn,
94                softcap: None,
95                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
96                sliding_window: cfg.sliding_window,
97            },
98        })
99    }
100
101    #[allow(clippy::too_many_arguments)]
102    fn forward(
103        &self,
104        xs: &Tensor,
105        attention_mask: Option<&Tensor>,
106        seqlen_offsets: &[usize],
107        position_ids: &[usize],
108        kv_cache: &mut Option<(Tensor, Tensor)>,
109        scalings: Option<Tensor>,
110        global_scaling_weight: f64,
111        is_scaling_pass: Option<f64>,
112        flash_params: &FlashParams,
113    ) -> Result<Tensor> {
114        let (b_sz, q_len, _) = xs.dims3()?;
115
116        let original_dtype = xs.dtype();
117        let mut xs = xs.clone();
118        if let Some(t) = self.qkv_proj.quantized_act_type() {
119            xs = xs.to_dtype(t)?;
120        }
121        let mut qkv = self.qkv_proj.lora_forward(
122            &xs,
123            scalings.clone(),
124            global_scaling_weight,
125            is_scaling_pass,
126        )?;
127        if self.qkv_proj.quantized_act_type().is_some() {
128            qkv = qkv.to_dtype(original_dtype)?;
129        }
130        let query_pos = self.num_heads * self.head_dim;
131        let q = qkv.narrow(D::Minus1, 0, query_pos)?;
132        let k = qkv.narrow(D::Minus1, query_pos, self.num_kv_heads * self.head_dim)?;
133        let v = qkv.narrow(
134            D::Minus1,
135            query_pos + self.num_kv_heads * self.head_dim,
136            self.num_kv_heads * self.head_dim,
137        )?;
138
139        let (q, k, v) = if q_len != 1 {
140            let q = q
141                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
142                .transpose(1, 2)?;
143            let k = k
144                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
145                .transpose(1, 2)?;
146            let v = v
147                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
148                .transpose(1, 2)?;
149            (q, k, v)
150        } else {
151            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
152            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
153            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
154            (q, k, v)
155        };
156
157        let (q, k) = self
158            .rotary_emb
159            .forward(&q, &k, seqlen_offsets, position_ids)?;
160
161        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
162            kv_cache,
163            k,
164            v,
165            attention_mask,
166            self.sliding_window,
167            true,
168        )?;
169
170        let mut attn_output = Sdpa.run_attention(
171            &q,
172            &k,
173            &v,
174            attn_mask.as_ref(),
175            Some(flash_params),
176            &self.sdpa_params,
177        )?;
178
179        if let Some(t) = self.qkv_proj.quantized_act_type() {
180            attn_output = attn_output.to_dtype(t)?;
181        }
182        let mut res = self.o_proj.lora_forward(
183            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
184            scalings.clone(),
185            global_scaling_weight,
186            is_scaling_pass,
187        )?;
188        if self.qkv_proj.quantized_act_type().is_some() {
189            res = res.to_dtype(original_dtype)?;
190        }
191        Ok(res)
192    }
193}
194
195#[derive(Clone)]
196struct Mlp {
197    gate_up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
198    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
199    act_fn: Activation,
200    i_size: usize,
201}
202
203impl Mlp {
204    #[allow(clippy::too_many_arguments)]
205    fn new(
206        cfg: &Config,
207        vb: ShardedVarBuilder,
208        lora_config: &[((String, String), LoraConfig)],
209        count: &mut usize,
210        ord: &Ordering,
211        mapper: &dyn DeviceMapper,
212        layer_idx: usize,
213        loading_isq: bool,
214        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
215    ) -> Result<Self> {
216        let hidden_size = cfg.hidden_size;
217        let i_size = cfg.intermediate_size;
218        let gate_up_proj = linear_no_bias(
219            hidden_size,
220            2 * i_size,
221            mapper.set_device(layer_idx, vb.pp("gate_up_proj"), loading_isq),
222            mapper.set_device(layer_idx, vb.pp("gate_up_proj"), false),
223            lora_config,
224            count,
225            ord,
226            preload_adapters,
227        )?;
228        let down_proj = linear_no_bias(
229            i_size,
230            hidden_size,
231            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
232            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
233            lora_config,
234            count,
235            ord,
236            preload_adapters,
237        )?;
238        Ok(Self {
239            gate_up_proj,
240            down_proj,
241            act_fn: cfg.hidden_act,
242            i_size,
243        })
244    }
245
246    fn forward(
247        &self,
248        xs: &Tensor,
249        scalings: Option<Tensor>,
250        global_scaling_weight: f64,
251        is_scaling_pass: Option<f64>,
252    ) -> Result<Tensor> {
253        let original_dtype = xs.dtype();
254        let mut xs = xs.clone();
255        if let Some(t) = self.gate_up_proj.quantized_act_type() {
256            xs = xs.to_dtype(t)?;
257        }
258        let up_states = self.gate_up_proj.lora_forward(
259            &xs,
260            scalings.clone(),
261            global_scaling_weight,
262            is_scaling_pass,
263        )?;
264        let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
265        let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
266        let up_states = (up_states * gate.apply(&self.act_fn))?;
267        let mut res = self.down_proj.lora_forward(
268            &up_states,
269            scalings,
270            global_scaling_weight,
271            is_scaling_pass,
272        )?;
273        if self.gate_up_proj.quantized_act_type().is_some() {
274            res = res.to_dtype(original_dtype)?;
275        }
276        Ok(res)
277    }
278}
279
280struct DecoderLayer {
281    self_attn: Attention,
282    mlp: Mlp,
283    input_layernorm: RmsNorm,
284    post_attention_layernorm: RmsNorm,
285}
286
287impl DecoderLayer {
288    #[allow(clippy::too_many_arguments)]
289    fn new(
290        rotary_emb: Arc<PhiRotaryEmbedding>,
291        cfg: &Config,
292        vb: ShardedVarBuilder,
293        lora_config: &[((String, String), LoraConfig)],
294        count: &mut usize,
295        ord: &Ordering,
296        mapper: &dyn DeviceMapper,
297        layer_idx: usize,
298        loading_isq: bool,
299        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
300    ) -> Result<Self> {
301        let self_attn = Attention::new(
302            rotary_emb,
303            cfg,
304            vb.pp("self_attn"),
305            lora_config,
306            count,
307            ord,
308            mapper,
309            layer_idx,
310            loading_isq,
311            preload_adapters,
312        )?;
313        let mlp = Mlp::new(
314            cfg,
315            vb.pp("mlp"),
316            lora_config,
317            count,
318            ord,
319            mapper,
320            layer_idx,
321            loading_isq,
322            preload_adapters,
323        )?;
324        let input_layernorm = RmsNorm::new(
325            cfg.hidden_size,
326            cfg.rms_norm_eps,
327            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
328        )?;
329        let post_attention_layernorm = RmsNorm::new(
330            cfg.hidden_size,
331            cfg.rms_norm_eps,
332            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
333        )?;
334        Ok(Self {
335            self_attn,
336            mlp,
337            input_layernorm,
338            post_attention_layernorm,
339        })
340    }
341
342    #[allow(clippy::too_many_arguments)]
343    fn forward(
344        &self,
345        xs: &Tensor,
346        attention_mask: Option<&Tensor>,
347        seqlen_offsets: &[usize],
348        position_ids: &[usize],
349        kv_cache: &mut Option<(Tensor, Tensor)>,
350        scalings: Option<Tensor>,
351        global_scaling_weight: f64,
352        is_scaling_pass: Option<f64>,
353        flash_params: &FlashParams,
354    ) -> Result<Tensor> {
355        let residual = xs;
356        let xs = self.input_layernorm.forward(xs)?;
357        let xs = self.self_attn.forward(
358            &xs,
359            attention_mask,
360            seqlen_offsets,
361            position_ids,
362            kv_cache,
363            scalings.clone(),
364            global_scaling_weight,
365            is_scaling_pass,
366            flash_params,
367        )?;
368        let xs = (xs + residual)?;
369        let residual = &xs;
370        let xs = self.mlp.forward(
371            &xs.apply(&self.post_attention_layernorm)?,
372            scalings,
373            global_scaling_weight,
374            is_scaling_pass,
375        );
376        residual + xs
377    }
378}
379
380pub struct Model {
381    embed_tokens: candle_nn::Embedding,
382    layers: Vec<DecoderLayer>,
383    norm: RmsNorm,
384    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
385    dtype: DType,
386    device: Device,
387    cache: EitherCache,
388    max_seq_len: usize,
389    mapper: Box<dyn DeviceMapper + Send + Sync>,
390    xlora_classifier: Option<XLoraClassifier>,
391    sliding_window: Option<usize>,
392    cfg: ModelConfigMetadata,
393}
394
395impl Model {
396    #[allow(clippy::too_many_arguments)]
397    pub fn new(
398        cfg: &Config,
399        vb: ShardedVarBuilder,
400        lora_config: &[((String, String), LoraConfig)],
401        xlora_config: Option<XLoraConfig>,
402        xlora_ordering: Ordering,
403        _is_gptx: bool,
404        normal_loading_metadata: NormalLoadingMetadata,
405        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
406    ) -> Result<Self> {
407        if let Some(ref quant_cfg) = &cfg.quantization_config {
408            tracing::info!(
409                "Using {} quantization: {}.",
410                quant_cfg.quant_method.to_string(),
411                quant_cfg.get_bits_name(&vb)
412            );
413        }
414        let mapper = normal_loading_metadata.mapper;
415        let vb_m = vb.pp("model");
416
417        let embed_tokens = layers::embedding(
418            cfg.vocab_size,
419            cfg.hidden_size,
420            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
421        )?;
422        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
423        let vb_l = vb_m.pp("layers");
424        let mut ropes = HashMap::new();
425        for layer_idx in 0..cfg.num_hidden_layers {
426            let device = mapper
427                .device_for(layer_idx, false)
428                .unwrap_or(&normal_loading_metadata.real_device);
429            ropes.insert(
430                device.location(),
431                Arc::new(PhiRotaryEmbedding::new(vb.dtype(), cfg.clone(), device)?),
432            );
433        }
434        let mut count = 0;
435        for layer_idx in NiceProgressBar::<_, 'b'>(
436            0..cfg.num_hidden_layers,
437            "Loading repeating layers",
438            &normal_loading_metadata.multi_progress,
439        ) {
440            let device = mapper
441                .device_for(layer_idx, false)
442                .unwrap_or(&normal_loading_metadata.real_device);
443            let rotary_emb = ropes
444                .get(&device.location())
445                .expect("No RoPE for device location!")
446                .clone();
447            let layer = DecoderLayer::new(
448                rotary_emb.clone(),
449                cfg,
450                vb_l.pp(layer_idx),
451                lora_config,
452                &mut count,
453                &xlora_ordering,
454                &*mapper,
455                layer_idx,
456                normal_loading_metadata.loading_isq,
457                preload_adapters,
458            )?;
459            layers.push(layer)
460        }
461        if xlora_config.is_none() && preload_adapters.is_none() {
462            // We are now a LoRA model so we must merge the weights
463            info!("Merging LoRA adapters.");
464            for layer in layers.iter_mut().tqdm() {
465                Arc::get_mut(&mut layer.self_attn.qkv_proj)
466                    .unwrap()
467                    .merge_weights()?;
468                Arc::get_mut(&mut layer.self_attn.o_proj)
469                    .unwrap()
470                    .merge_weights()?;
471
472                Arc::get_mut(&mut layer.mlp.down_proj)
473                    .unwrap()
474                    .merge_weights()?;
475                Arc::get_mut(&mut layer.mlp.gate_up_proj)
476                    .unwrap()
477                    .merge_weights()?;
478            }
479        }
480        let norm = RmsNorm::new(
481            cfg.hidden_size,
482            cfg.rms_norm_eps,
483            mapper.set_nm_device(vb_m.pp("norm"), false),
484        )?;
485        let lm_head = linear_no_bias(
486            cfg.hidden_size,
487            cfg.vocab_size,
488            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
489            mapper.set_nm_device(vb.pp("lm_head"), false),
490            lora_config,
491            &mut count,
492            &xlora_ordering,
493            preload_adapters,
494        )?;
495        if xlora_config.is_some() && lm_head.is_lora() {
496            // This is why we can pass dummy values (..., None, 1.0, None)?
497            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
498        }
499        Ok(Self {
500            embed_tokens,
501            layers,
502            norm,
503            lm_head,
504            device: normal_loading_metadata.real_device,
505            dtype: vb.dtype(),
506            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
507            max_seq_len: cfg.max_position_embeddings,
508            mapper,
509            xlora_classifier: xlora_config.map(|xlora_config| {
510                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
511            }),
512            sliding_window: cfg.sliding_window,
513            cfg: ModelConfigMetadata {
514                max_seq_len: cfg.max_position_embeddings,
515                num_layers: cfg.num_hidden_layers,
516                hidden_size: cfg.hidden_size,
517                num_kv_heads: cfg.num_key_value_heads,
518                num_attn_heads: cfg.num_attention_heads,
519                sliding_window: cfg.sliding_window,
520                k_head_dim: cfg.head_dim(),
521                v_head_dim: cfg.head_dim(),
522            },
523        })
524    }
525
526    #[allow(clippy::too_many_arguments)]
527    fn inner_forward(
528        &self,
529        input_ids: &Tensor,
530        seqlen_offsets: &[usize],
531        position_ids: &[usize],
532        scalings: Option<Tensor>,
533        is_full_pass: bool,
534        no_kv_cache: bool,
535        is_scaling_pass: Option<f64>,
536        flash_params: &FlashParams,
537    ) -> Result<Tensor> {
538        let mut xs = self.embed_tokens.forward(input_ids)?;
539        let mut cache = if is_full_pass {
540            if no_kv_cache {
541                let mut new_cache = Vec::new();
542                for _ in 0..self.cache.full().xlora_lock().len() {
543                    new_cache.push(None);
544                }
545
546                self.cache.full().xlora_lock().clone_from(&new_cache);
547            }
548            self.cache.full().xlora_lock()
549        } else {
550            self.cache.full().lock()
551        };
552        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
553            input_ids,
554            &*cache,
555            self.sliding_window,
556            xs.dtype(),
557            self.cfg.num_attn_heads,
558        )?;
559
560        for (i, layer) in self.layers.iter().enumerate() {
561            xs = self.mapper.map(xs, i)?;
562            xs = layer.forward(
563                &xs,
564                attention_mask
565                    .as_ref()
566                    .map(|m| m.to_device(xs.device()).unwrap())
567                    .as_ref(),
568                seqlen_offsets,
569                position_ids,
570                &mut cache[i],
571                scalings.clone(),
572                self.xlora_classifier
573                    .as_ref()
574                    .map(|classifier| classifier.get_global_scaling_weight())
575                    .unwrap_or(1.0),
576                is_scaling_pass,
577                flash_params,
578            )?
579        }
580        let xs = xs.to_device(&self.device)?;
581        xs.apply(&self.norm)
582    }
583
584    #[allow(clippy::too_many_arguments)]
585    pub fn forward(
586        &self,
587        input_ids: &Tensor,
588        input_ids_full: &Tensor,
589        seqlen_offsets: &[usize],
590        seqlen_offsets_full: &[usize],
591        no_kv_cache: bool,
592        non_granular_state: &Option<NonGranularState>,
593        context_lens: Vec<(usize, usize)>,
594        position_ids: Vec<usize>,
595        flash_params: &FlashParams,
596        flash_params_full: &FlashParams,
597    ) -> Result<Tensor> {
598        if self.xlora_classifier.is_some() {
599            let scalings = self.get_scalings(
600                input_ids,
601                input_ids_full,
602                seqlen_offsets,
603                seqlen_offsets_full,
604                no_kv_cache,
605                non_granular_state,
606                &position_ids,
607                flash_params,
608                flash_params_full,
609            )?;
610
611            if no_kv_cache {
612                let mut res = self
613                    .inner_forward(
614                        input_ids_full,
615                        seqlen_offsets_full,
616                        &position_ids,
617                        Some(scalings),
618                        true,
619                        no_kv_cache,
620                        None,
621                        flash_params_full,
622                    )?
623                    .contiguous()?;
624                if let Some(t) = self.lm_head.quantized_act_type() {
625                    res = res.to_dtype(t)?;
626                }
627                extract_logits(
628                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
629                    context_lens,
630                )
631            } else {
632                // is_full_pass=true is ok because no_kv_cache=false
633                let mut res = self
634                    .inner_forward(
635                        input_ids,
636                        seqlen_offsets,
637                        &position_ids,
638                        Some(scalings),
639                        true,
640                        no_kv_cache,
641                        None,
642                        flash_params,
643                    )?
644                    .contiguous()?;
645                if let Some(t) = self.lm_head.quantized_act_type() {
646                    res = res.to_dtype(t)?;
647                }
648                extract_logits(
649                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
650                    context_lens,
651                )
652            }
653        } else {
654            let mut res = self
655                .inner_forward(
656                    input_ids,
657                    seqlen_offsets,
658                    &position_ids,
659                    None,
660                    false,
661                    no_kv_cache,
662                    None,
663                    flash_params,
664                )?
665                .contiguous()?;
666            if let Some(t) = self.lm_head.quantized_act_type() {
667                res = res.to_dtype(t)?;
668            }
669            extract_logits(
670                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
671                context_lens,
672            )
673        }
674    }
675}
676
677impl IsqModel for Model {
678    fn get_layers(
679        &mut self,
680    ) -> (
681        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
682        &dyn DeviceMapper,
683    ) {
684        let mut tensors = Vec::new();
685        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
686        for (i, layer) in self.layers.iter_mut().enumerate() {
687            tensors.push((
688                Arc::get_mut(&mut layer.self_attn.qkv_proj)
689                    .unwrap()
690                    .quant_inner(),
691                Some(i),
692            ));
693            tensors.push((
694                Arc::get_mut(&mut layer.self_attn.o_proj)
695                    .unwrap()
696                    .quant_inner(),
697                Some(i),
698            ));
699            tensors.push((
700                Arc::get_mut(&mut layer.mlp.gate_up_proj)
701                    .unwrap()
702                    .quant_inner(),
703                Some(i),
704            ));
705            tensors.push((
706                Arc::get_mut(&mut layer.mlp.down_proj)
707                    .unwrap()
708                    .quant_inner(),
709                Some(i),
710            ));
711        }
712        (tensors, &*self.mapper)
713    }
714
715    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
716        panic!("Cannot generate UQFF for an adapter model.")
717    }
718}
719
720impl NormalModel for Model {
721    fn forward(
722        &self,
723        _input_ids: &Tensor,
724        _seqlen_offsets: &[usize],
725        _context_lens: Vec<(usize, usize)>,
726        _position_ids: Vec<usize>,
727        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
728        _flash_params: &FlashParams,
729    ) -> Result<Tensor> {
730        unreachable!()
731    }
732    fn xlora_forward(
733        &self,
734        input_ids: &Tensor,
735        input_ids_full: &Tensor,
736        seqlen_offsets: &[usize],
737        seqlen_offsets_full: &[usize],
738        no_kv_cache: bool,
739        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
740        context_lens: Vec<(usize, usize)>,
741        position_ids: Vec<usize>,
742        flash_params: &FlashParams,
743        flash_params_full: &FlashParams,
744    ) -> Result<Tensor> {
745        self.forward(
746            input_ids,
747            input_ids_full,
748            seqlen_offsets,
749            seqlen_offsets_full,
750            no_kv_cache,
751            non_granular_state,
752            context_lens,
753            position_ids,
754            flash_params,
755            flash_params_full,
756        )
757    }
758    fn cache(&self) -> &EitherCache {
759        &self.cache
760    }
761    fn cache_mut(&mut self) -> &mut EitherCache {
762        &mut self.cache
763    }
764    fn device(&self) -> &Device {
765        &self.device
766    }
767    fn is_xlora(&self) -> bool {
768        true
769    }
770    fn max_seq_len(&self) -> usize {
771        self.max_seq_len
772    }
773    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
774        if self.xlora_classifier.is_some() {
775            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
776        }
777        let mut sum = 0;
778        for layer in self.layers.iter_mut() {
779            sum += Arc::get_mut(&mut layer.self_attn.qkv_proj)
780                .unwrap()
781                .activate(&adapter_names)?;
782            sum += Arc::get_mut(&mut layer.self_attn.o_proj)
783                .unwrap()
784                .activate(&adapter_names)?;
785
786            sum += Arc::get_mut(&mut layer.mlp.down_proj)
787                .unwrap()
788                .activate(&adapter_names)?;
789            sum += Arc::get_mut(&mut layer.mlp.gate_up_proj)
790                .unwrap()
791                .activate(&adapter_names)?;
792        }
793        Ok(sum)
794    }
795    fn config(&self) -> &ModelConfigMetadata {
796        &self.cfg
797    }
798}
799
800impl ScalingsMaker for Model {
801    fn dtype(&self) -> DType {
802        self.dtype
803    }
804    fn get_cache(&self) -> &EitherCache {
805        &self.cache
806    }
807    fn get_classifier(&self) -> &XLoraClassifier {
808        self.xlora_classifier.as_ref().unwrap()
809    }
810    fn forward(
811        &self,
812        input_ids: &Tensor,
813        seqlen_offsets: &[usize],
814        scalings: Tensor,
815        is_full_pass: bool,
816        no_kv_cache: bool,
817        is_scaling_pass: Option<f64>,
818        context_lens: &[usize],
819        flash_params: &FlashParams,
820    ) -> Result<Tensor> {
821        // NOTE(EricLBuehler): hacky yes, but passing the context lens to start the position ids calculation works
822        self.inner_forward(
823            input_ids,
824            seqlen_offsets,
825            context_lens,
826            Some(scalings),
827            is_full_pass,
828            no_kv_cache,
829            is_scaling_pass,
830            flash_params,
831        )
832    }
833}
834
835impl AnyMoeBaseModelMixin for Model {}