mistralrs_core/xlora_models/
llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{Llama3RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel,
12    },
13    utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23    device_map::DeviceMapper,
24    layers::{embedding, CausalMasker, RmsNorm},
25    models::llama::Config,
26    pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31struct CausalSelfAttention {
32    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    num_attention_heads: usize,
37    num_key_value_heads: usize,
38    head_dim: usize,
39    rotary_emb: Arc<Llama3RotaryEmbedding>,
40    max_seq_len: usize,
41    sdpa_params: SdpaParams,
42}
43
44impl CausalSelfAttention {
45    #[allow(clippy::too_many_arguments)]
46    fn forward(
47        &self,
48        x: &Tensor,
49        mask: &Option<Tensor>,
50        seqlen_offsets: &[usize],
51        block_idx: usize,
52        kv_cache: &mut LayerCaches,
53        scalings: Option<Tensor>,
54        global_scaling_weight: f64,
55        is_scaling_pass: Option<f64>,
56        flash_params: &FlashParams,
57    ) -> Result<Tensor> {
58        let (b_sz, seq_len, hidden_size) = x.dims3()?;
59
60        let original_dtype = x.dtype();
61        let mut x = x.clone();
62        if let Some(t) = self.q_proj.quantized_act_type() {
63            x = x.to_dtype(t)?;
64        }
65        let mut q = self.q_proj.lora_forward(
66            &x,
67            scalings.clone(),
68            global_scaling_weight,
69            is_scaling_pass,
70        )?;
71        let mut k = self.k_proj.lora_forward(
72            &x,
73            scalings.clone(),
74            global_scaling_weight,
75            is_scaling_pass,
76        )?;
77        let mut v = self.v_proj.lora_forward(
78            &x,
79            scalings.clone(),
80            global_scaling_weight,
81            is_scaling_pass,
82        )?;
83        if self.q_proj.quantized_act_type().is_some() {
84            q = q.to_dtype(original_dtype)?;
85            k = k.to_dtype(original_dtype)?;
86            v = v.to_dtype(original_dtype)?;
87        }
88
89        let (q, k, v) = if seq_len != 1 {
90            let q = q
91                .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
92                .transpose(1, 2)?;
93            let k = k
94                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
95                .transpose(1, 2)?;
96            let v = v
97                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
98                .transpose(1, 2)?;
99            (q, k, v)
100        } else {
101            let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
102            let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
103            let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
104            (q, k, v)
105        };
106
107        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
108
109        let (k, v) =
110            crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?;
111
112        let y = Sdpa.run_attention(
113            &q,
114            &k,
115            &v,
116            mask.clone().as_ref(),
117            Some(flash_params),
118            &self.sdpa_params,
119        )?;
120
121        let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
122        if let Some(t) = self.q_proj.quantized_act_type() {
123            y = y.to_dtype(t)?;
124        }
125        let mut res = self.o_proj.lora_forward(
126            &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?,
127            scalings.clone(),
128            global_scaling_weight,
129            is_scaling_pass,
130        )?;
131        if self.q_proj.quantized_act_type().is_some() {
132            res = res.to_dtype(original_dtype)?;
133        }
134        Ok(res)
135    }
136
137    #[allow(clippy::too_many_arguments)]
138    fn load(
139        vb: ShardedVarBuilder,
140        cfg: &Config,
141        lora_config: &[((String, String), LoraConfig)],
142        count: &mut usize,
143        ord: &Ordering,
144        mapper: &dyn DeviceMapper,
145        layer_idx: usize,
146        loading_isq: bool,
147        rope: Arc<Llama3RotaryEmbedding>,
148        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
149    ) -> Result<Self> {
150        let size_in = cfg.hidden_size;
151        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
152        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
153        let q_proj = linear(
154            size_in,
155            size_q,
156            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
157            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
158            lora_config,
159            count,
160            ord,
161            preload_adapters,
162        )?;
163        let k_proj = linear(
164            size_in,
165            size_kv,
166            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
167            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
168            lora_config,
169            count,
170            ord,
171            preload_adapters,
172        )?;
173        let v_proj = linear(
174            size_in,
175            size_kv,
176            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
177            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
178            lora_config,
179            count,
180            ord,
181            preload_adapters,
182        )?;
183        let o_proj = linear(
184            size_q,
185            size_in,
186            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
187            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
188            lora_config,
189            count,
190            ord,
191            preload_adapters,
192        )?;
193        Ok(Self {
194            q_proj,
195            k_proj,
196            v_proj,
197            o_proj,
198            num_attention_heads: cfg.num_attention_heads,
199            num_key_value_heads: cfg.num_key_value_heads,
200            head_dim: cfg.hidden_size / cfg.num_attention_heads,
201            rotary_emb: rope,
202            max_seq_len: cfg.max_position_embeddings,
203            sdpa_params: SdpaParams {
204                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
205                softcap: None,
206                softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
207                sliding_window: None,
208            },
209        })
210    }
211}
212
213#[derive(Clone)]
214struct Mlp {
215    c_fc1: Arc<dyn LinearLayerLike + Send + Sync>,
216    c_fc2: Arc<dyn LinearLayerLike + Send + Sync>,
217    c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
218}
219
220impl Mlp {
221    fn forward(
222        &self,
223        x: &Tensor,
224        scalings: Option<Tensor>,
225        global_scaling_weight: f64,
226        is_scaling_pass: Option<f64>,
227    ) -> Result<Tensor> {
228        let original_dtype = x.dtype();
229        let mut x = x.clone();
230        if let Some(t) = self.c_fc1.quantized_act_type() {
231            x = x.to_dtype(t)?;
232        }
233        let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward(
234            &x,
235            scalings.clone(),
236            global_scaling_weight,
237            is_scaling_pass,
238        )?)? * self.c_fc2.lora_forward(
239            &x,
240            scalings.clone(),
241            global_scaling_weight,
242            is_scaling_pass,
243        )?)?;
244        let mut res = self.c_proj.lora_forward(
245            &x,
246            scalings.clone(),
247            global_scaling_weight,
248            is_scaling_pass,
249        )?;
250        if self.c_fc1.quantized_act_type().is_some() {
251            res = res.to_dtype(original_dtype)?;
252        }
253        Ok(res)
254    }
255
256    #[allow(clippy::too_many_arguments)]
257    fn load(
258        vb: ShardedVarBuilder,
259        cfg: &Config,
260        lora_config: &[((String, String), LoraConfig)],
261        count: &mut usize,
262        ord: &Ordering,
263        mapper: &dyn DeviceMapper,
264        layer_idx: usize,
265        loading_isq: bool,
266        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
267    ) -> Result<Self> {
268        let h_size = cfg.hidden_size;
269        let i_size = cfg.intermediate_size;
270        let c_fc1 = linear(
271            h_size,
272            i_size,
273            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
274            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
275            lora_config,
276            count,
277            ord,
278            preload_adapters,
279        )?;
280        let c_fc2 = linear(
281            h_size,
282            i_size,
283            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
284            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
285            lora_config,
286            count,
287            ord,
288            preload_adapters,
289        )?;
290        let c_proj = linear(
291            i_size,
292            h_size,
293            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
294            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
295            lora_config,
296            count,
297            ord,
298            preload_adapters,
299        )?;
300        Ok(Self {
301            c_fc1,
302            c_fc2,
303            c_proj,
304        })
305    }
306}
307
308struct Block {
309    rms_1: RmsNorm,
310    attn: CausalSelfAttention,
311    rms_2: RmsNorm,
312    mlp: Mlp,
313}
314
315impl Block {
316    #[allow(clippy::too_many_arguments)]
317    fn forward(
318        &self,
319        x: &Tensor,
320        mask: &Option<Tensor>,
321        seqlen_offsets: &[usize],
322        block_idx: usize,
323        kv_cache: &mut LayerCaches,
324        scalings: Option<Tensor>,
325        global_scaling_weight: f64,
326        is_scaling_pass: Option<f64>,
327        flash_params: &FlashParams,
328    ) -> Result<Tensor> {
329        let residual = x;
330        let x = self.rms_1.forward(x)?;
331        let x = (self.attn.forward(
332            &x,
333            mask,
334            seqlen_offsets,
335            block_idx,
336            kv_cache,
337            scalings.clone(),
338            global_scaling_weight,
339            is_scaling_pass,
340            flash_params,
341        )? + residual)?;
342        let residual = &x;
343        let x = (self.mlp.forward(
344            &self.rms_2.forward(&x)?,
345            scalings,
346            global_scaling_weight,
347            is_scaling_pass,
348        )? + residual)?;
349        Ok(x)
350    }
351
352    #[allow(clippy::too_many_arguments)]
353    fn load(
354        vb: ShardedVarBuilder,
355        cfg: &Config,
356        lora_config: &[((String, String), LoraConfig)],
357        count: &mut usize,
358        ord: &Ordering,
359        mapper: &dyn DeviceMapper,
360        layer_idx: usize,
361        loading_isq: bool,
362        rope: Arc<Llama3RotaryEmbedding>,
363        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
364    ) -> Result<Self> {
365        let attn = CausalSelfAttention::load(
366            vb.pp("self_attn"),
367            cfg,
368            lora_config,
369            count,
370            ord,
371            mapper,
372            layer_idx,
373            loading_isq,
374            rope,
375            preload_adapters,
376        )?;
377        let mlp = Mlp::load(
378            vb.pp("mlp"),
379            cfg,
380            lora_config,
381            count,
382            ord,
383            mapper,
384            layer_idx,
385            loading_isq,
386            preload_adapters,
387        )?;
388        let rms_1 = RmsNorm::new(
389            cfg.hidden_size,
390            cfg.rms_norm_eps,
391            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
392        )?;
393        let rms_2 = RmsNorm::new(
394            cfg.hidden_size,
395            cfg.rms_norm_eps,
396            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
397        )?;
398        Ok(Self {
399            rms_1,
400            attn,
401            rms_2,
402            mlp,
403        })
404    }
405}
406
407pub struct XLoraLlama {
408    wte: Embedding,
409    blocks: Vec<Block>,
410    ln_f: RmsNorm,
411    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
412    kv_cache: pipeline::EitherCache,
413    device: Device,
414    xlora_classifier: Option<XLoraClassifier>,
415    dtype: DType,
416    mapper: Box<dyn DeviceMapper + Send + Sync>,
417    cfg: ModelConfigMetadata,
418}
419
420impl XLoraLlama {
421    #[allow(clippy::too_many_arguments)]
422    fn inner_forward(
423        &self,
424        input_ids: &Tensor,
425        seqlen_offsets: &[usize],
426        scalings: Option<Tensor>,
427        is_full_pass: bool,
428        no_kv_cache: bool,
429        is_scaling_pass: Option<f64>,
430        flash_params: &FlashParams,
431    ) -> Result<Tensor> {
432        let mut x = self.wte.forward(input_ids)?;
433        let mut cache = if is_full_pass {
434            if no_kv_cache {
435                let mut new_cache = Vec::new();
436                for _ in 0..self.kv_cache.full().xlora_lock().len() {
437                    new_cache.push(None);
438                }
439
440                self.kv_cache.full().xlora_lock().clone_from(&new_cache);
441            }
442            self.kv_cache.full().xlora_lock()
443        } else {
444            self.kv_cache.full().lock()
445        };
446        let mask = CausalMasker.make_causal_mask_matrix(
447            input_ids,
448            &*cache,
449            x.dtype(),
450            self.cfg.num_attn_heads,
451        )?;
452        for (block_idx, block) in self.blocks.iter().enumerate() {
453            x = self.mapper.map(x, block_idx)?;
454            x = block.forward(
455                &x,
456                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
457                seqlen_offsets,
458                block_idx,
459                &mut cache,
460                scalings.clone(),
461                self.xlora_classifier
462                    .as_ref()
463                    .map(|classifier| classifier.get_global_scaling_weight())
464                    .unwrap_or(1.0),
465                is_scaling_pass,
466                flash_params,
467            )?;
468        }
469        let x = x.to_device(&self.device)?;
470        self.ln_f.forward(&x)
471    }
472
473    #[allow(clippy::too_many_arguments)]
474    pub fn forward(
475        &self,
476        input_ids: &Tensor,
477        input_ids_full: &Tensor,
478        seqlen_offsets: &[usize],
479        seqlen_offsets_full: &[usize],
480        no_kv_cache: bool,
481        non_granular_state: &Option<NonGranularState>,
482        context_lens: Vec<(usize, usize)>,
483        flash_params: &FlashParams,
484        flash_params_full: &FlashParams,
485    ) -> Result<Tensor> {
486        if self.xlora_classifier.is_some() {
487            let scalings = self.get_scalings(
488                input_ids,
489                input_ids_full,
490                seqlen_offsets,
491                seqlen_offsets_full,
492                no_kv_cache,
493                non_granular_state,
494                &vec![usize::MAX; context_lens.len()],
495                flash_params,
496                flash_params_full,
497            )?;
498
499            if no_kv_cache {
500                let mut res = self
501                    .inner_forward(
502                        input_ids_full,
503                        seqlen_offsets_full,
504                        Some(scalings),
505                        true,
506                        no_kv_cache,
507                        None,
508                        flash_params_full,
509                    )?
510                    .contiguous()?;
511                if let Some(t) = self.lm_head.quantized_act_type() {
512                    res = res.to_dtype(t)?;
513                }
514                extract_logits(
515                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
516                    context_lens,
517                )
518            } else {
519                // is_full_pass=true is ok because no_kv_cache=false
520                let mut res = self
521                    .inner_forward(
522                        input_ids,
523                        seqlen_offsets,
524                        Some(scalings),
525                        true,
526                        no_kv_cache,
527                        None,
528                        flash_params,
529                    )?
530                    .contiguous()?;
531                if let Some(t) = self.lm_head.quantized_act_type() {
532                    res = res.to_dtype(t)?;
533                }
534                extract_logits(
535                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
536                    context_lens,
537                )
538            }
539        } else {
540            let mut res = self
541                .inner_forward(
542                    input_ids,
543                    seqlen_offsets,
544                    None,
545                    false,
546                    no_kv_cache,
547                    None,
548                    flash_params,
549                )?
550                .contiguous()?;
551            if let Some(t) = self.lm_head.quantized_act_type() {
552                res = res.to_dtype(t)?;
553            }
554            extract_logits(
555                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
556                context_lens,
557            )
558        }
559    }
560
561    #[allow(clippy::too_many_arguments)]
562    pub fn new(
563        cfg: &Config,
564        vb: ShardedVarBuilder,
565        lora_config: &[((String, String), LoraConfig)],
566        xlora_config: Option<XLoraConfig>,
567        xlora_ordering: Ordering,
568        is_gptx: bool,
569        normal_loading_metadata: NormalLoadingMetadata,
570        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
571    ) -> Result<Self> {
572        if let Some(ref quant_cfg) = &cfg.quantization_config {
573            tracing::info!(
574                "Using {} quantization: {}.",
575                quant_cfg.name(),
576                quant_cfg.get_bits_name(&vb)
577            );
578        }
579        let mapper = normal_loading_metadata.mapper;
580        let dtype = vb.dtype();
581        let mut count = 0;
582
583        let wte = embedding(
584            cfg.vocab_size,
585            cfg.hidden_size,
586            mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
587            &cfg.quantization_config,
588        )?;
589        let lm_head = linear(
590            cfg.hidden_size,
591            cfg.vocab_size,
592            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
593            mapper.set_nm_device(vb.pp("lm_head"), false),
594            lora_config,
595            &mut count,
596            &xlora_ordering,
597            preload_adapters,
598        )?;
599        if xlora_config.is_some() && lm_head.is_lora() {
600            // This is why we can pass dummy values (..., None, 1.0, None)?
601            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
602        }
603        let ln_f = RmsNorm::new(
604            cfg.hidden_size,
605            cfg.rms_norm_eps,
606            mapper.set_nm_device(vb.pp("model.norm"), false),
607        )?;
608        let mut ropes = HashMap::new();
609        for i in 0..cfg.num_hidden_layers {
610            let device = mapper
611                .device_for(i, false)
612                .unwrap_or(&normal_loading_metadata.real_device);
613            ropes.insert(
614                device.location(),
615                Arc::new(Llama3RotaryEmbedding::new_llama3(
616                    vb.dtype(),
617                    cfg,
618                    device,
619                    is_gptx,
620                )?),
621            );
622        }
623        let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
624            0..cfg.num_hidden_layers,
625            "Loading repeating layers",
626            &normal_loading_metadata.multi_progress,
627        )
628        .into_iter()
629        .map(|i| {
630            let device = mapper
631                .device_for(i, false)
632                .unwrap_or(&normal_loading_metadata.real_device);
633            let rotary_emb = ropes
634                .get(&device.location())
635                .expect("No RoPE for device location!")
636                .clone();
637            Block::load(
638                vb.pp(format!("model.layers.{i}")),
639                cfg,
640                lora_config,
641                &mut count,
642                &xlora_ordering,
643                &*mapper,
644                i,
645                normal_loading_metadata.loading_isq,
646                rotary_emb,
647                preload_adapters,
648            )
649            .expect("Failed to load block.")
650        })
651        .collect();
652        if xlora_config.is_none() && preload_adapters.is_none() {
653            // We are now a LoRA model so we must merge the weights
654            info!("Merging LoRA adapters.");
655            for layer in blocks.iter_mut().tqdm() {
656                Arc::get_mut(&mut layer.attn.k_proj)
657                    .unwrap()
658                    .merge_weights()?;
659                Arc::get_mut(&mut layer.attn.o_proj)
660                    .unwrap()
661                    .merge_weights()?;
662                Arc::get_mut(&mut layer.attn.q_proj)
663                    .unwrap()
664                    .merge_weights()?;
665                Arc::get_mut(&mut layer.attn.v_proj)
666                    .unwrap()
667                    .merge_weights()?;
668
669                Arc::get_mut(&mut layer.mlp.c_fc1)
670                    .unwrap()
671                    .merge_weights()?;
672                Arc::get_mut(&mut layer.mlp.c_fc2)
673                    .unwrap()
674                    .merge_weights()?;
675                Arc::get_mut(&mut layer.mlp.c_proj)
676                    .unwrap()
677                    .merge_weights()?;
678            }
679        }
680
681        Ok(Self {
682            wte,
683            blocks,
684            ln_f,
685            lm_head,
686            kv_cache: EitherCache::Full(pipeline::Cache::new(cfg.num_hidden_layers, true)),
687            device: normal_loading_metadata.real_device,
688            xlora_classifier: xlora_config.map(|xlora_config| {
689                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
690            }),
691            dtype,
692            mapper,
693            cfg: ModelConfigMetadata {
694                max_seq_len: cfg.max_position_embeddings,
695                num_layers: cfg.num_hidden_layers,
696                hidden_size: cfg.hidden_size,
697                num_kv_heads: cfg.num_key_value_heads,
698                num_attn_heads: cfg.num_attention_heads,
699                sliding_window: None,
700                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
702            },
703        })
704    }
705}
706
707impl IsqModel for XLoraLlama {
708    fn get_layers(
709        &mut self,
710    ) -> (
711        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
712        &dyn DeviceMapper,
713    ) {
714        let mut tensors = Vec::new();
715        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
716        for (i, layer) in self.blocks.iter_mut().enumerate() {
717            tensors.push((
718                Arc::get_mut(&mut layer.attn.q_proj).unwrap().quant_inner(),
719                Some(i),
720            ));
721            tensors.push((
722                Arc::get_mut(&mut layer.attn.k_proj).unwrap().quant_inner(),
723                Some(i),
724            ));
725            tensors.push((
726                Arc::get_mut(&mut layer.attn.v_proj).unwrap().quant_inner(),
727                Some(i),
728            ));
729            tensors.push((
730                Arc::get_mut(&mut layer.attn.o_proj).unwrap().quant_inner(),
731                Some(i),
732            ));
733            tensors.push((
734                Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().quant_inner(),
735                Some(i),
736            ));
737            tensors.push((
738                Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().quant_inner(),
739                Some(i),
740            ));
741            tensors.push((
742                Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
743                Some(i),
744            ));
745        }
746        (tensors, &*self.mapper)
747    }
748
749    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
750        panic!("Cannot generate UQFF for an adapter model.")
751    }
752}
753
754impl NormalModel for XLoraLlama {
755    fn forward(
756        &self,
757        _input_ids: &Tensor,
758        _seqlen_offsets: &[usize],
759        _context_lens: Vec<(usize, usize)>,
760        _position_ids: Vec<usize>,
761        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
762        _flash_params: &FlashParams,
763    ) -> Result<Tensor> {
764        unreachable!()
765    }
766    fn xlora_forward(
767        &self,
768        input_ids: &Tensor,
769        input_ids_full: &Tensor,
770        seqlen_offsets: &[usize],
771        seqlen_offsets_full: &[usize],
772        no_kv_cache: bool,
773        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
774        context_lens: Vec<(usize, usize)>,
775        _position_ids: Vec<usize>,
776        flash_params: &FlashParams,
777        flash_params_full: &FlashParams,
778    ) -> Result<Tensor> {
779        self.forward(
780            input_ids,
781            input_ids_full,
782            seqlen_offsets,
783            seqlen_offsets_full,
784            no_kv_cache,
785            non_granular_state,
786            context_lens,
787            flash_params,
788            flash_params_full,
789        )
790    }
791    fn cache(&self) -> &super::EitherCache {
792        &self.kv_cache
793    }
794    fn cache_mut(&mut self) -> &mut super::EitherCache {
795        &mut self.kv_cache
796    }
797    fn device(&self) -> &Device {
798        &self.device
799    }
800    fn is_xlora(&self) -> bool {
801        true
802    }
803    fn max_seq_len(&self) -> usize {
804        self.blocks[0].attn.max_seq_len
805    }
806    fn config(&self) -> &ModelConfigMetadata {
807        &self.cfg
808    }
809}
810
811impl ScalingsMaker for XLoraLlama {
812    fn dtype(&self) -> DType {
813        self.dtype
814    }
815    fn get_cache(&self) -> &pipeline::EitherCache {
816        &self.kv_cache
817    }
818    fn get_classifier(&self) -> &XLoraClassifier {
819        self.xlora_classifier.as_ref().unwrap()
820    }
821    fn forward(
822        &self,
823        input_ids: &Tensor,
824        seqlen_offsets: &[usize],
825        scalings: Tensor,
826        is_full_pass: bool,
827        no_kv_cache: bool,
828        is_scaling_pass: Option<f64>,
829        _context_lens: &[usize],
830        flash_params: &FlashParams,
831    ) -> Result<Tensor> {
832        self.inner_forward(
833            input_ids,
834            seqlen_offsets,
835            Some(scalings),
836            is_full_pass,
837            no_kv_cache,
838            is_scaling_pass,
839            flash_params,
840        )
841    }
842}
843
844impl AnyMoeBaseModelMixin for XLoraLlama {}