mistralrs_core/xlora_models/
phi2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6    amoe::AnyMoeBaseModelMixin,
7    attention::SdpaParams,
8    layers::{Activation, RotaryEmbedding, Sdpa},
9    lora::{linear, LinearLayerLike, LoraConfig, Ordering},
10    paged_attention::ModelConfigMetadata,
11    pipeline::{
12        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13        EitherCache, IsqModel, NormalLoadingMetadata,
14    },
15    utils::progress::NiceProgressBar,
16};
17/// Phi model.
18/// https://huggingface.co/microsoft/phi-2
19/// There is an alternative implementation of the phi model in mixformers.rs.
20/// This corresponds to the model update made with the following commit:
21/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
22use candle_core::{DType, Device, Result, Tensor};
23use candle_nn::{Embedding, LayerNorm};
24use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
25use tqdm::Iter;
26use tracing::info;
27
28use crate::{
29    device_map::DeviceMapper,
30    layers::{embedding, layer_norm, CausalMasker},
31    models::phi2::Config,
32    pipeline::{extract_logits, NormalModel},
33};
34
35use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig};
36
37#[derive(Clone)]
38#[allow(clippy::upper_case_acronyms)]
39struct MLP {
40    fc1: Arc<dyn LinearLayerLike + Send + Sync>,
41    fc2: Arc<dyn LinearLayerLike + Send + Sync>,
42    act: Activation,
43}
44
45impl MLP {
46    #[allow(clippy::too_many_arguments)]
47    fn new(
48        cfg: &Config,
49        vb: ShardedVarBuilder,
50        lora_config: &[((String, String), LoraConfig)],
51        count: &mut usize,
52        ord: &Ordering,
53        mapper: &dyn DeviceMapper,
54        layer_idx: usize,
55        loading_isq: bool,
56        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
57    ) -> Result<Self> {
58        let fc1 = linear(
59            cfg.hidden_size,
60            cfg.intermediate_size,
61            mapper.set_device(layer_idx, vb.pp("fc1"), loading_isq),
62            mapper.set_device(layer_idx, vb.pp("fc1"), false),
63            lora_config,
64            count,
65            ord,
66            preload_adapters,
67        )?;
68        let fc2 = linear(
69            cfg.intermediate_size,
70            cfg.hidden_size,
71            mapper.set_device(layer_idx, vb.pp("fc2"), loading_isq),
72            mapper.set_device(layer_idx, vb.pp("fc2"), false),
73            lora_config,
74            count,
75            ord,
76            preload_adapters,
77        )?;
78        Ok(Self {
79            fc1,
80            fc2,
81            // This does not match the mixformers implementation where Gelu is used rather than
82            // GeluNew.
83            act: cfg.hidden_act,
84        })
85    }
86
87    fn forward(
88        &self,
89        xs: &Tensor,
90        scalings: Option<Tensor>,
91        global_scaling_weight: f64,
92        is_scaling_pass: Option<f64>,
93    ) -> Result<Tensor> {
94        let original_dtype = xs.dtype();
95        let mut xs = xs.clone();
96        if let Some(t) = self.fc1.quantized_act_type() {
97            xs = xs.to_dtype(t)?;
98        }
99        let mut res = self.fc2.lora_forward(
100            &self
101                .fc1
102                .lora_forward(
103                    &xs,
104                    scalings.clone(),
105                    global_scaling_weight,
106                    is_scaling_pass,
107                )?
108                .apply(&self.act)?,
109            scalings,
110            global_scaling_weight,
111            is_scaling_pass,
112        )?;
113        if self.fc1.quantized_act_type().is_some() {
114            res = res.to_dtype(original_dtype)?;
115        }
116        Ok(res)
117    }
118}
119
120struct Attention {
121    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
122    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
123    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
124    dense: Arc<dyn LinearLayerLike + Send + Sync>,
125    q_layernorm: Option<LayerNorm>,
126    k_layernorm: Option<LayerNorm>,
127    rotary_emb: Arc<RotaryEmbedding>,
128    num_heads: usize,
129    num_kv_heads: usize,
130    head_dim: usize,
131    sdpa_params: SdpaParams,
132}
133
134impl Attention {
135    #[allow(clippy::too_many_arguments)]
136    fn new(
137        cfg: &Config,
138        vb: ShardedVarBuilder,
139        lora_config: &[((String, String), LoraConfig)],
140        count: &mut usize,
141        ord: &Ordering,
142        mapper: &dyn DeviceMapper,
143        layer_idx: usize,
144        loading_isq: bool,
145        rope: Arc<RotaryEmbedding>,
146        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
147    ) -> Result<Self> {
148        let num_heads = cfg.num_attention_heads;
149        let num_kv_heads = cfg.num_key_value_heads();
150        let head_dim = cfg.head_dim();
151        let q_proj = linear(
152            cfg.hidden_size,
153            num_heads * head_dim,
154            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
155            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
156            lora_config,
157            count,
158            ord,
159            preload_adapters,
160        )?;
161        let k_proj = linear(
162            cfg.hidden_size,
163            num_kv_heads * head_dim,
164            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
165            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
166            lora_config,
167            count,
168            ord,
169            preload_adapters,
170        )?;
171        let v_proj = linear(
172            cfg.hidden_size,
173            num_kv_heads * head_dim,
174            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
175            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
176            lora_config,
177            count,
178            ord,
179            preload_adapters,
180        )?;
181        let dense = linear(
182            num_heads * head_dim,
183            cfg.hidden_size,
184            mapper.set_device(layer_idx, vb.pp("dense"), loading_isq),
185            mapper.set_device(layer_idx, vb.pp("dense"), false),
186            lora_config,
187            count,
188            ord,
189            preload_adapters,
190        )?;
191        let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
192            let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
193            let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
194            (Some(q_layernorm), Some(k_layernorm))
195        } else {
196            (None, None)
197        };
198        Ok(Self {
199            q_proj,
200            k_proj,
201            v_proj,
202            dense,
203            q_layernorm,
204            k_layernorm,
205            rotary_emb: rope,
206            num_heads,
207            num_kv_heads,
208            head_dim,
209            sdpa_params: SdpaParams {
210                n_kv_groups: num_heads / num_kv_heads,
211                use_flash_attn: cfg.use_flash_attn,
212                softcap: None,
213                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
214                sliding_window: None,
215            },
216        })
217    }
218
219    #[allow(clippy::too_many_arguments)]
220    fn forward(
221        &self,
222        xs: &Tensor,
223        mask: Option<&Tensor>,
224        seqlen_offsets: &[usize],
225        kv_cache: &mut Option<(Tensor, Tensor)>,
226        scalings: Option<Tensor>,
227        global_scaling_weight: f64,
228        is_scaling_pass: Option<f64>,
229        flash_params: &FlashParams,
230    ) -> Result<Tensor> {
231        let (b_size, seq_len, _n_embd) = xs.dims3()?;
232        let original_dtype = xs.dtype();
233        let mut xs = xs.clone();
234        if let Some(t) = self.q_proj.quantized_act_type() {
235            xs = xs.to_dtype(t)?;
236        }
237        let mut q = self.q_proj.lora_forward(
238            &xs,
239            scalings.clone(),
240            global_scaling_weight,
241            is_scaling_pass,
242        )?;
243        let mut k = self.k_proj.lora_forward(
244            &xs,
245            scalings.clone(),
246            global_scaling_weight,
247            is_scaling_pass,
248        )?;
249        let mut v = self.v_proj.lora_forward(
250            &xs,
251            scalings.clone(),
252            global_scaling_weight,
253            is_scaling_pass,
254        )?;
255        if self.q_proj.quantized_act_type().is_some() {
256            q = q.to_dtype(original_dtype)?;
257            k = k.to_dtype(original_dtype)?;
258            v = v.to_dtype(original_dtype)?;
259        }
260
261        let q = match &self.q_layernorm {
262            None => q,
263            Some(ln) => q.apply(ln)?,
264        };
265        let k = match &self.k_layernorm {
266            None => k,
267            Some(ln) => k.apply(ln)?,
268        };
269
270        let (q, k, v) = if seq_len != 1 {
271            let q = q
272                .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
273                .transpose(1, 2)?;
274            let k = k
275                .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
276                .transpose(1, 2)?;
277            let v = v
278                .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
279                .transpose(1, 2)?;
280            (q, k, v)
281        } else {
282            let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
283            let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
284            let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
285            (q, k, v)
286        };
287
288        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
289
290        let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
291
292        let attn_output =
293            Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?;
294
295        let mut attn_output = attn_output
296            .transpose(1, 2)?
297            .reshape((b_size, seq_len, ()))?;
298        if let Some(t) = self.q_proj.quantized_act_type() {
299            attn_output = attn_output.to_dtype(t)?;
300        }
301        let mut res = self.dense.lora_forward(
302            &attn_output,
303            scalings,
304            global_scaling_weight,
305            is_scaling_pass,
306        )?;
307        if self.q_proj.quantized_act_type().is_some() {
308            res = res.to_dtype(original_dtype)?;
309        }
310        Ok(res)
311    }
312}
313
314struct DecoderLayer {
315    self_attn: Attention,
316    mlp: MLP,
317    input_layernorm: LayerNorm,
318}
319
320impl DecoderLayer {
321    #[allow(clippy::too_many_arguments)]
322    fn new(
323        cfg: &Config,
324        vb: ShardedVarBuilder,
325        lora_config: &[((String, String), LoraConfig)],
326        count: &mut usize,
327        ord: &Ordering,
328        mapper: &dyn DeviceMapper,
329        layer_idx: usize,
330        loading_isq: bool,
331        rope: Arc<RotaryEmbedding>,
332        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
333    ) -> Result<Self> {
334        let self_attn = Attention::new(
335            cfg,
336            vb.pp("self_attn"),
337            lora_config,
338            count,
339            ord,
340            mapper,
341            layer_idx,
342            loading_isq,
343            rope,
344            preload_adapters,
345        )?;
346        let mlp = MLP::new(
347            cfg,
348            vb.pp("mlp"),
349            lora_config,
350            count,
351            ord,
352            mapper,
353            layer_idx,
354            loading_isq,
355            preload_adapters,
356        )?;
357        let input_layernorm = layer_norm(
358            cfg.hidden_size,
359            cfg.layer_norm_eps,
360            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
361        )?;
362        Ok(Self {
363            self_attn,
364            mlp,
365            input_layernorm,
366        })
367    }
368
369    #[allow(clippy::too_many_arguments)]
370    fn forward(
371        &self,
372        xs: &Tensor,
373        mask: Option<&Tensor>,
374        seqlen_offsets: &[usize],
375        kv_cache: &mut Option<(Tensor, Tensor)>,
376        scalings: Option<Tensor>,
377        global_scaling_weight: f64,
378        is_scaling_pass: Option<f64>,
379        flash_params: &FlashParams,
380    ) -> Result<Tensor> {
381        let residual = xs;
382        let xs = xs.apply(&self.input_layernorm)?;
383        let attn_outputs = self.self_attn.forward(
384            &xs,
385            mask,
386            seqlen_offsets,
387            kv_cache,
388            scalings.clone(),
389            global_scaling_weight,
390            is_scaling_pass,
391            flash_params,
392        )?;
393        let feed_forward_hidden_states =
394            self.mlp
395                .forward(&xs, scalings, global_scaling_weight, is_scaling_pass)?;
396        attn_outputs + feed_forward_hidden_states + residual
397    }
398}
399
400pub struct Model {
401    embed_tokens: Embedding,
402    layers: Vec<DecoderLayer>,
403    final_layernorm: LayerNorm,
404    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
405    cache: EitherCache,
406    device: Device,
407    max_seq_len: usize,
408    xlora_classifier: Option<XLoraClassifier>,
409    dtype: DType,
410    mapper: Box<dyn DeviceMapper + Send + Sync>,
411    cfg: ModelConfigMetadata,
412}
413
414impl Model {
415    #[allow(clippy::too_many_arguments)]
416    pub fn new(
417        cfg: &Config,
418        vb: ShardedVarBuilder,
419        lora_config: &[((String, String), LoraConfig)],
420        xlora_config: Option<XLoraConfig>,
421        xlora_ordering: Ordering,
422        is_gptx: bool,
423        normal_loading_metadata: NormalLoadingMetadata,
424        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
425    ) -> Result<Self> {
426        if let Some(ref quant_cfg) = &cfg.quantization_config {
427            tracing::info!(
428                "Using {} quantization: {}.",
429                quant_cfg.name(),
430                quant_cfg.get_bits_name(&vb)
431            );
432        }
433        let mapper = normal_loading_metadata.mapper;
434        let vb_m = vb.pp("model");
435
436        let embed_tokens = embedding(
437            cfg.vocab_size,
438            cfg.hidden_size,
439            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
440            &cfg.quantization_config,
441        )?;
442        let final_layernorm = layer_norm(
443            cfg.hidden_size,
444            cfg.layer_norm_eps,
445            mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
446        )?;
447        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
448        let vb_m = vb_m.pp("layers");
449        let mut ropes = HashMap::new();
450        for layer_idx in 0..cfg.num_hidden_layers {
451            let device = mapper
452                .device_for(layer_idx, false)
453                .unwrap_or(&normal_loading_metadata.real_device);
454            // Alternative rope scalings are not supported
455            ropes.insert(
456                device.location(),
457                Arc::new(RotaryEmbedding::new_partial(
458                    cfg.rope_theta,
459                    (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
460                    cfg.max_position_embeddings,
461                    device,
462                    is_gptx,
463                    vb.dtype(),
464                )?),
465            );
466        }
467        let mut count = 0;
468        for layer_idx in NiceProgressBar::<_, 'b'>(
469            0..cfg.num_hidden_layers,
470            "Loading repeating layers",
471            &normal_loading_metadata.multi_progress,
472        ) {
473            let device = mapper
474                .device_for(layer_idx, false)
475                .unwrap_or(&normal_loading_metadata.real_device);
476            let rotary_emb = ropes
477                .get(&device.location())
478                .expect("No RoPE for device location!")
479                .clone();
480            let layer = DecoderLayer::new(
481                cfg,
482                vb_m.pp(layer_idx),
483                lora_config,
484                &mut count,
485                &xlora_ordering,
486                &*mapper,
487                layer_idx,
488                normal_loading_metadata.loading_isq,
489                rotary_emb,
490                preload_adapters,
491            )?;
492            layers.push(layer)
493        }
494        if xlora_config.is_none() && preload_adapters.is_none() {
495            // We are now a LoRA model so we must merge the weights
496            info!("Merging LoRA adapters.");
497            for layer in layers.iter_mut().tqdm() {
498                Arc::get_mut(&mut layer.self_attn.k_proj)
499                    .unwrap()
500                    .merge_weights()?;
501                Arc::get_mut(&mut layer.self_attn.dense)
502                    .unwrap()
503                    .merge_weights()?;
504                Arc::get_mut(&mut layer.self_attn.q_proj)
505                    .unwrap()
506                    .merge_weights()?;
507                Arc::get_mut(&mut layer.self_attn.v_proj)
508                    .unwrap()
509                    .merge_weights()?;
510
511                Arc::get_mut(&mut layer.mlp.fc1).unwrap().merge_weights()?;
512                Arc::get_mut(&mut layer.mlp.fc2).unwrap().merge_weights()?;
513            }
514        }
515        let lm_head = linear(
516            cfg.hidden_size,
517            cfg.vocab_size,
518            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
519            mapper.set_nm_device(vb.pp("lm_head"), false),
520            lora_config,
521            &mut count,
522            &xlora_ordering,
523            preload_adapters,
524        )?;
525        if xlora_config.is_some() && lm_head.is_lora() {
526            // This is why we can pass dummy values (..., None, 1.0, None)?
527            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
528        }
529        Ok(Self {
530            embed_tokens,
531            layers,
532            final_layernorm,
533            lm_head,
534            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
535            device: normal_loading_metadata.real_device,
536            max_seq_len: cfg.max_position_embeddings,
537            dtype: vb.dtype(),
538            xlora_classifier: xlora_config.map(|xlora_config| {
539                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
540            }),
541            mapper,
542            cfg: ModelConfigMetadata {
543                max_seq_len: cfg.max_position_embeddings,
544                num_layers: cfg.num_hidden_layers,
545                hidden_size: cfg.hidden_size,
546                num_kv_heads: cfg.num_key_value_heads(),
547                num_attn_heads: cfg.num_attention_heads,
548                sliding_window: None,
549                k_head_dim: cfg.head_dim(),
550                v_head_dim: cfg.head_dim(),
551            },
552        })
553    }
554
555    #[allow(clippy::too_many_arguments)]
556    fn inner_forward(
557        &self,
558        input_ids: &Tensor,
559        seqlen_offsets: &[usize],
560        scalings: Option<Tensor>,
561        is_full_pass: bool,
562        no_kv_cache: bool,
563        is_scaling_pass: Option<f64>,
564        flash_params: &FlashParams,
565    ) -> Result<Tensor> {
566        let mut xs = input_ids.apply(&self.embed_tokens)?;
567        let mut cache = if is_full_pass {
568            if no_kv_cache {
569                let mut new_cache = Vec::new();
570                for _ in 0..self.cache.full().xlora_lock().len() {
571                    new_cache.push(None);
572                }
573
574                self.cache.full().xlora_lock().clone_from(&new_cache);
575            }
576            self.cache.full().xlora_lock()
577        } else {
578            self.cache.full().lock()
579        };
580        let mask = CausalMasker.make_causal_mask_matrix(
581            input_ids,
582            &*cache,
583            xs.dtype(),
584            self.cfg.num_attn_heads,
585        )?;
586        for (i, layer) in self.layers.iter().enumerate() {
587            xs = self.mapper.map(xs, i)?;
588            xs = layer.forward(
589                &xs,
590                mask.as_ref()
591                    .map(|m| m.to_device(xs.device()).unwrap())
592                    .as_ref(),
593                seqlen_offsets,
594                &mut cache[i],
595                scalings.clone(),
596                self.xlora_classifier
597                    .as_ref()
598                    .map(|classifier| classifier.get_global_scaling_weight())
599                    .unwrap_or(1.0),
600                is_scaling_pass,
601                flash_params,
602            )?;
603        }
604        let xs = xs.to_device(&self.device)?;
605        xs.apply(&self.final_layernorm)
606    }
607
608    #[allow(clippy::too_many_arguments)]
609    pub fn forward(
610        &self,
611        input_ids: &Tensor,
612        input_ids_full: &Tensor,
613        seqlen_offsets: &[usize],
614        seqlen_offsets_full: &[usize],
615        no_kv_cache: bool,
616        non_granular_state: &Option<NonGranularState>,
617        context_lens: Vec<(usize, usize)>,
618        flash_params: &FlashParams,
619        flash_params_full: &FlashParams,
620    ) -> Result<Tensor> {
621        if self.xlora_classifier.is_some() {
622            let scalings = self.get_scalings(
623                input_ids,
624                input_ids_full,
625                seqlen_offsets,
626                seqlen_offsets_full,
627                no_kv_cache,
628                non_granular_state,
629                &vec![usize::MAX; context_lens.len()],
630                flash_params,
631                flash_params_full,
632            )?;
633
634            if no_kv_cache {
635                let mut res = self
636                    .inner_forward(
637                        input_ids_full,
638                        seqlen_offsets_full,
639                        Some(scalings),
640                        true,
641                        no_kv_cache,
642                        None,
643                        flash_params_full,
644                    )?
645                    .contiguous()?;
646                if let Some(t) = self.lm_head.quantized_act_type() {
647                    res = res.to_dtype(t)?;
648                }
649                extract_logits(
650                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
651                    context_lens,
652                )
653            } else {
654                // is_full_pass=true is ok because no_kv_cache=false
655                let mut res = self
656                    .inner_forward(
657                        input_ids,
658                        seqlen_offsets,
659                        Some(scalings),
660                        true,
661                        no_kv_cache,
662                        None,
663                        flash_params,
664                    )?
665                    .contiguous()?;
666                if let Some(t) = self.lm_head.quantized_act_type() {
667                    res = res.to_dtype(t)?;
668                }
669                extract_logits(
670                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
671                    context_lens,
672                )
673            }
674        } else {
675            let mut res = self
676                .inner_forward(
677                    input_ids,
678                    seqlen_offsets,
679                    None,
680                    false,
681                    no_kv_cache,
682                    None,
683                    flash_params,
684                )?
685                .contiguous()?;
686            if let Some(t) = self.lm_head.quantized_act_type() {
687                res = res.to_dtype(t)?;
688            }
689            extract_logits(
690                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
691                context_lens,
692            )
693        }
694    }
695}
696
697impl IsqModel for Model {
698    fn get_layers(
699        &mut self,
700    ) -> (
701        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
702        &dyn DeviceMapper,
703    ) {
704        let mut tensors = Vec::new();
705        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
706        for (i, layer) in self.layers.iter_mut().enumerate() {
707            tensors.push((
708                Arc::get_mut(&mut layer.self_attn.q_proj)
709                    .unwrap()
710                    .quant_inner(),
711                Some(i),
712            ));
713            tensors.push((
714                Arc::get_mut(&mut layer.self_attn.k_proj)
715                    .unwrap()
716                    .quant_inner(),
717                Some(i),
718            ));
719            tensors.push((
720                Arc::get_mut(&mut layer.self_attn.v_proj)
721                    .unwrap()
722                    .quant_inner(),
723                Some(i),
724            ));
725            tensors.push((
726                Arc::get_mut(&mut layer.self_attn.dense)
727                    .unwrap()
728                    .quant_inner(),
729                Some(i),
730            ));
731            tensors.push((
732                Arc::get_mut(&mut layer.mlp.fc1).unwrap().quant_inner(),
733                Some(i),
734            ));
735            tensors.push((
736                Arc::get_mut(&mut layer.mlp.fc2).unwrap().quant_inner(),
737                Some(i),
738            ));
739        }
740        (tensors, &*self.mapper)
741    }
742
743    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
744        panic!("Cannot generate UQFF for an adapter model.")
745    }
746}
747
748impl NormalModel for Model {
749    fn forward(
750        &self,
751        _input_ids: &Tensor,
752        _seqlen_offsets: &[usize],
753        _context_lens: Vec<(usize, usize)>,
754        _position_ids: Vec<usize>,
755        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
756        _flash_params: &FlashParams,
757    ) -> Result<Tensor> {
758        unreachable!()
759    }
760    fn xlora_forward(
761        &self,
762        input_ids: &Tensor,
763        input_ids_full: &Tensor,
764        seqlen_offsets: &[usize],
765        seqlen_offsets_full: &[usize],
766        no_kv_cache: bool,
767        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
768        context_lens: Vec<(usize, usize)>,
769        _position_ids: Vec<usize>,
770        flash_params: &FlashParams,
771        flash_params_full: &FlashParams,
772    ) -> Result<Tensor> {
773        self.forward(
774            input_ids,
775            input_ids_full,
776            seqlen_offsets,
777            seqlen_offsets_full,
778            no_kv_cache,
779            non_granular_state,
780            context_lens,
781            flash_params,
782            flash_params_full,
783        )
784    }
785    fn cache(&self) -> &EitherCache {
786        &self.cache
787    }
788    fn cache_mut(&mut self) -> &mut EitherCache {
789        &mut self.cache
790    }
791    fn device(&self) -> &Device {
792        &self.device
793    }
794    fn is_xlora(&self) -> bool {
795        true
796    }
797    fn max_seq_len(&self) -> usize {
798        self.max_seq_len
799    }
800    fn config(&self) -> &ModelConfigMetadata {
801        &self.cfg
802    }
803}
804
805impl ScalingsMaker for Model {
806    fn dtype(&self) -> DType {
807        self.dtype
808    }
809    fn get_cache(&self) -> &EitherCache {
810        &self.cache
811    }
812    fn get_classifier(&self) -> &XLoraClassifier {
813        self.xlora_classifier.as_ref().unwrap()
814    }
815    fn forward(
816        &self,
817        input_ids: &Tensor,
818        seqlen_offsets: &[usize],
819        scalings: Tensor,
820        is_full_pass: bool,
821        no_kv_cache: bool,
822        is_scaling_pass: Option<f64>,
823        _context_lens: &[usize],
824        flash_params: &FlashParams,
825    ) -> Result<Tensor> {
826        self.inner_forward(
827            input_ids,
828            seqlen_offsets,
829            Some(scalings),
830            is_full_pass,
831            no_kv_cache,
832            is_scaling_pass,
833            flash_params,
834        )
835    }
836}
837
838impl AnyMoeBaseModelMixin for Model {}