mistralrs_core/xlora_models/
phi2.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6    amoe::AnyMoeBaseModelMixin,
7    attention::SdpaParams,
8    layers::{Activation, RotaryEmbedding, Sdpa},
9    lora::{linear, LinearLayerLike, LoraConfig, Ordering},
10    paged_attention::ModelConfigMetadata,
11    pipeline::{
12        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13        EitherCache, IsqModel, NormalLoadingMetadata,
14    },
15    utils::progress::NiceProgressBar,
16};
17/// Phi model.
18/// https://huggingface.co/microsoft/phi-2
19/// There is an alternative implementation of the phi model in mixformers.rs.
20/// This corresponds to the model update made with the following commit:
21/// https://huggingface.co/microsoft/phi-2/commit/cb2f4533604d8b67de604e7df03bfe6f3ca22869
22use candle_core::{DType, Device, Result, Tensor};
23use candle_nn::{Embedding, LayerNorm};
24use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
25use tqdm::Iter;
26use tracing::info;
27
28use crate::{
29    device_map::DeviceMapper,
30    layers::{embedding, layer_norm, CausalMasker},
31    models::phi2::Config,
32    pipeline::{extract_logits, NormalModel},
33};
34
35use super::{classifier::XLoraClassifier, Cache, NonGranularState, ScalingsMaker, XLoraConfig};
36
37#[derive(Clone)]
38#[allow(clippy::upper_case_acronyms)]
39struct MLP {
40    fc1: Arc<dyn LinearLayerLike + Send + Sync>,
41    fc2: Arc<dyn LinearLayerLike + Send + Sync>,
42    act: Activation,
43}
44
45impl MLP {
46    #[allow(clippy::too_many_arguments)]
47    fn new(
48        cfg: &Config,
49        vb: ShardedVarBuilder,
50        lora_config: &[((String, String), LoraConfig)],
51        count: &mut usize,
52        ord: &Ordering,
53        mapper: &dyn DeviceMapper,
54        layer_idx: usize,
55        loading_isq: bool,
56        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
57    ) -> Result<Self> {
58        let fc1 = linear(
59            cfg.hidden_size,
60            cfg.intermediate_size,
61            mapper.set_device(layer_idx, vb.pp("fc1"), loading_isq),
62            mapper.set_device(layer_idx, vb.pp("fc1"), false),
63            lora_config,
64            count,
65            ord,
66            preload_adapters,
67        )?;
68        let fc2 = linear(
69            cfg.intermediate_size,
70            cfg.hidden_size,
71            mapper.set_device(layer_idx, vb.pp("fc2"), loading_isq),
72            mapper.set_device(layer_idx, vb.pp("fc2"), false),
73            lora_config,
74            count,
75            ord,
76            preload_adapters,
77        )?;
78        Ok(Self {
79            fc1,
80            fc2,
81            // This does not match the mixformers implementation where Gelu is used rather than
82            // GeluNew.
83            act: cfg.hidden_act,
84        })
85    }
86
87    fn forward(
88        &self,
89        xs: &Tensor,
90        scalings: Option<Tensor>,
91        global_scaling_weight: f64,
92        is_scaling_pass: Option<f64>,
93    ) -> Result<Tensor> {
94        let original_dtype = xs.dtype();
95        let mut xs = xs.clone();
96        if let Some(t) = self.fc1.quantized_act_type() {
97            xs = xs.to_dtype(t)?;
98        }
99        let mut res = self.fc2.lora_forward(
100            &self
101                .fc1
102                .lora_forward(
103                    &xs,
104                    scalings.clone(),
105                    global_scaling_weight,
106                    is_scaling_pass,
107                )?
108                .apply(&self.act)?,
109            scalings,
110            global_scaling_weight,
111            is_scaling_pass,
112        )?;
113        if self.fc1.quantized_act_type().is_some() {
114            res = res.to_dtype(original_dtype)?;
115        }
116        Ok(res)
117    }
118}
119
120struct Attention {
121    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
122    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
123    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
124    dense: Arc<dyn LinearLayerLike + Send + Sync>,
125    q_layernorm: Option<LayerNorm>,
126    k_layernorm: Option<LayerNorm>,
127    rotary_emb: Arc<RotaryEmbedding>,
128    num_heads: usize,
129    num_kv_heads: usize,
130    head_dim: usize,
131    sdpa_params: SdpaParams,
132}
133
134impl Attention {
135    #[allow(clippy::too_many_arguments)]
136    fn new(
137        cfg: &Config,
138        vb: ShardedVarBuilder,
139        lora_config: &[((String, String), LoraConfig)],
140        count: &mut usize,
141        ord: &Ordering,
142        mapper: &dyn DeviceMapper,
143        layer_idx: usize,
144        loading_isq: bool,
145        rope: Arc<RotaryEmbedding>,
146        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
147    ) -> Result<Self> {
148        let num_heads = cfg.num_attention_heads;
149        let num_kv_heads = cfg.num_key_value_heads();
150        let head_dim = cfg.head_dim();
151        let q_proj = linear(
152            cfg.hidden_size,
153            num_heads * head_dim,
154            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
155            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
156            lora_config,
157            count,
158            ord,
159            preload_adapters,
160        )?;
161        let k_proj = linear(
162            cfg.hidden_size,
163            num_kv_heads * head_dim,
164            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
165            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
166            lora_config,
167            count,
168            ord,
169            preload_adapters,
170        )?;
171        let v_proj = linear(
172            cfg.hidden_size,
173            num_kv_heads * head_dim,
174            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
175            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
176            lora_config,
177            count,
178            ord,
179            preload_adapters,
180        )?;
181        let dense = linear(
182            num_heads * head_dim,
183            cfg.hidden_size,
184            mapper.set_device(layer_idx, vb.pp("dense"), loading_isq),
185            mapper.set_device(layer_idx, vb.pp("dense"), false),
186            lora_config,
187            count,
188            ord,
189            preload_adapters,
190        )?;
191        let (q_layernorm, k_layernorm) = if cfg.qk_layernorm {
192            let q_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("q_layernorm"))?;
193            let k_layernorm = layer_norm(head_dim, cfg.layer_norm_eps, vb.pp("k_layernorm"))?;
194            (Some(q_layernorm), Some(k_layernorm))
195        } else {
196            (None, None)
197        };
198        Ok(Self {
199            q_proj,
200            k_proj,
201            v_proj,
202            dense,
203            q_layernorm,
204            k_layernorm,
205            rotary_emb: rope,
206            num_heads,
207            num_kv_heads,
208            head_dim,
209            sdpa_params: SdpaParams {
210                n_kv_groups: num_heads / num_kv_heads,
211                use_flash_attn: cfg.use_flash_attn,
212                softcap: None,
213                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
214                sliding_window: None,
215            },
216        })
217    }
218
219    #[allow(clippy::too_many_arguments)]
220    fn forward(
221        &self,
222        xs: &Tensor,
223        mask: Option<&Tensor>,
224        seqlen_offsets: &[usize],
225        kv_cache: &mut Option<(Tensor, Tensor)>,
226        scalings: Option<Tensor>,
227        global_scaling_weight: f64,
228        is_scaling_pass: Option<f64>,
229        flash_params: &FlashParams,
230    ) -> Result<Tensor> {
231        let (b_size, seq_len, _n_embd) = xs.dims3()?;
232        let original_dtype = xs.dtype();
233        let mut xs = xs.clone();
234        if let Some(t) = self.q_proj.quantized_act_type() {
235            xs = xs.to_dtype(t)?;
236        }
237        let mut q = self.q_proj.lora_forward(
238            &xs,
239            scalings.clone(),
240            global_scaling_weight,
241            is_scaling_pass,
242        )?;
243        let mut k = self.k_proj.lora_forward(
244            &xs,
245            scalings.clone(),
246            global_scaling_weight,
247            is_scaling_pass,
248        )?;
249        let mut v = self.v_proj.lora_forward(
250            &xs,
251            scalings.clone(),
252            global_scaling_weight,
253            is_scaling_pass,
254        )?;
255        if self.q_proj.quantized_act_type().is_some() {
256            q = q.to_dtype(original_dtype)?;
257            k = k.to_dtype(original_dtype)?;
258            v = v.to_dtype(original_dtype)?;
259        }
260
261        let q = match &self.q_layernorm {
262            None => q,
263            Some(ln) => q.apply(ln)?,
264        };
265        let k = match &self.k_layernorm {
266            None => k,
267            Some(ln) => k.apply(ln)?,
268        };
269
270        let (q, k, v) = if seq_len != 1 {
271            let q = q
272                .reshape((b_size, seq_len, self.num_heads, self.head_dim))?
273                .transpose(1, 2)?;
274            let k = k
275                .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
276                .transpose(1, 2)?;
277            let v = v
278                .reshape((b_size, seq_len, self.num_kv_heads, self.head_dim))?
279                .transpose(1, 2)?;
280            (q, k, v)
281        } else {
282            let q = q.reshape((b_size, self.num_heads, seq_len, self.head_dim))?;
283            let k = k.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
284            let v = v.reshape((b_size, self.num_kv_heads, seq_len, self.head_dim))?;
285            (q, k, v)
286        };
287
288        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
289
290        let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
291
292        let attn_output =
293            Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?;
294
295        let mut attn_output = attn_output
296            .transpose(1, 2)?
297            .reshape((b_size, seq_len, ()))?;
298        if let Some(t) = self.q_proj.quantized_act_type() {
299            attn_output = attn_output.to_dtype(t)?;
300        }
301        let mut res = self.dense.lora_forward(
302            &attn_output,
303            scalings,
304            global_scaling_weight,
305            is_scaling_pass,
306        )?;
307        if self.q_proj.quantized_act_type().is_some() {
308            res = res.to_dtype(original_dtype)?;
309        }
310        Ok(res)
311    }
312}
313
314struct DecoderLayer {
315    self_attn: Attention,
316    mlp: MLP,
317    input_layernorm: LayerNorm,
318}
319
320impl DecoderLayer {
321    #[allow(clippy::too_many_arguments)]
322    fn new(
323        cfg: &Config,
324        vb: ShardedVarBuilder,
325        lora_config: &[((String, String), LoraConfig)],
326        count: &mut usize,
327        ord: &Ordering,
328        mapper: &dyn DeviceMapper,
329        layer_idx: usize,
330        loading_isq: bool,
331        rope: Arc<RotaryEmbedding>,
332        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
333    ) -> Result<Self> {
334        let self_attn = Attention::new(
335            cfg,
336            vb.pp("self_attn"),
337            lora_config,
338            count,
339            ord,
340            mapper,
341            layer_idx,
342            loading_isq,
343            rope,
344            preload_adapters,
345        )?;
346        let mlp = MLP::new(
347            cfg,
348            vb.pp("mlp"),
349            lora_config,
350            count,
351            ord,
352            mapper,
353            layer_idx,
354            loading_isq,
355            preload_adapters,
356        )?;
357        let input_layernorm = layer_norm(
358            cfg.hidden_size,
359            cfg.layer_norm_eps,
360            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
361        )?;
362        Ok(Self {
363            self_attn,
364            mlp,
365            input_layernorm,
366        })
367    }
368
369    #[allow(clippy::too_many_arguments)]
370    fn forward(
371        &self,
372        xs: &Tensor,
373        mask: Option<&Tensor>,
374        seqlen_offsets: &[usize],
375        kv_cache: &mut Option<(Tensor, Tensor)>,
376        scalings: Option<Tensor>,
377        global_scaling_weight: f64,
378        is_scaling_pass: Option<f64>,
379        flash_params: &FlashParams,
380    ) -> Result<Tensor> {
381        let residual = xs;
382        let xs = xs.apply(&self.input_layernorm)?;
383        let attn_outputs = self.self_attn.forward(
384            &xs,
385            mask,
386            seqlen_offsets,
387            kv_cache,
388            scalings.clone(),
389            global_scaling_weight,
390            is_scaling_pass,
391            flash_params,
392        )?;
393        let feed_forward_hidden_states =
394            self.mlp
395                .forward(&xs, scalings, global_scaling_weight, is_scaling_pass)?;
396        attn_outputs + feed_forward_hidden_states + residual
397    }
398}
399
400pub struct Model {
401    embed_tokens: Embedding,
402    layers: Vec<DecoderLayer>,
403    final_layernorm: LayerNorm,
404    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
405    cache: EitherCache,
406    device: Device,
407    max_seq_len: usize,
408    xlora_classifier: Option<XLoraClassifier>,
409    dtype: DType,
410    mapper: Box<dyn DeviceMapper + Send + Sync>,
411    cfg: ModelConfigMetadata,
412}
413
414impl Model {
415    #[allow(clippy::too_many_arguments)]
416    pub fn new(
417        cfg: &Config,
418        vb: ShardedVarBuilder,
419        lora_config: &[((String, String), LoraConfig)],
420        xlora_config: Option<XLoraConfig>,
421        xlora_ordering: Ordering,
422        is_gptx: bool,
423        normal_loading_metadata: NormalLoadingMetadata,
424        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
425    ) -> Result<Self> {
426        if let Some(ref quant_cfg) = &cfg.quantization_config {
427            tracing::info!(
428                "Using {} quantization: {}.",
429                quant_cfg.quant_method.to_string(),
430                quant_cfg.get_bits_name(&vb)
431            );
432        }
433        let mapper = normal_loading_metadata.mapper;
434        let vb_m = vb.pp("model");
435
436        let embed_tokens = embedding(
437            cfg.vocab_size,
438            cfg.hidden_size,
439            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
440        )?;
441        let final_layernorm = layer_norm(
442            cfg.hidden_size,
443            cfg.layer_norm_eps,
444            mapper.set_nm_device(vb_m.pp("final_layernorm"), false),
445        )?;
446        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
447        let vb_m = vb_m.pp("layers");
448        let mut ropes = HashMap::new();
449        for layer_idx in 0..cfg.num_hidden_layers {
450            let device = mapper
451                .device_for(layer_idx, false)
452                .unwrap_or(&normal_loading_metadata.real_device);
453            // Alternative rope scalings are not supported
454            ropes.insert(
455                device.location(),
456                Arc::new(RotaryEmbedding::new_partial(
457                    cfg.rope_theta,
458                    (cfg.partial_rotary_factor * cfg.head_dim() as f64) as usize,
459                    cfg.max_position_embeddings,
460                    device,
461                    is_gptx,
462                    vb.dtype(),
463                )?),
464            );
465        }
466        let mut count = 0;
467        for layer_idx in NiceProgressBar::<_, 'b'>(
468            0..cfg.num_hidden_layers,
469            "Loading repeating layers",
470            &normal_loading_metadata.multi_progress,
471        ) {
472            let device = mapper
473                .device_for(layer_idx, false)
474                .unwrap_or(&normal_loading_metadata.real_device);
475            let rotary_emb = ropes
476                .get(&device.location())
477                .expect("No RoPE for device location!")
478                .clone();
479            let layer = DecoderLayer::new(
480                cfg,
481                vb_m.pp(layer_idx),
482                lora_config,
483                &mut count,
484                &xlora_ordering,
485                &*mapper,
486                layer_idx,
487                normal_loading_metadata.loading_isq,
488                rotary_emb,
489                preload_adapters,
490            )?;
491            layers.push(layer)
492        }
493        if xlora_config.is_none() && preload_adapters.is_none() {
494            // We are now a LoRA model so we must merge the weights
495            info!("Merging LoRA adapters.");
496            for layer in layers.iter_mut().tqdm() {
497                Arc::get_mut(&mut layer.self_attn.k_proj)
498                    .unwrap()
499                    .merge_weights()?;
500                Arc::get_mut(&mut layer.self_attn.dense)
501                    .unwrap()
502                    .merge_weights()?;
503                Arc::get_mut(&mut layer.self_attn.q_proj)
504                    .unwrap()
505                    .merge_weights()?;
506                Arc::get_mut(&mut layer.self_attn.v_proj)
507                    .unwrap()
508                    .merge_weights()?;
509
510                Arc::get_mut(&mut layer.mlp.fc1).unwrap().merge_weights()?;
511                Arc::get_mut(&mut layer.mlp.fc2).unwrap().merge_weights()?;
512            }
513        }
514        let lm_head = linear(
515            cfg.hidden_size,
516            cfg.vocab_size,
517            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
518            mapper.set_nm_device(vb.pp("lm_head"), false),
519            lora_config,
520            &mut count,
521            &xlora_ordering,
522            preload_adapters,
523        )?;
524        if xlora_config.is_some() && lm_head.is_lora() {
525            // This is why we can pass dummy values (..., None, 1.0, None)?
526            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
527        }
528        Ok(Self {
529            embed_tokens,
530            layers,
531            final_layernorm,
532            lm_head,
533            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
534            device: normal_loading_metadata.real_device,
535            max_seq_len: cfg.max_position_embeddings,
536            dtype: vb.dtype(),
537            xlora_classifier: xlora_config.map(|xlora_config| {
538                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
539            }),
540            mapper,
541            cfg: ModelConfigMetadata {
542                max_seq_len: cfg.max_position_embeddings,
543                num_layers: cfg.num_hidden_layers,
544                hidden_size: cfg.hidden_size,
545                num_kv_heads: cfg.num_key_value_heads(),
546                num_attn_heads: cfg.num_attention_heads,
547                sliding_window: None,
548                k_head_dim: cfg.head_dim(),
549                v_head_dim: cfg.head_dim(),
550            },
551        })
552    }
553
554    #[allow(clippy::too_many_arguments)]
555    fn inner_forward(
556        &self,
557        input_ids: &Tensor,
558        seqlen_offsets: &[usize],
559        scalings: Option<Tensor>,
560        is_full_pass: bool,
561        no_kv_cache: bool,
562        is_scaling_pass: Option<f64>,
563        flash_params: &FlashParams,
564    ) -> Result<Tensor> {
565        let mut xs = input_ids.apply(&self.embed_tokens)?;
566        let mut cache = if is_full_pass {
567            if no_kv_cache {
568                let mut new_cache = Vec::new();
569                for _ in 0..self.cache.full().xlora_lock().len() {
570                    new_cache.push(None);
571                }
572
573                self.cache.full().xlora_lock().clone_from(&new_cache);
574            }
575            self.cache.full().xlora_lock()
576        } else {
577            self.cache.full().lock()
578        };
579        let mask = CausalMasker.make_causal_mask_matrix(
580            input_ids,
581            &*cache,
582            xs.dtype(),
583            self.cfg.num_attn_heads,
584        )?;
585        for (i, layer) in self.layers.iter().enumerate() {
586            xs = self.mapper.map(xs, i)?;
587            xs = layer.forward(
588                &xs,
589                mask.as_ref()
590                    .map(|m| m.to_device(xs.device()).unwrap())
591                    .as_ref(),
592                seqlen_offsets,
593                &mut cache[i],
594                scalings.clone(),
595                self.xlora_classifier
596                    .as_ref()
597                    .map(|classifier| classifier.get_global_scaling_weight())
598                    .unwrap_or(1.0),
599                is_scaling_pass,
600                flash_params,
601            )?;
602        }
603        let xs = xs.to_device(&self.device)?;
604        xs.apply(&self.final_layernorm)
605    }
606
607    #[allow(clippy::too_many_arguments)]
608    pub fn forward(
609        &self,
610        input_ids: &Tensor,
611        input_ids_full: &Tensor,
612        seqlen_offsets: &[usize],
613        seqlen_offsets_full: &[usize],
614        no_kv_cache: bool,
615        non_granular_state: &Option<NonGranularState>,
616        context_lens: Vec<(usize, usize)>,
617        flash_params: &FlashParams,
618        flash_params_full: &FlashParams,
619    ) -> Result<Tensor> {
620        if self.xlora_classifier.is_some() {
621            let scalings = self.get_scalings(
622                input_ids,
623                input_ids_full,
624                seqlen_offsets,
625                seqlen_offsets_full,
626                no_kv_cache,
627                non_granular_state,
628                &vec![usize::MAX; context_lens.len()],
629                flash_params,
630                flash_params_full,
631            )?;
632
633            if no_kv_cache {
634                let mut res = self
635                    .inner_forward(
636                        input_ids_full,
637                        seqlen_offsets_full,
638                        Some(scalings),
639                        true,
640                        no_kv_cache,
641                        None,
642                        flash_params_full,
643                    )?
644                    .contiguous()?;
645                if let Some(t) = self.lm_head.quantized_act_type() {
646                    res = res.to_dtype(t)?;
647                }
648                extract_logits(
649                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
650                    context_lens,
651                )
652            } else {
653                // is_full_pass=true is ok because no_kv_cache=false
654                let mut res = self
655                    .inner_forward(
656                        input_ids,
657                        seqlen_offsets,
658                        Some(scalings),
659                        true,
660                        no_kv_cache,
661                        None,
662                        flash_params,
663                    )?
664                    .contiguous()?;
665                if let Some(t) = self.lm_head.quantized_act_type() {
666                    res = res.to_dtype(t)?;
667                }
668                extract_logits(
669                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
670                    context_lens,
671                )
672            }
673        } else {
674            let mut res = self
675                .inner_forward(
676                    input_ids,
677                    seqlen_offsets,
678                    None,
679                    false,
680                    no_kv_cache,
681                    None,
682                    flash_params,
683                )?
684                .contiguous()?;
685            if let Some(t) = self.lm_head.quantized_act_type() {
686                res = res.to_dtype(t)?;
687            }
688            extract_logits(
689                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
690                context_lens,
691            )
692        }
693    }
694}
695
696impl IsqModel for Model {
697    fn get_layers(
698        &mut self,
699    ) -> (
700        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
701        &dyn DeviceMapper,
702    ) {
703        let mut tensors = Vec::new();
704        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
705        for (i, layer) in self.layers.iter_mut().enumerate() {
706            tensors.push((
707                Arc::get_mut(&mut layer.self_attn.q_proj)
708                    .unwrap()
709                    .quant_inner(),
710                Some(i),
711            ));
712            tensors.push((
713                Arc::get_mut(&mut layer.self_attn.k_proj)
714                    .unwrap()
715                    .quant_inner(),
716                Some(i),
717            ));
718            tensors.push((
719                Arc::get_mut(&mut layer.self_attn.v_proj)
720                    .unwrap()
721                    .quant_inner(),
722                Some(i),
723            ));
724            tensors.push((
725                Arc::get_mut(&mut layer.self_attn.dense)
726                    .unwrap()
727                    .quant_inner(),
728                Some(i),
729            ));
730            tensors.push((
731                Arc::get_mut(&mut layer.mlp.fc1).unwrap().quant_inner(),
732                Some(i),
733            ));
734            tensors.push((
735                Arc::get_mut(&mut layer.mlp.fc2).unwrap().quant_inner(),
736                Some(i),
737            ));
738        }
739        (tensors, &*self.mapper)
740    }
741
742    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
743        panic!("Cannot generate UQFF for an adapter model.")
744    }
745}
746
747impl NormalModel for Model {
748    fn forward(
749        &self,
750        _input_ids: &Tensor,
751        _seqlen_offsets: &[usize],
752        _context_lens: Vec<(usize, usize)>,
753        _position_ids: Vec<usize>,
754        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
755        _flash_params: &FlashParams,
756    ) -> Result<Tensor> {
757        unreachable!()
758    }
759    fn xlora_forward(
760        &self,
761        input_ids: &Tensor,
762        input_ids_full: &Tensor,
763        seqlen_offsets: &[usize],
764        seqlen_offsets_full: &[usize],
765        no_kv_cache: bool,
766        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
767        context_lens: Vec<(usize, usize)>,
768        _position_ids: Vec<usize>,
769        flash_params: &FlashParams,
770        flash_params_full: &FlashParams,
771    ) -> Result<Tensor> {
772        self.forward(
773            input_ids,
774            input_ids_full,
775            seqlen_offsets,
776            seqlen_offsets_full,
777            no_kv_cache,
778            non_granular_state,
779            context_lens,
780            flash_params,
781            flash_params_full,
782        )
783    }
784    fn cache(&self) -> &EitherCache {
785        &self.cache
786    }
787    fn cache_mut(&mut self) -> &mut EitherCache {
788        &mut self.cache
789    }
790    fn device(&self) -> &Device {
791        &self.device
792    }
793    fn is_xlora(&self) -> bool {
794        true
795    }
796    fn max_seq_len(&self) -> usize {
797        self.max_seq_len
798    }
799    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
800        if self.xlora_classifier.is_some() {
801            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
802        }
803        let mut sum = 0;
804        for layer in self.layers.iter_mut() {
805            sum += Arc::get_mut(&mut layer.self_attn.k_proj)
806                .unwrap()
807                .activate(&adapter_names)?;
808            sum += Arc::get_mut(&mut layer.self_attn.dense)
809                .unwrap()
810                .activate(&adapter_names)?;
811            sum += Arc::get_mut(&mut layer.self_attn.q_proj)
812                .unwrap()
813                .activate(&adapter_names)?;
814            sum += Arc::get_mut(&mut layer.self_attn.v_proj)
815                .unwrap()
816                .activate(&adapter_names)?;
817
818            sum += Arc::get_mut(&mut layer.mlp.fc1)
819                .unwrap()
820                .activate(&adapter_names)?;
821            sum += Arc::get_mut(&mut layer.mlp.fc2)
822                .unwrap()
823                .activate(&adapter_names)?;
824        }
825        Ok(sum)
826    }
827    fn config(&self) -> &ModelConfigMetadata {
828        &self.cfg
829    }
830}
831
832impl ScalingsMaker for Model {
833    fn dtype(&self) -> DType {
834        self.dtype
835    }
836    fn get_cache(&self) -> &EitherCache {
837        &self.cache
838    }
839    fn get_classifier(&self) -> &XLoraClassifier {
840        self.xlora_classifier.as_ref().unwrap()
841    }
842    fn forward(
843        &self,
844        input_ids: &Tensor,
845        seqlen_offsets: &[usize],
846        scalings: Tensor,
847        is_full_pass: bool,
848        no_kv_cache: bool,
849        is_scaling_pass: Option<f64>,
850        _context_lens: &[usize],
851        flash_params: &FlashParams,
852    ) -> Result<Tensor> {
853        self.inner_forward(
854            input_ids,
855            seqlen_offsets,
856            Some(scalings),
857            is_full_pass,
858            no_kv_cache,
859            is_scaling_pass,
860            flash_params,
861        )
862    }
863}
864
865impl AnyMoeBaseModelMixin for Model {}