mistralrs_core/xlora_models/
mistral.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{self, RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel, NormalLoadingMetadata,
12    },
13    utils::progress::NiceProgressBar,
14};
15/// Mistral LLM, https://github.com/mistralai/mistral-src
16use candle_core::{DType, Device, Module, Result, Tensor};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23    device_map::DeviceMapper,
24    layers::{Activation, CausalMasker, RmsNorm},
25    models::mistral::Config,
26    pipeline::{extract_logits, Cache, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, config::XLoraConfig, NonGranularState, ScalingsMaker};
30
31#[derive(Clone)]
32#[allow(clippy::upper_case_acronyms)]
33struct MLP {
34    gate_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    up_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    down_proj: Arc<dyn LinearLayerLike + Send + Sync>,
37    act_fn: Activation,
38}
39
40impl MLP {
41    #[allow(clippy::too_many_arguments)]
42    fn new(
43        cfg: &Config,
44        vb: ShardedVarBuilder,
45        lora_config: &[((String, String), LoraConfig)],
46        count: &mut usize,
47        ord: &Ordering,
48        mapper: &dyn DeviceMapper,
49        layer_idx: usize,
50        loading_isq: bool,
51        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
52    ) -> Result<Self> {
53        let hidden_sz = cfg.hidden_size;
54        let intermediate_sz = cfg.intermediate_size;
55        let gate_proj = linear_no_bias(
56            hidden_sz,
57            intermediate_sz,
58            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
59            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
60            lora_config,
61            count,
62            ord,
63            preload_adapters,
64        )?;
65        let up_proj = linear_no_bias(
66            hidden_sz,
67            intermediate_sz,
68            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
69            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
70            lora_config,
71            count,
72            ord,
73            preload_adapters,
74        )?;
75        let down_proj = linear_no_bias(
76            intermediate_sz,
77            hidden_sz,
78            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
79            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
80            lora_config,
81            count,
82            ord,
83            preload_adapters,
84        )?;
85        Ok(Self {
86            gate_proj,
87            up_proj,
88            down_proj,
89            act_fn: cfg.hidden_act,
90        })
91    }
92
93    fn forward(
94        &self,
95        xs: &Tensor,
96        scalings: Option<Tensor>,
97        global_scaling_weight: f64,
98        is_scaling_pass: Option<f64>,
99    ) -> Result<Tensor> {
100        let original_dtype = xs.dtype();
101        let mut xs = xs.clone();
102        if let Some(t) = self.gate_proj.quantized_act_type() {
103            xs = xs.to_dtype(t)?;
104        }
105        let lhs = self
106            .gate_proj
107            .lora_forward(
108                &xs,
109                scalings.clone(),
110                global_scaling_weight,
111                is_scaling_pass,
112            )?
113            .apply(&self.act_fn)?;
114        let rhs = self.up_proj.lora_forward(
115            &xs,
116            scalings.clone(),
117            global_scaling_weight,
118            is_scaling_pass,
119        )?;
120        let mut res = self.down_proj.lora_forward(
121            &(lhs * rhs)?,
122            scalings,
123            global_scaling_weight,
124            is_scaling_pass,
125        )?;
126        if self.gate_proj.quantized_act_type().is_some() {
127            res = res.to_dtype(original_dtype)?;
128        }
129        Ok(res)
130    }
131}
132
133struct Attention {
134    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
135    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
136    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
137    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
138    num_heads: usize,
139    num_kv_heads: usize,
140    head_dim: usize,
141    rotary_emb: Arc<RotaryEmbedding>,
142    sliding_window: Option<usize>,
143    sdpa_params: SdpaParams,
144}
145
146impl Attention {
147    #[allow(clippy::too_many_arguments)]
148    fn new(
149        rotary_emb: Arc<RotaryEmbedding>,
150        cfg: &Config,
151        vb: ShardedVarBuilder,
152        lora_config: &[((String, String), LoraConfig)],
153        count: &mut usize,
154        ord: &Ordering,
155        mapper: &dyn DeviceMapper,
156        layer_idx: usize,
157        loading_isq: bool,
158        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
159    ) -> Result<Self> {
160        let hidden_sz = cfg.hidden_size;
161        let num_heads = cfg.num_attention_heads;
162        let num_kv_heads = cfg.num_key_value_heads;
163        let head_dim = cfg.head_dim();
164        let q_proj = linear_no_bias(
165            hidden_sz,
166            num_heads * head_dim,
167            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
168            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
169            lora_config,
170            count,
171            ord,
172            preload_adapters,
173        )?;
174        let k_proj = linear_no_bias(
175            hidden_sz,
176            num_kv_heads * head_dim,
177            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
178            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
179            lora_config,
180            count,
181            ord,
182            preload_adapters,
183        )?;
184        let v_proj = linear_no_bias(
185            hidden_sz,
186            num_kv_heads * head_dim,
187            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
188            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
189            lora_config,
190            count,
191            ord,
192            preload_adapters,
193        )?;
194        let o_proj = linear_no_bias(
195            num_heads * head_dim,
196            hidden_sz,
197            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
198            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
199            lora_config,
200            count,
201            ord,
202            preload_adapters,
203        )?;
204        Ok(Self {
205            q_proj,
206            k_proj,
207            v_proj,
208            o_proj,
209            num_heads,
210            num_kv_heads,
211            head_dim,
212            rotary_emb,
213            sliding_window: cfg.sliding_window,
214            sdpa_params: SdpaParams {
215                n_kv_groups: num_heads / num_kv_heads,
216                use_flash_attn: cfg.use_flash_attn,
217                softcap: None,
218                softmax_scale: 1.0 / (head_dim as f32).sqrt(),
219                sliding_window: cfg.sliding_window,
220            },
221        })
222    }
223
224    #[allow(clippy::too_many_arguments)]
225    fn forward(
226        &self,
227        xs: &Tensor,
228        attention_mask: Option<&Tensor>,
229        seqlen_offsets: &[usize],
230        kv_cache: &mut Option<(Tensor, Tensor)>,
231        scalings: Option<Tensor>,
232        global_scaling_weight: f64,
233        is_scaling_pass: Option<f64>,
234        flash_params: &FlashParams,
235    ) -> Result<Tensor> {
236        let (b_sz, q_len, _) = xs.dims3()?;
237
238        let original_dtype = xs.dtype();
239        let mut xs = xs.clone();
240        if let Some(t) = self.q_proj.quantized_act_type() {
241            xs = xs.to_dtype(t)?;
242        }
243        let mut q = self.q_proj.lora_forward(
244            &xs,
245            scalings.clone(),
246            global_scaling_weight,
247            is_scaling_pass,
248        )?;
249        let mut k = self.k_proj.lora_forward(
250            &xs,
251            scalings.clone(),
252            global_scaling_weight,
253            is_scaling_pass,
254        )?;
255        let mut v = self.v_proj.lora_forward(
256            &xs,
257            scalings.clone(),
258            global_scaling_weight,
259            is_scaling_pass,
260        )?;
261        if self.q_proj.quantized_act_type().is_some() {
262            q = q.to_dtype(original_dtype)?;
263            k = k.to_dtype(original_dtype)?;
264            v = v.to_dtype(original_dtype)?;
265        }
266
267        let (q, k, v) = if q_len != 1 {
268            let q = q
269                .reshape((b_sz, q_len, self.num_heads, self.head_dim))?
270                .transpose(1, 2)?;
271            let k = k
272                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
273                .transpose(1, 2)?;
274            let v = v
275                .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
276                .transpose(1, 2)?;
277            (q, k, v)
278        } else {
279            let q = q.reshape((b_sz, self.num_heads, q_len, self.head_dim))?;
280            let k = k.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
281            let v = v.reshape((b_sz, self.num_kv_heads, q_len, self.head_dim))?;
282            (q, k, v)
283        };
284
285        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
286
287        let (k, v, attn_mask) = Cache::update_kv_cache_sliding_window(
288            kv_cache,
289            k,
290            v,
291            attention_mask,
292            self.sliding_window,
293            false,
294        )?;
295
296        let mut attn_output = Sdpa.run_attention(
297            &q,
298            &k,
299            &v,
300            attn_mask.as_ref(),
301            Some(flash_params),
302            &self.sdpa_params,
303        )?;
304
305        if let Some(t) = self.q_proj.quantized_act_type() {
306            attn_output = attn_output.to_dtype(t)?;
307        }
308        let mut res = self.o_proj.lora_forward(
309            &attn_output.transpose(1, 2)?.reshape((b_sz, q_len, ()))?,
310            scalings.clone(),
311            global_scaling_weight,
312            is_scaling_pass,
313        )?;
314        if self.q_proj.quantized_act_type().is_some() {
315            res = res.to_dtype(original_dtype)?;
316        }
317        Ok(res)
318    }
319}
320
321struct DecoderLayer {
322    self_attn: Attention,
323    mlp: MLP,
324    input_layernorm: RmsNorm,
325    post_attention_layernorm: RmsNorm,
326}
327
328impl DecoderLayer {
329    #[allow(clippy::too_many_arguments)]
330    fn new(
331        rotary_emb: Arc<RotaryEmbedding>,
332        cfg: &Config,
333        vb: ShardedVarBuilder,
334        lora_config: &[((String, String), LoraConfig)],
335        count: &mut usize,
336        ord: &Ordering,
337        mapper: &dyn DeviceMapper,
338        layer_idx: usize,
339        loading_isq: bool,
340        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
341    ) -> Result<Self> {
342        let self_attn = Attention::new(
343            rotary_emb,
344            cfg,
345            vb.pp("self_attn"),
346            lora_config,
347            count,
348            ord,
349            mapper,
350            layer_idx,
351            loading_isq,
352            preload_adapters,
353        )?;
354        let mlp = MLP::new(
355            cfg,
356            vb.pp("mlp"),
357            lora_config,
358            count,
359            ord,
360            mapper,
361            layer_idx,
362            loading_isq,
363            preload_adapters,
364        )?;
365        let input_layernorm = RmsNorm::new(
366            cfg.hidden_size,
367            cfg.rms_norm_eps,
368            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
369        )?;
370        let post_attention_layernorm = RmsNorm::new(
371            cfg.hidden_size,
372            cfg.rms_norm_eps,
373            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
374        )?;
375        Ok(Self {
376            self_attn,
377            mlp,
378            input_layernorm,
379            post_attention_layernorm,
380        })
381    }
382
383    #[allow(clippy::too_many_arguments)]
384    fn forward(
385        &self,
386        xs: &Tensor,
387        attention_mask: Option<&Tensor>,
388        seqlen_offsets: &[usize],
389        kv_cache: &mut Option<(Tensor, Tensor)>,
390        scalings: Option<Tensor>,
391        global_scaling_weight: f64,
392        is_scaling_pass: Option<f64>,
393        flash_params: &FlashParams,
394    ) -> Result<Tensor> {
395        let residual = xs;
396        let xs = self.input_layernorm.forward(xs)?;
397        let xs = self.self_attn.forward(
398            &xs,
399            attention_mask,
400            seqlen_offsets,
401            kv_cache,
402            scalings.clone(),
403            global_scaling_weight,
404            is_scaling_pass,
405            flash_params,
406        )?;
407        let xs = (xs + residual)?;
408        let residual = &xs;
409        let xs = self.mlp.forward(
410            &xs.apply(&self.post_attention_layernorm)?,
411            scalings,
412            global_scaling_weight,
413            is_scaling_pass,
414        )?;
415        residual + xs
416    }
417}
418
419pub struct XLoraModel {
420    embed_tokens: candle_nn::Embedding,
421    layers: Vec<DecoderLayer>,
422    norm: RmsNorm,
423    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
424    sliding_window: Option<usize>,
425    dtype: DType,
426    device: Device,
427    cache: EitherCache,
428    max_seq_len: usize,
429    xlora_classifier: Option<XLoraClassifier>,
430    mapper: Box<dyn DeviceMapper + Send + Sync>,
431    cfg: ModelConfigMetadata,
432}
433
434impl XLoraModel {
435    #[allow(clippy::too_many_arguments)]
436    pub fn new(
437        cfg: &Config,
438        vb: ShardedVarBuilder,
439        lora_config: &[((String, String), LoraConfig)],
440        xlora_config: Option<XLoraConfig>,
441        xlora_ordering: Ordering,
442        is_gptx: bool,
443        normal_loading_metadata: NormalLoadingMetadata,
444        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
445    ) -> Result<Self> {
446        if let Some(ref quant_cfg) = &cfg.quantization_config {
447            tracing::info!(
448                "Using {} quantization: {}.",
449                quant_cfg.name(),
450                quant_cfg.get_bits_name(&vb)
451            );
452        }
453        let mapper = normal_loading_metadata.mapper;
454
455        let vb_m = vb.pp("model");
456        let embed_tokens = layers::embedding(
457            cfg.vocab_size,
458            cfg.hidden_size,
459            mapper.set_nm_device(vb_m.pp("embed_tokens"), false),
460            &cfg.quantization_config,
461        )?;
462        let head_dim = cfg.head_dim();
463        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
464        let vb_l = vb_m.pp("layers");
465        let mut ropes = HashMap::new();
466        for layer_idx in 0..cfg.num_hidden_layers {
467            let device = mapper
468                .device_for(layer_idx, false)
469                .unwrap_or(&normal_loading_metadata.real_device);
470            ropes.insert(
471                device.location(),
472                Arc::new(RotaryEmbedding::new(
473                    cfg.rope_theta as f32,
474                    head_dim,
475                    cfg.max_position_embeddings,
476                    device,
477                    is_gptx,
478                    vb_m.dtype(),
479                )?),
480            );
481        }
482        let mut count = 0;
483        for layer_idx in NiceProgressBar::<_, 'b'>(
484            0..cfg.num_hidden_layers,
485            "Loading repeating layers",
486            &normal_loading_metadata.multi_progress,
487        ) {
488            let device = mapper
489                .device_for(layer_idx, false)
490                .unwrap_or(&normal_loading_metadata.real_device);
491            let rotary_emb = ropes
492                .get(&device.location())
493                .expect("No RoPE for device location!")
494                .clone();
495            let layer = DecoderLayer::new(
496                rotary_emb.clone(),
497                cfg,
498                vb_l.pp(layer_idx),
499                lora_config,
500                &mut count,
501                &xlora_ordering,
502                &*mapper,
503                layer_idx,
504                normal_loading_metadata.loading_isq,
505                preload_adapters,
506            )?;
507            layers.push(layer)
508        }
509        if xlora_config.is_none() && preload_adapters.is_none() {
510            // We are now a LoRA model so we must merge the weights
511            info!("Merging LoRA adapters.");
512            for layer in layers.iter_mut().tqdm() {
513                Arc::get_mut(&mut layer.self_attn.k_proj)
514                    .unwrap()
515                    .merge_weights()?;
516                Arc::get_mut(&mut layer.self_attn.o_proj)
517                    .unwrap()
518                    .merge_weights()?;
519                Arc::get_mut(&mut layer.self_attn.q_proj)
520                    .unwrap()
521                    .merge_weights()?;
522                Arc::get_mut(&mut layer.self_attn.v_proj)
523                    .unwrap()
524                    .merge_weights()?;
525
526                Arc::get_mut(&mut layer.mlp.down_proj)
527                    .unwrap()
528                    .merge_weights()?;
529                Arc::get_mut(&mut layer.mlp.gate_proj)
530                    .unwrap()
531                    .merge_weights()?;
532                Arc::get_mut(&mut layer.mlp.up_proj)
533                    .unwrap()
534                    .merge_weights()?;
535            }
536        }
537        let norm = RmsNorm::new(
538            cfg.hidden_size,
539            cfg.rms_norm_eps,
540            mapper.set_nm_device(vb_m.pp("norm"), false),
541        )?;
542        let lm_head = linear_no_bias(
543            cfg.hidden_size,
544            cfg.vocab_size,
545            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
546            mapper.set_nm_device(vb.pp("lm_head"), false),
547            lora_config,
548            &mut count,
549            &xlora_ordering,
550            preload_adapters,
551        )?;
552        if xlora_config.is_some() && lm_head.is_lora() {
553            // This is why we can pass dummy values (..., None, 1.0, None)?
554            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
555        }
556        Ok(Self {
557            embed_tokens,
558            layers,
559            norm,
560            lm_head,
561            sliding_window: cfg.sliding_window,
562            device: normal_loading_metadata.real_device,
563            dtype: vb.dtype(),
564            cache: EitherCache::Full(Cache::new(cfg.num_hidden_layers, true)),
565            max_seq_len: cfg.max_position_embeddings,
566            xlora_classifier: xlora_config.map(|xlora_config| {
567                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
568            }),
569            mapper,
570            cfg: ModelConfigMetadata {
571                max_seq_len: cfg.max_position_embeddings,
572                num_layers: cfg.num_hidden_layers,
573                hidden_size: cfg.hidden_size,
574                num_kv_heads: cfg.num_key_value_heads,
575                num_attn_heads: cfg.num_attention_heads,
576                sliding_window: cfg.sliding_window,
577                k_head_dim: cfg.head_dim(),
578                v_head_dim: cfg.head_dim(),
579            },
580        })
581    }
582
583    #[allow(clippy::too_many_arguments)]
584    fn inner_forward(
585        &self,
586        input_ids: &Tensor,
587        seqlen_offsets: &[usize],
588        scalings: Option<Tensor>,
589        is_full_pass: bool,
590        no_kv_cache: bool,
591        is_scaling_pass: Option<f64>,
592        flash_params: &FlashParams,
593    ) -> Result<Tensor> {
594        let mut cache = if is_full_pass {
595            if no_kv_cache {
596                let mut new_cache = Vec::new();
597                for _ in 0..self.cache.full().xlora_lock().len() {
598                    new_cache.push(None);
599                }
600
601                self.cache.full().xlora_lock().clone_from(&new_cache);
602            }
603            self.cache.full().xlora_lock()
604        } else {
605            self.cache.full().lock()
606        };
607        let mut xs = self.embed_tokens.forward(input_ids)?;
608        let attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
609            input_ids,
610            &*cache,
611            self.sliding_window,
612            xs.dtype(),
613            self.cfg.num_attn_heads,
614        )?;
615        for (i, layer) in self.layers.iter().enumerate() {
616            xs = self.mapper.map(xs, i)?;
617            xs = layer.forward(
618                &xs,
619                attention_mask
620                    .as_ref()
621                    .map(|m| m.to_device(xs.device()).unwrap())
622                    .as_ref(),
623                seqlen_offsets,
624                &mut cache[i],
625                scalings.clone(),
626                self.xlora_classifier
627                    .as_ref()
628                    .map(|classifier| classifier.get_global_scaling_weight())
629                    .unwrap_or(1.0),
630                is_scaling_pass,
631                flash_params,
632            )?
633        }
634        let xs = xs.to_device(&self.device)?;
635        xs.apply(&self.norm)
636    }
637
638    #[allow(clippy::too_many_arguments)]
639    pub fn forward(
640        &self,
641        input_ids: &Tensor,
642        input_ids_full: &Tensor,
643        seqlen_offsets: &[usize],
644        seqlen_offsets_full: &[usize],
645        no_kv_cache: bool,
646        non_granular_state: &Option<NonGranularState>,
647        context_lens: Vec<(usize, usize)>,
648        flash_params: &FlashParams,
649        flash_params_full: &FlashParams,
650    ) -> Result<Tensor> {
651        if self.xlora_classifier.is_some() {
652            let scalings = self.get_scalings(
653                input_ids,
654                input_ids_full,
655                seqlen_offsets,
656                seqlen_offsets_full,
657                no_kv_cache,
658                non_granular_state,
659                &vec![usize::MAX; context_lens.len()],
660                flash_params,
661                flash_params_full,
662            )?;
663
664            if no_kv_cache {
665                let mut res = self
666                    .inner_forward(
667                        input_ids_full,
668                        seqlen_offsets_full,
669                        Some(scalings),
670                        true,
671                        no_kv_cache,
672                        None,
673                        flash_params_full,
674                    )?
675                    .contiguous()?;
676                if let Some(t) = self.lm_head.quantized_act_type() {
677                    res = res.to_dtype(t)?;
678                }
679                extract_logits(
680                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
681                    context_lens,
682                )
683            } else {
684                // is_full_pass=true is ok because no_kv_cache=false
685                let mut res = self
686                    .inner_forward(
687                        input_ids,
688                        seqlen_offsets,
689                        Some(scalings),
690                        true,
691                        no_kv_cache,
692                        None,
693                        flash_params,
694                    )?
695                    .contiguous()?;
696                if let Some(t) = self.lm_head.quantized_act_type() {
697                    res = res.to_dtype(t)?;
698                }
699                extract_logits(
700                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
701                    context_lens,
702                )
703            }
704        } else {
705            let mut res = self
706                .inner_forward(
707                    input_ids,
708                    seqlen_offsets,
709                    None,
710                    false,
711                    no_kv_cache,
712                    None,
713                    flash_params,
714                )?
715                .contiguous()?;
716            if let Some(t) = self.lm_head.quantized_act_type() {
717                res = res.to_dtype(t)?;
718            }
719            extract_logits(
720                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
721                context_lens,
722            )
723        }
724    }
725}
726
727impl IsqModel for XLoraModel {
728    fn get_layers(
729        &mut self,
730    ) -> (
731        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
732        &dyn DeviceMapper,
733    ) {
734        let mut tensors = Vec::new();
735        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
736        for (i, layer) in self.layers.iter_mut().enumerate() {
737            tensors.push((
738                Arc::get_mut(&mut layer.self_attn.q_proj)
739                    .unwrap()
740                    .quant_inner(),
741                Some(i),
742            ));
743            tensors.push((
744                Arc::get_mut(&mut layer.self_attn.k_proj)
745                    .unwrap()
746                    .quant_inner(),
747                Some(i),
748            ));
749            tensors.push((
750                Arc::get_mut(&mut layer.self_attn.v_proj)
751                    .unwrap()
752                    .quant_inner(),
753                Some(i),
754            ));
755            tensors.push((
756                Arc::get_mut(&mut layer.self_attn.o_proj)
757                    .unwrap()
758                    .quant_inner(),
759                Some(i),
760            ));
761            tensors.push((
762                Arc::get_mut(&mut layer.mlp.gate_proj)
763                    .unwrap()
764                    .quant_inner(),
765                Some(i),
766            ));
767            tensors.push((
768                Arc::get_mut(&mut layer.mlp.up_proj).unwrap().quant_inner(),
769                Some(i),
770            ));
771            tensors.push((
772                Arc::get_mut(&mut layer.mlp.down_proj)
773                    .unwrap()
774                    .quant_inner(),
775                Some(i),
776            ));
777        }
778        (tensors, &*self.mapper)
779    }
780
781    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
782        panic!("Cannot generate UQFF for an adapter model.")
783    }
784}
785
786impl NormalModel for XLoraModel {
787    fn forward(
788        &self,
789        _input_ids: &Tensor,
790        _seqlen_offsets: &[usize],
791        _context_lens: Vec<(usize, usize)>,
792        _position_ids: Vec<usize>,
793        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
794        _flash_params: &FlashParams,
795    ) -> Result<Tensor> {
796        unreachable!()
797    }
798    fn xlora_forward(
799        &self,
800        input_ids: &Tensor,
801        input_ids_full: &Tensor,
802        seqlen_offsets: &[usize],
803        seqlen_offsets_full: &[usize],
804        no_kv_cache: bool,
805        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
806        context_lens: Vec<(usize, usize)>,
807        _position_ids: Vec<usize>,
808        flash_params: &FlashParams,
809        flash_params_full: &FlashParams,
810    ) -> Result<Tensor> {
811        self.forward(
812            input_ids,
813            input_ids_full,
814            seqlen_offsets,
815            seqlen_offsets_full,
816            no_kv_cache,
817            non_granular_state,
818            context_lens,
819            flash_params,
820            flash_params_full,
821        )
822    }
823    fn cache(&self) -> &EitherCache {
824        &self.cache
825    }
826    fn cache_mut(&mut self) -> &mut EitherCache {
827        &mut self.cache
828    }
829    fn device(&self) -> &Device {
830        &self.device
831    }
832    fn is_xlora(&self) -> bool {
833        true
834    }
835    fn max_seq_len(&self) -> usize {
836        self.max_seq_len
837    }
838    fn config(&self) -> &ModelConfigMetadata {
839        &self.cfg
840    }
841}
842
843impl ScalingsMaker for XLoraModel {
844    fn dtype(&self) -> DType {
845        self.dtype
846    }
847    fn get_cache(&self) -> &EitherCache {
848        &self.cache
849    }
850    fn get_classifier(&self) -> &XLoraClassifier {
851        self.xlora_classifier.as_ref().unwrap()
852    }
853    fn forward(
854        &self,
855        input_ids: &Tensor,
856        seqlen_offsets: &[usize],
857        scalings: Tensor,
858        is_full_pass: bool,
859        no_kv_cache: bool,
860        is_scaling_pass: Option<f64>,
861        _context_lens: &[usize],
862        flash_params: &FlashParams,
863    ) -> Result<Tensor> {
864        self.inner_forward(
865            input_ids,
866            seqlen_offsets,
867            Some(scalings),
868            is_full_pass,
869            no_kv_cache,
870            is_scaling_pass,
871            flash_params,
872        )
873    }
874}
875
876impl AnyMoeBaseModelMixin for XLoraModel {}