mistralrs_core/xlora_models/
gemma.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{collections::HashMap, sync::Arc};
4
5use crate::{
6    amoe::AnyMoeBaseModelMixin,
7    attention::SdpaParams,
8    layers::{self, Activation, RmsNorm, RotaryEmbedding, Sdpa},
9    lora::{linear_b as linear, LinearLayerLike, LoraConfig, Ordering},
10    paged_attention::ModelConfigMetadata,
11    pipeline::{
12        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
13        EitherCache, IsqModel, NormalLoadingMetadata,
14    },
15    utils::progress::NiceProgressBar,
16};
17use candle_core::{DType, Device, Module, Result, Tensor};
18use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23    device_map::DeviceMapper,
24    layers::CausalMasker,
25    models::gemma::Config,
26    pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31fn default_max_position_embeddings() -> usize {
32    4096
33}
34
35#[derive(Clone)]
36#[allow(clippy::upper_case_acronyms)]
37struct MLP {
38    gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
39    up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
40    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
41    act_fn: Activation,
42}
43
44impl MLP {
45    #[allow(clippy::too_many_arguments)]
46    fn new(
47        cfg: &Config,
48        vb: ShardedVarBuilder,
49        lora_config: &[((String, String), LoraConfig)],
50        count: &mut usize,
51        ord: &Ordering,
52        mapper: &dyn DeviceMapper,
53        layer_idx: usize,
54        loading_isq: bool,
55        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
56    ) -> Result<Self> {
57        let hidden_sz = cfg.hidden_size;
58        let intermediate_sz = cfg.intermediate_size;
59        let gate_proj = linear(
60            hidden_sz,
61            intermediate_sz,
62            false,
63            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
64            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
65            lora_config,
66            count,
67            ord,
68            preload_adapters,
69        )?;
70        let up_proj = linear(
71            hidden_sz,
72            intermediate_sz,
73            false,
74            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
75            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
76            lora_config,
77            count,
78            ord,
79            preload_adapters,
80        )?;
81        let down_proj = linear(
82            intermediate_sz,
83            hidden_sz,
84            false,
85            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
86            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
87            lora_config,
88            count,
89            ord,
90            preload_adapters,
91        )?;
92        Ok(Self {
93            gate_proj,
94            up_proj,
95            down_proj,
96            act_fn: cfg.hidden_act()?,
97        })
98    }
99
100    fn forward(
101        &self,
102        xs: &Tensor,
103        scalings: Option<Tensor>,
104        global_scaling_weight: f64,
105        is_scaling_pass: Option<f64>,
106    ) -> Result<Tensor> {
107        let original_dtype = xs.dtype();
108        let mut xs = xs.clone();
109        if let Some(t) = self.gate_proj.quantized_act_type() {
110            xs = xs.to_dtype(t)?;
111        }
112        let lhs = self
113            .gate_proj
114            .lora_forward(
115                &xs,
116                scalings.clone(),
117                global_scaling_weight,
118                is_scaling_pass,
119            )?
120            .apply(&self.act_fn)?;
121        let rhs = self.up_proj.lora_forward(
122            &xs,
123            scalings.clone(),
124            global_scaling_weight,
125            is_scaling_pass,
126        )?;
127        let mut res = self.down_proj.lora_forward(
128            &(lhs * rhs)?,
129            scalings,
130            global_scaling_weight,
131            is_scaling_pass,
132        )?;
133        if self.gate_proj.quantized_act_type().is_some() {
134            res = res.to_dtype(original_dtype)?;
135        }
136        Ok(res)
137    }
138}
139
140struct Attention {
141    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
142    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
143    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
144    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
145    num_heads: usize,
146    num_kv_heads: usize,
147    head_dim: usize,
148    rotary_emb: Arc<RotaryEmbedding>,
149    sdpa_params: SdpaParams,
150}
151
152impl Attention {
153    #[allow(clippy::too_many_arguments)]
154    fn new(
155        rotary_emb: Arc<RotaryEmbedding>,
156        cfg: &Config,
157        vb: ShardedVarBuilder,
158        lora_config: &[((String, String), LoraConfig)],
159        count: &mut usize,
160        ord: &Ordering,
161        mapper: &dyn DeviceMapper,
162        layer_idx: usize,
163        loading_isq: bool,
164        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
165    ) -> Result<Self> {
166        let hidden_sz = cfg.hidden_size;
167        let num_heads = cfg.num_attention_heads;
168        let num_kv_heads = cfg.num_key_value_heads;
169        let head_dim = cfg.head_dim;
170        let bias = cfg.attention_bias;
171        let q_proj = linear(
172            hidden_sz,
173            num_heads * head_dim,
174            bias,
175            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
176            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
177            lora_config,
178            count,
179            ord,
180            preload_adapters,
181        )?;
182        let k_proj = linear(
183            hidden_sz,
184            num_kv_heads * head_dim,
185            bias,
186            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
187            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
188            lora_config,
189            count,
190            ord,
191            preload_adapters,
192        )?;
193        let v_proj = linear(
194            hidden_sz,
195            num_kv_heads * head_dim,
196            bias,
197            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
198            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
199            lora_config,
200            count,
201            ord,
202            preload_adapters,
203        )?;
204        let o_proj = linear(
205            num_heads * head_dim,
206            hidden_sz,
207            bias,
208            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
209            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
210            lora_config,
211            count,
212            ord,
213            preload_adapters,
214        )?;
215        Ok(Self {
216            q_proj,
217            k_proj,
218            v_proj,
219            o_proj,
220            num_heads,
221            num_kv_heads,
222            head_dim,
223            rotary_emb,
224            sdpa_params: SdpaParams {
225                n_kv_groups: num_heads / num_kv_heads,
226                use_flash_attn: cfg.use_flash_attn,
227                softcap: None,
228                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
229                sliding_window: None,
230            },
231        })
232    }
233
234    #[allow(clippy::too_many_arguments)]
235    fn forward(
236        &self,
237        xs: &Tensor,
238        attention_mask: Option<&Tensor>,
239        seqlen_offsets: &[usize],
240        kv_cache: &mut Option<(Tensor, Tensor)>,
241        scalings: Option<Tensor>,
242        global_scaling_weight: f64,
243        is_scaling_pass: Option<f64>,
244        flash_params: &FlashParams,
245    ) -> Result<Tensor> {
246        let (b_sz, q_len, _) = xs.dims3()?;
247
248        let original_dtype = xs.dtype();
249        let mut xs = xs.clone();
250        if let Some(t) = self.q_proj.quantized_act_type() {
251            xs = xs.to_dtype(t)?;
252        }
253        let mut q = self.q_proj.lora_forward(
254            &xs,
255            scalings.clone(),
256            global_scaling_weight,
257            is_scaling_pass,
258        )?;
259        let mut k = self.k_proj.lora_forward(
260            &xs,
261            scalings.clone(),
262            global_scaling_weight,
263            is_scaling_pass,
264        )?;
265        let mut v = self.v_proj.lora_forward(
266            &xs,
267            scalings.clone(),
268            global_scaling_weight,
269            is_scaling_pass,
270        )?;
271        if self.q_proj.quantized_act_type().is_some() {
272            q = q.to_dtype(original_dtype)?;
273            k = k.to_dtype(original_dtype)?;
274            v = v.to_dtype(original_dtype)?;
275        }
276
277        let (q, k, v) = if q_len != 1 {
278            let q = q
279                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
280                .transpose(1, 2)?;
281            let k = k
282                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
283                .transpose(1, 2)?;
284            let v = v
285                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
286                .transpose(1, 2)?;
287            (q, k, v)
288        } else {
289            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
290            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
291            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
292            (q, k, v)
293        };
294
295        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
296
297        let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
298
299        let mut attn_output = Sdpa.run_attention(
300            &q,
301            &k,
302            &v,
303            attention_mask,
304            Some(flash_params),
305            &self.sdpa_params,
306        )?;
307
308        if let Some(t) = self.q_proj.quantized_act_type() {
309            attn_output = attn_output.to_dtype(t)?;
310        }
311        let mut res = self.o_proj.lora_forward(
312            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
313            scalings.clone(),
314            global_scaling_weight,
315            is_scaling_pass,
316        )?;
317        if self.q_proj.quantized_act_type().is_some() {
318            res = res.to_dtype(original_dtype)?;
319        }
320        Ok(res)
321    }
322}
323
324struct DecoderLayer {
325    self_attn: Attention,
326    mlp: MLP,
327    input_layernorm: RmsNorm,
328    post_attention_layernorm: RmsNorm,
329}
330
331impl DecoderLayer {
332    #[allow(clippy::too_many_arguments)]
333    fn new(
334        rotary_emb: Arc<RotaryEmbedding>,
335        cfg: &Config,
336        vb: ShardedVarBuilder,
337        lora_config: &[((String, String), LoraConfig)],
338        count: &mut usize,
339        ord: &Ordering,
340        mapper: &dyn DeviceMapper,
341        layer_idx: usize,
342        loading_isq: bool,
343        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
344    ) -> Result<Self> {
345        let self_attn = Attention::new(
346            rotary_emb,
347            cfg,
348            vb.pp("self_attn"),
349            lora_config,
350            count,
351            ord,
352            mapper,
353            layer_idx,
354            loading_isq,
355            preload_adapters,
356        )?;
357        let mlp = MLP::new(
358            cfg,
359            vb.pp("mlp"),
360            lora_config,
361            count,
362            ord,
363            mapper,
364            layer_idx,
365            loading_isq,
366            preload_adapters,
367        )?;
368        let input_layernorm = RmsNorm::new_gemma(
369            cfg.hidden_size,
370            cfg.rms_norm_eps,
371            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
372        )?;
373        let post_attention_layernorm = RmsNorm::new_gemma(
374            cfg.hidden_size,
375            cfg.rms_norm_eps,
376            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
377        )?;
378        Ok(Self {
379            self_attn,
380            mlp,
381            input_layernorm,
382            post_attention_layernorm,
383        })
384    }
385
386    #[allow(clippy::too_many_arguments)]
387    fn forward(
388        &self,
389        xs: &Tensor,
390        attention_mask: Option<&Tensor>,
391        seqlen_offsets: &[usize],
392        kv_cache: &mut Option<(Tensor, Tensor)>,
393        scalings: Option<Tensor>,
394        global_scaling_weight: f64,
395        is_scaling_pass: Option<f64>,
396        flash_params: &FlashParams,
397    ) -> Result<Tensor> {
398        let residual = xs;
399        let xs = self.input_layernorm.forward(xs)?;
400        let xs = self.self_attn.forward(
401            &xs,
402            attention_mask,
403            seqlen_offsets,
404            kv_cache,
405            scalings.clone(),
406            global_scaling_weight,
407            is_scaling_pass,
408            flash_params,
409        )?;
410        let xs = (xs + residual)?;
411        let residual = &xs;
412        let xs = self.mlp.forward(
413            &xs.apply(&self.post_attention_layernorm)?,
414            scalings.clone(),
415            global_scaling_weight,
416            is_scaling_pass,
417        )?;
418        residual + xs
419    }
420}
421
422pub struct XLoraModel {
423    embed_tokens: candle_nn::Embedding,
424    layers: Vec<DecoderLayer>,
425    norm: RmsNorm,
426    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
427    dtype: DType,
428    hidden_size: usize,
429    device: Device,
430    cache: EitherCache,
431    max_seq_len: usize,
432    xlora_classifier: Option<XLoraClassifier>,
433    mapper: Box<dyn DeviceMapper + Send + Sync>,
434    cfg: ModelConfigMetadata,
435}
436
437impl XLoraModel {
438    #[allow(clippy::too_many_arguments)]
439    pub fn new(
440        cfg: &Config,
441        vb: ShardedVarBuilder,
442        lora_config: &[((String, String), LoraConfig)],
443        xlora_config: Option<XLoraConfig>,
444        xlora_ordering: Ordering,
445        is_gptx: bool,
446        normal_loading_metadata: NormalLoadingMetadata,
447        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
448    ) -> Result<Self> {
449        if let Some(ref quant_cfg) = &cfg.quantization_config {
450            tracing::info!(
451                "Using {} quantization: {}.",
452                quant_cfg.quant_method.to_string(),
453                quant_cfg.get_bits_name(&vb)
454            );
455        }
456        let mapper = normal_loading_metadata.mapper;
457        let vb_m = vb.pp("model");
458
459        let embed_tokens = layers::embedding(
460            cfg.vocab_size,
461            cfg.hidden_size,
462            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
463        )?;
464        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
465        let vb_l = vb_m.pp("layers");
466        let mut ropes = HashMap::new();
467        for layer_idx in 0..cfg.num_hidden_layers {
468            let device = mapper
469                .device_for(layer_idx, false)
470                .unwrap_or(&normal_loading_metadata.real_device);
471            ropes.insert(
472                device.location(),
473                Arc::new(RotaryEmbedding::new(
474                    cfg.rope_theta as f32,
475                    cfg.head_dim,
476                    cfg.max_position_embeddings,
477                    device,
478                    is_gptx,
479                    vb.dtype(),
480                )?),
481            );
482        }
483
484        let mut count = 0;
485        for layer_idx in NiceProgressBar::<_, 'b'>(
486            0..cfg.num_hidden_layers,
487            "Loading repeating layers",
488            &normal_loading_metadata.multi_progress,
489        ) {
490            let device = mapper
491                .device_for(layer_idx, false)
492                .unwrap_or(&normal_loading_metadata.real_device);
493            let rotary_emb = ropes
494                .get(&device.location())
495                .expect("No RoPE for device location!")
496                .clone();
497            let layer = DecoderLayer::new(
498                rotary_emb.clone(),
499                cfg,
500                vb_l.pp(layer_idx),
501                lora_config,
502                &mut count,
503                &xlora_ordering,
504                &*mapper,
505                layer_idx,
506                normal_loading_metadata.loading_isq,
507                preload_adapters,
508            )?;
509            layers.push(layer)
510        }
511        if xlora_config.is_none() && preload_adapters.is_none() {
512            // We are now a LoRA model so we must merge the weights
513            info!("Merging LoRA adapters.");
514            for layer in layers.iter_mut().tqdm() {
515                Arc::get_mut(&mut layer.self_attn.k_proj)
516                    .unwrap()
517                    .merge_weights()?;
518                Arc::get_mut(&mut layer.self_attn.o_proj)
519                    .unwrap()
520                    .merge_weights()?;
521                Arc::get_mut(&mut layer.self_attn.q_proj)
522                    .unwrap()
523                    .merge_weights()?;
524                Arc::get_mut(&mut layer.self_attn.v_proj)
525                    .unwrap()
526                    .merge_weights()?;
527
528                Arc::get_mut(&mut layer.mlp.down_proj)
529                    .unwrap()
530                    .merge_weights()?;
531                Arc::get_mut(&mut layer.mlp.gate_proj)
532                    .unwrap()
533                    .merge_weights()?;
534                Arc::get_mut(&mut layer.mlp.up_proj)
535                    .unwrap()
536                    .merge_weights()?;
537            }
538        }
539        let norm = RmsNorm::new_gemma(
540            cfg.hidden_size,
541            cfg.rms_norm_eps,
542            mapper.set_nm_device(vb_m.pp("norm"), false),
543        )?;
544        let lm_head = linear(
545            embed_tokens.embeddings().dim(1)?,
546            embed_tokens.embeddings().dim(0)?,
547            false,
548            mapper.set_nm_device(vb_m.pp("embed_tokens"), normal_loading_metadata.loading_isq),
549            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
550            lora_config,
551            &mut count,
552            &xlora_ordering,
553            preload_adapters,
554        )?;
555        if xlora_config.is_some() && lm_head.is_lora() {
556            // This is why we can pass dummy values (..., None, 1.0, None)?
557            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
558        }
559
560        Ok(Self {
561            embed_tokens,
562            layers,
563            norm,
564            lm_head,
565            device: normal_loading_metadata.real_device,
566            dtype: vb.dtype(),
567            hidden_size: cfg.hidden_size,
568            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
569            max_seq_len: default_max_position_embeddings(),
570            xlora_classifier: xlora_config.map(|xlora_config| {
571                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
572            }),
573            mapper,
574            cfg: ModelConfigMetadata {
575                max_seq_len: cfg.max_position_embeddings,
576                num_layers: cfg.num_hidden_layers,
577                hidden_size: cfg.hidden_size,
578                num_kv_heads: cfg.num_key_value_heads,
579                num_attn_heads: cfg.num_attention_heads,
580                sliding_window: None,
581                k_head_dim: cfg.head_dim,
582                v_head_dim: cfg.head_dim,
583            },
584        })
585    }
586
587    #[allow(clippy::too_many_arguments)]
588    fn inner_forward(
589        &self,
590        input_ids: &Tensor,
591        seqlen_offsets: &[usize],
592        scalings: Option<Tensor>,
593        is_full_pass: bool,
594        no_kv_cache: bool,
595        is_scaling_pass: Option<f64>,
596        flash_params: &FlashParams,
597    ) -> Result<Tensor> {
598        let mut cache = if is_full_pass {
599            if no_kv_cache {
600                let mut new_cache = Vec::new();
601                for _ in 0..self.cache.full().xlora_lock().len() {
602                    new_cache.push(None);
603                }
604
605                self.cache.full().xlora_lock().clone_from(&new_cache);
606            }
607            self.cache.full().xlora_lock()
608        } else {
609            self.cache.full().lock()
610        };
611        let xs = self.embed_tokens.forward(input_ids)?;
612        let attention_mask = CausalMasker.make_causal_mask_matrix(
613            input_ids,
614            &*cache,
615            xs.dtype(),
616            self.cfg.num_attn_heads,
617        )?;
618        let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
619        for (i, layer) in self.layers.iter().enumerate() {
620            xs = self.mapper.map(xs, i)?;
621            xs = layer.forward(
622                &xs,
623                attention_mask
624                    .as_ref()
625                    .map(|m| m.to_device(xs.device()).unwrap())
626                    .as_ref(),
627                seqlen_offsets,
628                &mut cache[i],
629                scalings.clone(),
630                self.xlora_classifier
631                    .as_ref()
632                    .map(|classifier| classifier.get_global_scaling_weight())
633                    .unwrap_or(1.0),
634                is_scaling_pass,
635                flash_params,
636            )?
637        }
638        let xs = xs.to_device(&self.device)?;
639        xs.apply(&self.norm)
640    }
641
642    #[allow(clippy::too_many_arguments)]
643    pub fn forward(
644        &self,
645        input_ids: &Tensor,
646        input_ids_full: &Tensor,
647        seqlen_offsets: &[usize],
648        seqlen_offsets_full: &[usize],
649        no_kv_cache: bool,
650        non_granular_state: &Option<NonGranularState>,
651        context_lens: Vec<(usize, usize)>,
652        flash_params: &FlashParams,
653        flash_params_full: &FlashParams,
654    ) -> Result<Tensor> {
655        if self.xlora_classifier.is_some() {
656            let scalings = self.get_scalings(
657                input_ids,
658                input_ids_full,
659                seqlen_offsets,
660                seqlen_offsets_full,
661                no_kv_cache,
662                non_granular_state,
663                &vec![usize::MAX; context_lens.len()],
664                flash_params,
665                flash_params_full,
666            )?;
667
668            if no_kv_cache {
669                let mut res = self
670                    .inner_forward(
671                        input_ids_full,
672                        seqlen_offsets_full,
673                        Some(scalings),
674                        true,
675                        no_kv_cache,
676                        None,
677                        flash_params_full,
678                    )?
679                    .contiguous()?;
680                if let Some(t) = self.lm_head.quantized_act_type() {
681                    res = res.to_dtype(t)?;
682                }
683                extract_logits(
684                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
685                    context_lens,
686                )
687            } else {
688                // is_full_pass=true is ok because no_kv_cache=false
689                let mut res = self
690                    .inner_forward(
691                        input_ids,
692                        seqlen_offsets,
693                        Some(scalings),
694                        true,
695                        no_kv_cache,
696                        None,
697                        flash_params,
698                    )?
699                    .contiguous()?;
700                if let Some(t) = self.lm_head.quantized_act_type() {
701                    res = res.to_dtype(t)?;
702                }
703                extract_logits(
704                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
705                    context_lens,
706                )
707            }
708        } else {
709            let mut res = self
710                .inner_forward(
711                    input_ids,
712                    seqlen_offsets,
713                    None,
714                    false,
715                    no_kv_cache,
716                    None,
717                    flash_params,
718                )?
719                .contiguous()?;
720            if let Some(t) = self.lm_head.quantized_act_type() {
721                res = res.to_dtype(t)?;
722            }
723            extract_logits(
724                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
725                context_lens,
726            )
727        }
728    }
729}
730
731impl IsqModel for XLoraModel {
732    fn get_layers(
733        &mut self,
734    ) -> (
735        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
736        &dyn DeviceMapper,
737    ) {
738        let mut tensors = Vec::new();
739        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
740        for (i, layer) in self.layers.iter_mut().enumerate() {
741            tensors.push((
742                Arc::get_mut(&mut layer.self_attn.q_proj)
743                    .unwrap()
744                    .quant_inner(),
745                Some(i),
746            ));
747            tensors.push((
748                Arc::get_mut(&mut layer.self_attn.k_proj)
749                    .unwrap()
750                    .quant_inner(),
751                Some(i),
752            ));
753            tensors.push((
754                Arc::get_mut(&mut layer.self_attn.v_proj)
755                    .unwrap()
756                    .quant_inner(),
757                Some(i),
758            ));
759            tensors.push((
760                Arc::get_mut(&mut layer.self_attn.o_proj)
761                    .unwrap()
762                    .quant_inner(),
763                Some(i),
764            ));
765            tensors.push((
766                Arc::get_mut(&mut layer.mlp.gate_proj)
767                    .unwrap()
768                    .quant_inner(),
769                Some(i),
770            ));
771            tensors.push((
772                Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
773                Some(i),
774            ));
775            tensors.push((
776                Arc::get_mut(&mut layer.mlp.down_proj)
777                    .unwrap()
778                    .quant_inner(),
779                Some(i),
780            ));
781        }
782        (tensors, &*self.mapper)
783    }
784
785    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
786        panic!("Cannot generate UQFF for an adapter model.")
787    }
788}
789
790impl NormalModel for XLoraModel {
791    fn forward(
792        &self,
793        _input_ids: &Tensor,
794        _seqlen_offsets: &[usize],
795        _context_lens: Vec<(usize, usize)>,
796        _position_ids: Vec<usize>,
797        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
798        _flash_params: &FlashParams,
799    ) -> Result<Tensor> {
800        unreachable!()
801    }
802    fn xlora_forward(
803        &self,
804        input_ids: &Tensor,
805        input_ids_full: &Tensor,
806        seqlen_offsets: &[usize],
807        seqlen_offsets_full: &[usize],
808        no_kv_cache: bool,
809        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
810        context_lens: Vec<(usize, usize)>,
811        _position_ids: Vec<usize>,
812        flash_params: &FlashParams,
813        flash_params_full: &FlashParams,
814    ) -> Result<Tensor> {
815        self.forward(
816            input_ids,
817            input_ids_full,
818            seqlen_offsets,
819            seqlen_offsets_full,
820            no_kv_cache,
821            non_granular_state,
822            context_lens,
823            flash_params,
824            flash_params_full,
825        )
826    }
827    fn cache(&self) -> &EitherCache {
828        &self.cache
829    }
830    fn cache_mut(&mut self) -> &mut EitherCache {
831        &mut self.cache
832    }
833    fn device(&self) -> &Device {
834        &self.device
835    }
836    fn is_xlora(&self) -> bool {
837        true
838    }
839    fn max_seq_len(&self) -> usize {
840        self.max_seq_len
841    }
842    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
843        if self.xlora_classifier.is_some() {
844            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
845        }
846        let mut sum = 0;
847        for layer in self.layers.iter_mut() {
848            sum += Arc::get_mut(&mut layer.self_attn.k_proj)
849                .unwrap()
850                .activate(&adapter_names)?;
851            sum += Arc::get_mut(&mut layer.self_attn.o_proj)
852                .unwrap()
853                .activate(&adapter_names)?;
854            sum += Arc::get_mut(&mut layer.self_attn.q_proj)
855                .unwrap()
856                .activate(&adapter_names)?;
857            sum += Arc::get_mut(&mut layer.self_attn.v_proj)
858                .unwrap()
859                .activate(&adapter_names)?;
860
861            sum += Arc::get_mut(&mut layer.mlp.down_proj)
862                .unwrap()
863                .activate(&adapter_names)?;
864            sum += Arc::get_mut(&mut layer.mlp.gate_proj)
865                .unwrap()
866                .activate(&adapter_names)?;
867            sum += Arc::get_mut(&mut layer.mlp.up_proj)
868                .unwrap()
869                .activate(&adapter_names)?;
870        }
871        Ok(sum)
872    }
873    fn config(&self) -> &ModelConfigMetadata {
874        &self.cfg
875    }
876}
877
878impl ScalingsMaker for XLoraModel {
879    fn dtype(&self) -> DType {
880        self.dtype
881    }
882    fn get_cache(&self) -> &EitherCache {
883        &self.cache
884    }
885    fn get_classifier(&self) -> &XLoraClassifier {
886        self.xlora_classifier.as_ref().unwrap()
887    }
888    fn forward(
889        &self,
890        input_ids: &Tensor,
891        seqlen_offsets: &[usize],
892        scalings: Tensor,
893        is_full_pass: bool,
894        no_kv_cache: bool,
895        is_scaling_pass: Option<f64>,
896        _context_lens: &[usize],
897        flash_params: &FlashParams,
898    ) -> Result<Tensor> {
899        self.inner_forward(
900            input_ids,
901            seqlen_offsets,
902            Some(scalings),
903            is_full_pass,
904            no_kv_cache,
905            is_scaling_pass,
906            flash_params,
907        )
908    }
909}
910
911impl AnyMoeBaseModelMixin for XLoraModel {}