mistralrs_core/xlora_models/
mistral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{self, RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel, NormalLoadingMetadata,
12    },
13    utils::progress::NiceProgressBar,
14};
15/// Mistral LLM, https://github.com/mistralai/mistral-src
16use candle_core::{DType, Device, Module, Result, Tensor};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23    device_map::DeviceMapper,
24    layers::{Activation, CausalMasker, RmsNorm},
25    models::mistral::Config,
26    pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker};
30
31#[derive(Clone)]
32#[allow(clippy::upper_case_acronyms)]
33struct MLP {
34    gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    act_fn: Activation,
38}
39
40impl MLP {
41    #[allow(clippy::too_many_arguments)]
42    fn new(
43        cfg: &Config,
44        vb: ShardedVarBuilder,
45        lora_config: &[((String, String), LoraConfig)],
46        count: &mut usize,
47        ord: &Ordering,
48        mapper: &dyn DeviceMapper,
49        layer_idx: usize,
50        loading_isq: bool,
51        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
52    ) -> Result<Self> {
53        let hidden_sz = cfg.hidden_size;
54        let intermediate_sz = cfg.intermediate_size;
55        let gate_proj = linear_no_bias(
56            hidden_sz,
57            intermediate_sz,
58            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
59            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
60            lora_config,
61            count,
62            ord,
63            preload_adapters,
64        )?;
65        let up_proj = linear_no_bias(
66            hidden_sz,
67            intermediate_sz,
68            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70            lora_config,
71            count,
72            ord,
73            preload_adapters,
74        )?;
75        let down_proj = linear_no_bias(
76            intermediate_sz,
77            hidden_sz,
78            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
79            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
80            lora_config,
81            count,
82            ord,
83            preload_adapters,
84        )?;
85        Ok(Self {
86            gate_proj,
87            up_proj,
88            down_proj,
89            act_fn: cfg.hidden_act,
90        })
91    }
92
93    fn forward(
94        &self,
95        xs: &Tensor,
96        scalings: Option<Tensor>,
97        global_scaling_weight: f64,
98        is_scaling_pass: Option<f64>,
99    ) -> Result<Tensor> {
100        let original_dtype = xs.dtype();
101        let mut xs = xs.clone();
102        if let Some(t) = self.gate_proj.quantized_act_type() {
103            xs = xs.to_dtype(t)?;
104        }
105        let lhs = self
106            .gate_proj
107            .lora_forward(
108                &xs,
109                scalings.clone(),
110                global_scaling_weight,
111                is_scaling_pass,
112            )?
113            .apply(&self.act_fn)?;
114        let rhs = self.up_proj.lora_forward(
115            &xs,
116            scalings.clone(),
117            global_scaling_weight,
118            is_scaling_pass,
119        )?;
120        let mut res = self.down_proj.lora_forward(
121            &(lhs * rhs)?,
122            scalings,
123            global_scaling_weight,
124            is_scaling_pass,
125        )?;
126        if self.gate_proj.quantized_act_type().is_some() {
127            res = res.to_dtype(original_dtype)?;
128        }
129        Ok(res)
130    }
131}
132
133struct Attention {
134    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
135    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138    num_heads: usize,
139    num_kv_heads: usize,
140    head_dim: usize,
141    rotary_emb: Arc<RotaryEmbedding>,
142    sliding_window: Option<usize>,
143    sdpa_params: SdpaParams,
144}
145
146impl Attention {
147    #[allow(clippy::too_many_arguments)]
148    fn new(
149        rotary_emb: Arc<RotaryEmbedding>,
150        cfg: &Config,
151        vb: ShardedVarBuilder,
152        lora_config: &[((String, String), LoraConfig)],
153        count: &mut usize,
154        ord: &Ordering,
155        mapper: &dyn DeviceMapper,
156        layer_idx: usize,
157        loading_isq: bool,
158        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
159    ) -> Result<Self> {
160        let hidden_sz = cfg.hidden_size;
161        let num_heads = cfg.num_attention_heads;
162        let num_kv_heads = cfg.num_key_value_heads;
163        let head_dim = cfg.head_dim();
164        let q_proj = linear_no_bias(
165            hidden_sz,
166            num_heads * head_dim,
167            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
168            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
169            lora_config,
170            count,
171            ord,
172            preload_adapters,
173        )?;
174        let k_proj = linear_no_bias(
175            hidden_sz,
176            num_kv_heads * head_dim,
177            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
178            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
179            lora_config,
180            count,
181            ord,
182            preload_adapters,
183        )?;
184        let v_proj = linear_no_bias(
185            hidden_sz,
186            num_kv_heads * head_dim,
187            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
188            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
189            lora_config,
190            count,
191            ord,
192            preload_adapters,
193        )?;
194        let o_proj = linear_no_bias(
195            num_heads * head_dim,
196            hidden_sz,
197            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
198            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
199            lora_config,
200            count,
201            ord,
202            preload_adapters,
203        )?;
204        Ok(Self {
205            q_proj,
206            k_proj,
207            v_proj,
208            o_proj,
209            num_heads,
210            num_kv_heads,
211            head_dim,
212            rotary_emb,
213            sliding_window: cfg.sliding_window,
214            sdpa_params: SdpaParams {
215                n_kv_groups: num_heads / num_kv_heads,
216                use_flash_attn: cfg.use_flash_attn,
217                softcap: None,
218                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
219                sliding_window: cfg.sliding_window,
220            },
221        })
222    }
223
224    #[allow(clippy::too_many_arguments)]
225    fn forward(
226        &self,
227        xs: &Tensor,
228        attention_mask: Option<&Tensor>,
229        seqlen_offsets: &[usize],
230        kv_cache: &mut Option<(Tensor, Tensor)>,
231        scalings: Option<Tensor>,
232        global_scaling_weight: f64,
233        is_scaling_pass: Option<f64>,
234        flash_params: &FlashParams,
235    ) -> Result<Tensor> {
236        let (b_sz, q_len, _) = xs.dims3()?;
237
238        let original_dtype = xs.dtype();
239        let mut xs = xs.clone();
240        if let Some(t) = self.q_proj.quantized_act_type() {
241            xs = xs.to_dtype(t)?;
242        }
243        let mut q = self.q_proj.lora_forward(
244            &xs,
245            scalings.clone(),
246            global_scaling_weight,
247            is_scaling_pass,
248        )?;
249        let mut k = self.k_proj.lora_forward(
250            &xs,
251            scalings.clone(),
252            global_scaling_weight,
253            is_scaling_pass,
254        )?;
255        let mut v = self.v_proj.lora_forward(
256            &xs,
257            scalings.clone(),
258            global_scaling_weight,
259            is_scaling_pass,
260        )?;
261        if self.q_proj.quantized_act_type().is_some() {
262            q = q.to_dtype(original_dtype)?;
263            k = k.to_dtype(original_dtype)?;
264            v = v.to_dtype(original_dtype)?;
265        }
266
267        let (q, k, v) = if q_len != 1 {
268            let q = q
269                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
270                .transpose(1, 2)?;
271            let k = k
272                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
273                .transpose(1, 2)?;
274            let v = v
275                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
276                .transpose(1, 2)?;
277            (q, k, v)
278        } else {
279            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
280            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
281            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
282            (q, k, v)
283        };
284
285        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
286
287        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
288            kv_cache,
289            k,
290            v,
291            attention_mask,
292            self.sliding_window,
293            false,
294        )?;
295
296        let mut attn_output = Sdpa.run_attention(
297            &q,
298            &k,
299            &v,
300            attn_mask.as_ref(),
301            Some(flash_params),
302            &self.sdpa_params,
303        )?;
304
305        if let Some(t) = self.q_proj.quantized_act_type() {
306            attn_output = attn_output.to_dtype(t)?;
307        }
308        let mut res = self.o_proj.lora_forward(
309            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
310            scalings.clone(),
311            global_scaling_weight,
312            is_scaling_pass,
313        )?;
314        if self.q_proj.quantized_act_type().is_some() {
315            res = res.to_dtype(original_dtype)?;
316        }
317        Ok(res)
318    }
319}
320
321struct DecoderLayer {
322    self_attn: Attention,
323    mlp: MLP,
324    input_layernorm: RmsNorm,
325    post_attention_layernorm: RmsNorm,
326}
327
328impl DecoderLayer {
329    #[allow(clippy::too_many_arguments)]
330    fn new(
331        rotary_emb: Arc<RotaryEmbedding>,
332        cfg: &Config,
333        vb: ShardedVarBuilder,
334        lora_config: &[((String, String), LoraConfig)],
335        count: &mut usize,
336        ord: &Ordering,
337        mapper: &dyn DeviceMapper,
338        layer_idx: usize,
339        loading_isq: bool,
340        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341    ) -> Result<Self> {
342        let self_attn = Attention::new(
343            rotary_emb,
344            cfg,
345            vb.pp("self_attn"),
346            lora_config,
347            count,
348            ord,
349            mapper,
350            layer_idx,
351            loading_isq,
352            preload_adapters,
353        )?;
354        let mlp = MLP::new(
355            cfg,
356            vb.pp("mlp"),
357            lora_config,
358            count,
359            ord,
360            mapper,
361            layer_idx,
362            loading_isq,
363            preload_adapters,
364        )?;
365        let input_layernorm = RmsNorm::new(
366            cfg.hidden_size,
367            cfg.rms_norm_eps,
368            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
369        )?;
370        let post_attention_layernorm = RmsNorm::new(
371            cfg.hidden_size,
372            cfg.rms_norm_eps,
373            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
374        )?;
375        Ok(Self {
376            self_attn,
377            mlp,
378            input_layernorm,
379            post_attention_layernorm,
380        })
381    }
382
383    #[allow(clippy::too_many_arguments)]
384    fn forward(
385        &self,
386        xs: &Tensor,
387        attention_mask: Option<&Tensor>,
388        seqlen_offsets: &[usize],
389        kv_cache: &mut Option<(Tensor, Tensor)>,
390        scalings: Option<Tensor>,
391        global_scaling_weight: f64,
392        is_scaling_pass: Option<f64>,
393        flash_params: &FlashParams,
394    ) -> Result<Tensor> {
395        let residual = xs;
396        let xs = self.input_layernorm.forward(xs)?;
397        let xs = self.self_attn.forward(
398            &xs,
399            attention_mask,
400            seqlen_offsets,
401            kv_cache,
402            scalings.clone(),
403            global_scaling_weight,
404            is_scaling_pass,
405            flash_params,
406        )?;
407        let xs = (xs + residual)?;
408        let residual = &xs;
409        let xs = self.mlp.forward(
410            &xs.apply(&self.post_attention_layernorm)?,
411            scalings,
412            global_scaling_weight,
413            is_scaling_pass,
414        )?;
415        residual + xs
416    }
417}
418
419pub struct XLoraModel {
420    embed_tokens: candle_nn::Embedding,
421    layers: Vec<DecoderLayer>,
422    norm: RmsNorm,
423    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
424    sliding_window: Option<usize>,
425    dtype: DType,
426    device: Device,
427    cache: EitherCache,
428    max_seq_len: usize,
429    xlora_classifier: Option<XLoraClassifier>,
430    mapper: Box<dyn DeviceMapper + Send + Sync>,
431    cfg: ModelConfigMetadata,
432}
433
434impl XLoraModel {
435    #[allow(clippy::too_many_arguments)]
436    pub fn new(
437        cfg: &Config,
438        vb: ShardedVarBuilder,
439        lora_config: &[((String, String), LoraConfig)],
440        xlora_config: Option<XLoraConfig>,
441        xlora_ordering: Ordering,
442        is_gptx: bool,
443        normal_loading_metadata: NormalLoadingMetadata,
444        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
445    ) -> Result<Self> {
446        if let Some(ref quant_cfg) = &cfg.quantization_config {
447            tracing::info!(
448                "Using {} quantization: {}.",
449                quant_cfg.quant_method.to_string(),
450                quant_cfg.get_bits_name(&vb)
451            );
452        }
453        let mapper = normal_loading_metadata.mapper;
454
455        let vb_m = vb.pp("model");
456        let embed_tokens = layers::embedding(
457            cfg.vocab_size,
458            cfg.hidden_size,
459            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
460        )?;
461        let head_dim = cfg.head_dim();
462        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
463        let vb_l = vb_m.pp("layers");
464        let mut ropes = HashMap::new();
465        for layer_idx in 0..cfg.num_hidden_layers {
466            let device = mapper
467                .device_for(layer_idx, false)
468                .unwrap_or(&normal_loading_metadata.real_device);
469            ropes.insert(
470                device.location(),
471                Arc::new(RotaryEmbedding::new(
472                    cfg.rope_theta as f32,
473                    head_dim,
474                    cfg.max_position_embeddings,
475                    device,
476                    is_gptx,
477                    vb_m.dtype(),
478                )?),
479            );
480        }
481        let mut count = 0;
482        for layer_idx in NiceProgressBar::<_, 'b'>(
483            0..cfg.num_hidden_layers,
484            "Loading repeating layers",
485            &normal_loading_metadata.multi_progress,
486        ) {
487            let device = mapper
488                .device_for(layer_idx, false)
489                .unwrap_or(&normal_loading_metadata.real_device);
490            let rotary_emb = ropes
491                .get(&device.location())
492                .expect("No RoPE for device location!")
493                .clone();
494            let layer = DecoderLayer::new(
495                rotary_emb.clone(),
496                cfg,
497                vb_l.pp(layer_idx),
498                lora_config,
499                &mut count,
500                &xlora_ordering,
501                &*mapper,
502                layer_idx,
503                normal_loading_metadata.loading_isq,
504                preload_adapters,
505            )?;
506            layers.push(layer)
507        }
508        if xlora_config.is_none() && preload_adapters.is_none() {
509            // We are now a LoRA model so we must merge the weights
510            info!("Merging LoRA adapters.");
511            for layer in layers.iter_mut().tqdm() {
512                Arc::get_mut(&mut layer.self_attn.k_proj)
513                    .unwrap()
514                    .merge_weights()?;
515                Arc::get_mut(&mut layer.self_attn.o_proj)
516                    .unwrap()
517                    .merge_weights()?;
518                Arc::get_mut(&mut layer.self_attn.q_proj)
519                    .unwrap()
520                    .merge_weights()?;
521                Arc::get_mut(&mut layer.self_attn.v_proj)
522                    .unwrap()
523                    .merge_weights()?;
524
525                Arc::get_mut(&mut layer.mlp.down_proj)
526                    .unwrap()
527                    .merge_weights()?;
528                Arc::get_mut(&mut layer.mlp.gate_proj)
529                    .unwrap()
530                    .merge_weights()?;
531                Arc::get_mut(&mut layer.mlp.up_proj)
532                    .unwrap()
533                    .merge_weights()?;
534            }
535        }
536        let norm = RmsNorm::new(
537            cfg.hidden_size,
538            cfg.rms_norm_eps,
539            mapper.set_nm_device(vb_m.pp("norm"), false),
540        )?;
541        let lm_head = linear_no_bias(
542            cfg.hidden_size,
543            cfg.vocab_size,
544            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
545            mapper.set_nm_device(vb.pp("lm_head"), false),
546            lora_config,
547            &mut count,
548            &xlora_ordering,
549            preload_adapters,
550        )?;
551        if xlora_config.is_some() && lm_head.is_lora() {
552            // This is why we can pass dummy values (..., None, 1.0, None)?
553            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
554        }
555        Ok(Self {
556            embed_tokens,
557            layers,
558            norm,
559            lm_head,
560            sliding_window: cfg.sliding_window,
561            device: normal_loading_metadata.real_device,
562            dtype: vb.dtype(),
563            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
564            max_seq_len: cfg.max_position_embeddings,
565            xlora_classifier: xlora_config.map(|xlora_config| {
566                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
567            }),
568            mapper,
569            cfg: ModelConfigMetadata {
570                max_seq_len: cfg.max_position_embeddings,
571                num_layers: cfg.num_hidden_layers,
572                hidden_size: cfg.hidden_size,
573                num_kv_heads: cfg.num_key_value_heads,
574                num_attn_heads: cfg.num_attention_heads,
575                sliding_window: cfg.sliding_window,
576                k_head_dim: cfg.head_dim(),
577                v_head_dim: cfg.head_dim(),
578            },
579        })
580    }
581
582    #[allow(clippy::too_many_arguments)]
583    fn inner_forward(
584        &self,
585        input_ids: &Tensor,
586        seqlen_offsets: &[usize],
587        scalings: Option<Tensor>,
588        is_full_pass: bool,
589        no_kv_cache: bool,
590        is_scaling_pass: Option<f64>,
591        flash_params: &FlashParams,
592    ) -> Result<Tensor> {
593        let mut cache = if is_full_pass {
594            if no_kv_cache {
595                let mut new_cache = Vec::new();
596                for _ in 0..self.cache.full().xlora_lock().len() {
597                    new_cache.push(None);
598                }
599
600                self.cache.full().xlora_lock().clone_from(&new_cache);
601            }
602            self.cache.full().xlora_lock()
603        } else {
604            self.cache.full().lock()
605        };
606        let mut xs = self.embed_tokens.forward(input_ids)?;
607        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
608            input_ids,
609            &*cache,
610            self.sliding_window,
611            xs.dtype(),
612            self.cfg.num_attn_heads,
613        )?;
614        for (i, layer) in self.layers.iter().enumerate() {
615            xs = self.mapper.map(xs, i)?;
616            xs = layer.forward(
617                &xs,
618                attention_mask
619                    .as_ref()
620                    .map(|m| m.to_device(xs.device()).unwrap())
621                    .as_ref(),
622                seqlen_offsets,
623                &mut cache[i],
624                scalings.clone(),
625                self.xlora_classifier
626                    .as_ref()
627                    .map(|classifier| classifier.get_global_scaling_weight())
628                    .unwrap_or(1.0),
629                is_scaling_pass,
630                flash_params,
631            )?
632        }
633        let xs = xs.to_device(&self.device)?;
634        xs.apply(&self.norm)
635    }
636
637    #[allow(clippy::too_many_arguments)]
638    pub fn forward(
639        &self,
640        input_ids: &Tensor,
641        input_ids_full: &Tensor,
642        seqlen_offsets: &[usize],
643        seqlen_offsets_full: &[usize],
644        no_kv_cache: bool,
645        non_granular_state: &Option<NonGranularState>,
646        context_lens: Vec<(usize, usize)>,
647        flash_params: &FlashParams,
648        flash_params_full: &FlashParams,
649    ) -> Result<Tensor> {
650        if self.xlora_classifier.is_some() {
651            let scalings = self.get_scalings(
652                input_ids,
653                input_ids_full,
654                seqlen_offsets,
655                seqlen_offsets_full,
656                no_kv_cache,
657                non_granular_state,
658                &vec![usize::MAX; context_lens.len()],
659                flash_params,
660                flash_params_full,
661            )?;
662
663            if no_kv_cache {
664                let mut res = self
665                    .inner_forward(
666                        input_ids_full,
667                        seqlen_offsets_full,
668                        Some(scalings),
669                        true,
670                        no_kv_cache,
671                        None,
672                        flash_params_full,
673                    )?
674                    .contiguous()?;
675                if let Some(t) = self.lm_head.quantized_act_type() {
676                    res = res.to_dtype(t)?;
677                }
678                extract_logits(
679                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
680                    context_lens,
681                )
682            } else {
683                // is_full_pass=true is ok because no_kv_cache=false
684                let mut res = self
685                    .inner_forward(
686                        input_ids,
687                        seqlen_offsets,
688                        Some(scalings),
689                        true,
690                        no_kv_cache,
691                        None,
692                        flash_params,
693                    )?
694                    .contiguous()?;
695                if let Some(t) = self.lm_head.quantized_act_type() {
696                    res = res.to_dtype(t)?;
697                }
698                extract_logits(
699                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
700                    context_lens,
701                )
702            }
703        } else {
704            let mut res = self
705                .inner_forward(
706                    input_ids,
707                    seqlen_offsets,
708                    None,
709                    false,
710                    no_kv_cache,
711                    None,
712                    flash_params,
713                )?
714                .contiguous()?;
715            if let Some(t) = self.lm_head.quantized_act_type() {
716                res = res.to_dtype(t)?;
717            }
718            extract_logits(
719                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
720                context_lens,
721            )
722        }
723    }
724}
725
726impl IsqModel for XLoraModel {
727    fn get_layers(
728        &mut self,
729    ) -> (
730        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
731        &dyn DeviceMapper,
732    ) {
733        let mut tensors = Vec::new();
734        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
735        for (i, layer) in self.layers.iter_mut().enumerate() {
736            tensors.push((
737                Arc::get_mut(&mut layer.self_attn.q_proj)
738                    .unwrap()
739                    .quant_inner(),
740                Some(i),
741            ));
742            tensors.push((
743                Arc::get_mut(&mut layer.self_attn.k_proj)
744                    .unwrap()
745                    .quant_inner(),
746                Some(i),
747            ));
748            tensors.push((
749                Arc::get_mut(&mut layer.self_attn.v_proj)
750                    .unwrap()
751                    .quant_inner(),
752                Some(i),
753            ));
754            tensors.push((
755                Arc::get_mut(&mut layer.self_attn.o_proj)
756                    .unwrap()
757                    .quant_inner(),
758                Some(i),
759            ));
760            tensors.push((
761                Arc::get_mut(&mut layer.mlp.gate_proj)
762                    .unwrap()
763                    .quant_inner(),
764                Some(i),
765            ));
766            tensors.push((
767                Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
768                Some(i),
769            ));
770            tensors.push((
771                Arc::get_mut(&mut layer.mlp.down_proj)
772                    .unwrap()
773                    .quant_inner(),
774                Some(i),
775            ));
776        }
777        (tensors, &*self.mapper)
778    }
779
780    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
781        panic!("Cannot generate UQFF for an adapter model.")
782    }
783}
784
785impl NormalModel for XLoraModel {
786    fn forward(
787        &self,
788        _input_ids: &Tensor,
789        _seqlen_offsets: &[usize],
790        _context_lens: Vec<(usize, usize)>,
791        _position_ids: Vec<usize>,
792        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
793        _flash_params: &FlashParams,
794    ) -> Result<Tensor> {
795        unreachable!()
796    }
797    fn xlora_forward(
798        &self,
799        input_ids: &Tensor,
800        input_ids_full: &Tensor,
801        seqlen_offsets: &[usize],
802        seqlen_offsets_full: &[usize],
803        no_kv_cache: bool,
804        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
805        context_lens: Vec<(usize, usize)>,
806        _position_ids: Vec<usize>,
807        flash_params: &FlashParams,
808        flash_params_full: &FlashParams,
809    ) -> Result<Tensor> {
810        self.forward(
811            input_ids,
812            input_ids_full,
813            seqlen_offsets,
814            seqlen_offsets_full,
815            no_kv_cache,
816            non_granular_state,
817            context_lens,
818            flash_params,
819            flash_params_full,
820        )
821    }
822    fn cache(&self) -> &EitherCache {
823        &self.cache
824    }
825    fn cache_mut(&mut self) -> &mut EitherCache {
826        &mut self.cache
827    }
828    fn device(&self) -> &Device {
829        &self.device
830    }
831    fn is_xlora(&self) -> bool {
832        true
833    }
834    fn max_seq_len(&self) -> usize {
835        self.max_seq_len
836    }
837    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
838        if self.xlora_classifier.is_some() {
839            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
840        }
841        let mut sum = 0;
842        for layer in self.layers.iter_mut() {
843            sum += Arc::get_mut(&mut layer.self_attn.k_proj)
844                .unwrap()
845                .activate(&adapter_names)?;
846            sum += Arc::get_mut(&mut layer.self_attn.o_proj)
847                .unwrap()
848                .activate(&adapter_names)?;
849            sum += Arc::get_mut(&mut layer.self_attn.q_proj)
850                .unwrap()
851                .activate(&adapter_names)?;
852            sum += Arc::get_mut(&mut layer.self_attn.v_proj)
853                .unwrap()
854                .activate(&adapter_names)?;
855
856            sum += Arc::get_mut(&mut layer.mlp.down_proj)
857                .unwrap()
858                .activate(&adapter_names)?;
859            sum += Arc::get_mut(&mut layer.mlp.gate_proj)
860                .unwrap()
861                .activate(&adapter_names)?;
862            sum += Arc::get_mut(&mut layer.mlp.up_proj)
863                .unwrap()
864                .activate(&adapter_names)?;
865        }
866        Ok(sum)
867    }
868    fn config(&self) -> &ModelConfigMetadata {
869        &self.cfg
870    }
871}
872
873impl ScalingsMaker for XLoraModel {
874    fn dtype(&self) -> DType {
875        self.dtype
876    }
877    fn get_cache(&self) -> &EitherCache {
878        &self.cache
879    }
880    fn get_classifier(&self) -> &XLoraClassifier {
881        self.xlora_classifier.as_ref().unwrap()
882    }
883    fn forward(
884        &self,
885        input_ids: &Tensor,
886        seqlen_offsets: &[usize],
887        scalings: Tensor,
888        is_full_pass: bool,
889        no_kv_cache: bool,
890        is_scaling_pass: Option<f64>,
891        _context_lens: &[usize],
892        flash_params: &FlashParams,
893    ) -> Result<Tensor> {
894        self.inner_forward(
895            input_ids,
896            seqlen_offsets,
897            Some(scalings),
898            is_full_pass,
899            no_kv_cache,
900            is_scaling_pass,
901            flash_params,
902        )
903    }
904}
905
906impl AnyMoeBaseModelMixin for XLoraModel {}