mistralrs_core/xlora_models/
llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{Llama3RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel,
12    },
13    utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23    device_map::DeviceMapper,
24    layers::{embedding, CausalMasker, RmsNorm},
25    models::llama::Config,
26    pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31struct CausalSelfAttention {
32    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    num_attention_heads: usize,
37    num_key_value_heads: usize,
38    head_dim: usize,
39    rotary_emb: Arc<Llama3RotaryEmbedding>,
40    max_seq_len: usize,
41    sdpa_params: SdpaParams,
42}
43
44impl CausalSelfAttention {
45    #[allow(clippy::too_many_arguments)]
46    fn forward(
47        &self,
48        x: &Tensor,
49        mask: &Option<Tensor>,
50        seqlen_offsets: &[usize],
51        block_idx: usize,
52        kv_cache: &mut LayerCaches,
53        scalings: Option<Tensor>,
54        global_scaling_weight: f64,
55        is_scaling_pass: Option<f64>,
56        flash_params: &FlashParams,
57    ) -> Result<Tensor> {
58        let (b_sz, seq_len, hidden_size) = x.dims3()?;
59
60        let original_dtype = x.dtype();
61        let mut x = x.clone();
62        if let Some(t) = self.q_proj.quantized_act_type() {
63            x = x.to_dtype(t)?;
64        }
65        let mut q = self.q_proj.lora_forward(
66            &x,
67            scalings.clone(),
68            global_scaling_weight,
69            is_scaling_pass,
70        )?;
71        let mut k = self.k_proj.lora_forward(
72            &x,
73            scalings.clone(),
74            global_scaling_weight,
75            is_scaling_pass,
76        )?;
77        let mut v = self.v_proj.lora_forward(
78            &x,
79            scalings.clone(),
80            global_scaling_weight,
81            is_scaling_pass,
82        )?;
83        if self.q_proj.quantized_act_type().is_some() {
84            q = q.to_dtype(original_dtype)?;
85            k = k.to_dtype(original_dtype)?;
86            v = v.to_dtype(original_dtype)?;
87        }
88
89        let (q, k, v) = if seq_len != 1 {
90            let q = q
91                .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
92                .transpose(1, 2)?;
93            let k = k
94                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
95                .transpose(1, 2)?;
96            let v = v
97                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
98                .transpose(1, 2)?;
99            (q, k, v)
100        } else {
101            let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
102            let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
103            let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
104            (q, k, v)
105        };
106
107        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
108
109        let (k, v) = crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v)?;
110
111        let y = Sdpa.run_attention(
112            &q,
113            &k,
114            &v,
115            mask.clone().as_ref(),
116            Some(flash_params),
117            &self.sdpa_params,
118        )?;
119
120        let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
121        if let Some(t) = self.q_proj.quantized_act_type() {
122            y = y.to_dtype(t)?;
123        }
124        let mut res = self.o_proj.lora_forward(
125            &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?,
126            scalings.clone(),
127            global_scaling_weight,
128            is_scaling_pass,
129        )?;
130        if self.q_proj.quantized_act_type().is_some() {
131            res = res.to_dtype(original_dtype)?;
132        }
133        Ok(res)
134    }
135
136    #[allow(clippy::too_many_arguments)]
137    fn load(
138        vb: ShardedVarBuilder,
139        cfg: &Config,
140        lora_config: &[((String, String), LoraConfig)],
141        count: &mut usize,
142        ord: &Ordering,
143        mapper: &dyn DeviceMapper,
144        layer_idx: usize,
145        loading_isq: bool,
146        rope: Arc<Llama3RotaryEmbedding>,
147        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
148    ) -> Result<Self> {
149        let size_in = cfg.hidden_size;
150        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
151        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
152        let q_proj = linear(
153            size_in,
154            size_q,
155            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
156            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
157            lora_config,
158            count,
159            ord,
160            preload_adapters,
161        )?;
162        let k_proj = linear(
163            size_in,
164            size_kv,
165            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
166            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
167            lora_config,
168            count,
169            ord,
170            preload_adapters,
171        )?;
172        let v_proj = linear(
173            size_in,
174            size_kv,
175            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
176            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
177            lora_config,
178            count,
179            ord,
180            preload_adapters,
181        )?;
182        let o_proj = linear(
183            size_q,
184            size_in,
185            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
186            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
187            lora_config,
188            count,
189            ord,
190            preload_adapters,
191        )?;
192        Ok(Self {
193            q_proj,
194            k_proj,
195            v_proj,
196            o_proj,
197            num_attention_heads: cfg.num_attention_heads,
198            num_key_value_heads: cfg.num_key_value_heads,
199            head_dim: cfg.hidden_size / cfg.num_attention_heads,
200            rotary_emb: rope,
201            max_seq_len: cfg.max_position_embeddings,
202            sdpa_params: SdpaParams {
203                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
204                softcap: None,
205                softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
206                sliding_window: None,
207            },
208        })
209    }
210}
211
212#[derive(Clone)]
213struct Mlp {
214    c_fc1: Arc<dyn LinearLayerLike + Send + Sync>,
215    c_fc2: Arc<dyn LinearLayerLike + Send + Sync>,
216    c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
217}
218
219impl Mlp {
220    fn forward(
221        &self,
222        x: &Tensor,
223        scalings: Option<Tensor>,
224        global_scaling_weight: f64,
225        is_scaling_pass: Option<f64>,
226    ) -> Result<Tensor> {
227        let original_dtype = x.dtype();
228        let mut x = x.clone();
229        if let Some(t) = self.c_fc1.quantized_act_type() {
230            x = x.to_dtype(t)?;
231        }
232        let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward(
233            &x,
234            scalings.clone(),
235            global_scaling_weight,
236            is_scaling_pass,
237        )?)? * self.c_fc2.lora_forward(
238            &x,
239            scalings.clone(),
240            global_scaling_weight,
241            is_scaling_pass,
242        )?)?;
243        let mut res = self.c_proj.lora_forward(
244            &x,
245            scalings.clone(),
246            global_scaling_weight,
247            is_scaling_pass,
248        )?;
249        if self.c_fc1.quantized_act_type().is_some() {
250            res = res.to_dtype(original_dtype)?;
251        }
252        Ok(res)
253    }
254
255    #[allow(clippy::too_many_arguments)]
256    fn load(
257        vb: ShardedVarBuilder,
258        cfg: &Config,
259        lora_config: &[((String, String), LoraConfig)],
260        count: &mut usize,
261        ord: &Ordering,
262        mapper: &dyn DeviceMapper,
263        layer_idx: usize,
264        loading_isq: bool,
265        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
266    ) -> Result<Self> {
267        let h_size = cfg.hidden_size;
268        let i_size = cfg.intermediate_size;
269        let c_fc1 = linear(
270            h_size,
271            i_size,
272            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
273            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
274            lora_config,
275            count,
276            ord,
277            preload_adapters,
278        )?;
279        let c_fc2 = linear(
280            h_size,
281            i_size,
282            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
283            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
284            lora_config,
285            count,
286            ord,
287            preload_adapters,
288        )?;
289        let c_proj = linear(
290            i_size,
291            h_size,
292            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
293            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
294            lora_config,
295            count,
296            ord,
297            preload_adapters,
298        )?;
299        Ok(Self {
300            c_fc1,
301            c_fc2,
302            c_proj,
303        })
304    }
305}
306
307struct Block {
308    rms_1: RmsNorm,
309    attn: CausalSelfAttention,
310    rms_2: RmsNorm,
311    mlp: Mlp,
312}
313
314impl Block {
315    #[allow(clippy::too_many_arguments)]
316    fn forward(
317        &self,
318        x: &Tensor,
319        mask: &Option<Tensor>,
320        seqlen_offsets: &[usize],
321        block_idx: usize,
322        kv_cache: &mut LayerCaches,
323        scalings: Option<Tensor>,
324        global_scaling_weight: f64,
325        is_scaling_pass: Option<f64>,
326        flash_params: &FlashParams,
327    ) -> Result<Tensor> {
328        let residual = x;
329        let x = self.rms_1.forward(x)?;
330        let x = (self.attn.forward(
331            &x,
332            mask,
333            seqlen_offsets,
334            block_idx,
335            kv_cache,
336            scalings.clone(),
337            global_scaling_weight,
338            is_scaling_pass,
339            flash_params,
340        )? + residual)?;
341        let residual = &x;
342        let x = (self.mlp.forward(
343            &self.rms_2.forward(&x)?,
344            scalings,
345            global_scaling_weight,
346            is_scaling_pass,
347        )? + residual)?;
348        Ok(x)
349    }
350
351    #[allow(clippy::too_many_arguments)]
352    fn load(
353        vb: ShardedVarBuilder,
354        cfg: &Config,
355        lora_config: &[((String, String), LoraConfig)],
356        count: &mut usize,
357        ord: &Ordering,
358        mapper: &dyn DeviceMapper,
359        layer_idx: usize,
360        loading_isq: bool,
361        rope: Arc<Llama3RotaryEmbedding>,
362        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
363    ) -> Result<Self> {
364        let attn = CausalSelfAttention::load(
365            vb.pp("self_attn"),
366            cfg,
367            lora_config,
368            count,
369            ord,
370            mapper,
371            layer_idx,
372            loading_isq,
373            rope,
374            preload_adapters,
375        )?;
376        let mlp = Mlp::load(
377            vb.pp("mlp"),
378            cfg,
379            lora_config,
380            count,
381            ord,
382            mapper,
383            layer_idx,
384            loading_isq,
385            preload_adapters,
386        )?;
387        let rms_1 = RmsNorm::new(
388            cfg.hidden_size,
389            cfg.rms_norm_eps,
390            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
391        )?;
392        let rms_2 = RmsNorm::new(
393            cfg.hidden_size,
394            cfg.rms_norm_eps,
395            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
396        )?;
397        Ok(Self {
398            rms_1,
399            attn,
400            rms_2,
401            mlp,
402        })
403    }
404}
405
406pub struct XLoraLlama {
407    wte: Embedding,
408    blocks: Vec<Block>,
409    ln_f: RmsNorm,
410    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
411    kv_cache: pipeline::EitherCache,
412    device: Device,
413    xlora_classifier: Option<XLoraClassifier>,
414    dtype: DType,
415    mapper: Box<dyn DeviceMapper + Send + Sync>,
416    cfg: ModelConfigMetadata,
417}
418
419impl XLoraLlama {
420    #[allow(clippy::too_many_arguments)]
421    fn inner_forward(
422        &self,
423        input_ids: &Tensor,
424        seqlen_offsets: &[usize],
425        scalings: Option<Tensor>,
426        is_full_pass: bool,
427        no_kv_cache: bool,
428        is_scaling_pass: Option<f64>,
429        flash_params: &FlashParams,
430    ) -> Result<Tensor> {
431        let mut x = self.wte.forward(input_ids)?;
432        let mut cache = if is_full_pass {
433            if no_kv_cache {
434                let mut new_cache = Vec::new();
435                for _ in 0..self.kv_cache.full().xlora_lock().len() {
436                    new_cache.push(None);
437                }
438
439                self.kv_cache.full().xlora_lock().clone_from(&new_cache);
440            }
441            self.kv_cache.full().xlora_lock()
442        } else {
443            self.kv_cache.full().lock()
444        };
445        let mask = CausalMasker.make_causal_mask_matrix(
446            input_ids,
447            &*cache,
448            x.dtype(),
449            self.cfg.num_attn_heads,
450        )?;
451        for (block_idx, block) in self.blocks.iter().enumerate() {
452            x = self.mapper.map(x, block_idx)?;
453            x = block.forward(
454                &x,
455                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
456                seqlen_offsets,
457                block_idx,
458                &mut cache,
459                scalings.clone(),
460                self.xlora_classifier
461                    .as_ref()
462                    .map(|classifier| classifier.get_global_scaling_weight())
463                    .unwrap_or(1.0),
464                is_scaling_pass,
465                flash_params,
466            )?;
467        }
468        let x = x.to_device(&self.device)?;
469        self.ln_f.forward(&x)
470    }
471
472    #[allow(clippy::too_many_arguments)]
473    pub fn forward(
474        &self,
475        input_ids: &Tensor,
476        input_ids_full: &Tensor,
477        seqlen_offsets: &[usize],
478        seqlen_offsets_full: &[usize],
479        no_kv_cache: bool,
480        non_granular_state: &Option<NonGranularState>,
481        context_lens: Vec<(usize, usize)>,
482        flash_params: &FlashParams,
483        flash_params_full: &FlashParams,
484    ) -> Result<Tensor> {
485        if self.xlora_classifier.is_some() {
486            let scalings = self.get_scalings(
487                input_ids,
488                input_ids_full,
489                seqlen_offsets,
490                seqlen_offsets_full,
491                no_kv_cache,
492                non_granular_state,
493                &vec![usize::MAX; context_lens.len()],
494                flash_params,
495                flash_params_full,
496            )?;
497
498            if no_kv_cache {
499                let mut res = self
500                    .inner_forward(
501                        input_ids_full,
502                        seqlen_offsets_full,
503                        Some(scalings),
504                        true,
505                        no_kv_cache,
506                        None,
507                        flash_params_full,
508                    )?
509                    .contiguous()?;
510                if let Some(t) = self.lm_head.quantized_act_type() {
511                    res = res.to_dtype(t)?;
512                }
513                extract_logits(
514                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
515                    context_lens,
516                )
517            } else {
518                // is_full_pass=true is ok because no_kv_cache=false
519                let mut res = self
520                    .inner_forward(
521                        input_ids,
522                        seqlen_offsets,
523                        Some(scalings),
524                        true,
525                        no_kv_cache,
526                        None,
527                        flash_params,
528                    )?
529                    .contiguous()?;
530                if let Some(t) = self.lm_head.quantized_act_type() {
531                    res = res.to_dtype(t)?;
532                }
533                extract_logits(
534                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
535                    context_lens,
536                )
537            }
538        } else {
539            let mut res = self
540                .inner_forward(
541                    input_ids,
542                    seqlen_offsets,
543                    None,
544                    false,
545                    no_kv_cache,
546                    None,
547                    flash_params,
548                )?
549                .contiguous()?;
550            if let Some(t) = self.lm_head.quantized_act_type() {
551                res = res.to_dtype(t)?;
552            }
553            extract_logits(
554                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
555                context_lens,
556            )
557        }
558    }
559
560    #[allow(clippy::too_many_arguments)]
561    pub fn new(
562        cfg: &Config,
563        vb: ShardedVarBuilder,
564        lora_config: &[((String, String), LoraConfig)],
565        xlora_config: Option<XLoraConfig>,
566        xlora_ordering: Ordering,
567        is_gptx: bool,
568        normal_loading_metadata: NormalLoadingMetadata,
569        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
570    ) -> Result<Self> {
571        if let Some(ref quant_cfg) = &cfg.quantization_config {
572            tracing::info!(
573                "Using {} quantization: {}.",
574                quant_cfg.name(),
575                quant_cfg.get_bits_name(&vb)
576            );
577        }
578        let mapper = normal_loading_metadata.mapper;
579        let dtype = vb.dtype();
580        let mut count = 0;
581
582        let wte = embedding(
583            cfg.vocab_size,
584            cfg.hidden_size,
585            mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
586            &cfg.quantization_config,
587        )?;
588        let lm_head = linear(
589            cfg.hidden_size,
590            cfg.vocab_size,
591            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
592            mapper.set_nm_device(vb.pp("lm_head"), false),
593            lora_config,
594            &mut count,
595            &xlora_ordering,
596            preload_adapters,
597        )?;
598        if xlora_config.is_some() && lm_head.is_lora() {
599            // This is why we can pass dummy values (..., None, 1.0, None)?
600            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
601        }
602        let ln_f = RmsNorm::new(
603            cfg.hidden_size,
604            cfg.rms_norm_eps,
605            mapper.set_nm_device(vb.pp("model.norm"), false),
606        )?;
607        let mut ropes = HashMap::new();
608        for i in 0..cfg.num_hidden_layers {
609            let device = mapper
610                .device_for(i, false)
611                .unwrap_or(&normal_loading_metadata.real_device);
612            ropes.insert(
613                device.location(),
614                Arc::new(Llama3RotaryEmbedding::new_llama3(
615                    vb.dtype(),
616                    cfg,
617                    device,
618                    is_gptx,
619                )?),
620            );
621        }
622        let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
623            0..cfg.num_hidden_layers,
624            "Loading repeating layers",
625            &normal_loading_metadata.multi_progress,
626        )
627        .into_iter()
628        .map(|i| {
629            let device = mapper
630                .device_for(i, false)
631                .unwrap_or(&normal_loading_metadata.real_device);
632            let rotary_emb = ropes
633                .get(&device.location())
634                .expect("No RoPE for device location!")
635                .clone();
636            Block::load(
637                vb.pp(format!("model.layers.{i}")),
638                cfg,
639                lora_config,
640                &mut count,
641                &xlora_ordering,
642                &*mapper,
643                i,
644                normal_loading_metadata.loading_isq,
645                rotary_emb,
646                preload_adapters,
647            )
648            .expect("Failed to load block.")
649        })
650        .collect();
651        if xlora_config.is_none() && preload_adapters.is_none() {
652            // We are now a LoRA model so we must merge the weights
653            info!("Merging LoRA adapters.");
654            for layer in blocks.iter_mut().tqdm() {
655                Arc::get_mut(&mut layer.attn.k_proj)
656                    .unwrap()
657                    .merge_weights()?;
658                Arc::get_mut(&mut layer.attn.o_proj)
659                    .unwrap()
660                    .merge_weights()?;
661                Arc::get_mut(&mut layer.attn.q_proj)
662                    .unwrap()
663                    .merge_weights()?;
664                Arc::get_mut(&mut layer.attn.v_proj)
665                    .unwrap()
666                    .merge_weights()?;
667
668                Arc::get_mut(&mut layer.mlp.c_fc1)
669                    .unwrap()
670                    .merge_weights()?;
671                Arc::get_mut(&mut layer.mlp.c_fc2)
672                    .unwrap()
673                    .merge_weights()?;
674                Arc::get_mut(&mut layer.mlp.c_proj)
675                    .unwrap()
676                    .merge_weights()?;
677            }
678        }
679
680        Ok(Self {
681            wte,
682            blocks,
683            ln_f,
684            lm_head,
685            kv_cache: EitherCache::Full(pipeline::Cache::new(cfg.num_hidden_layers, true)),
686            device: normal_loading_metadata.real_device,
687            xlora_classifier: xlora_config.map(|xlora_config| {
688                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
689            }),
690            dtype,
691            mapper,
692            cfg: ModelConfigMetadata {
693                max_seq_len: cfg.max_position_embeddings,
694                num_layers: cfg.num_hidden_layers,
695                hidden_size: cfg.hidden_size,
696                num_kv_heads: cfg.num_key_value_heads,
697                num_attn_heads: cfg.num_attention_heads,
698                sliding_window: None,
699                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
700                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701            },
702        })
703    }
704}
705
706impl IsqModel for XLoraLlama {
707    fn get_layers(
708        &mut self,
709    ) -> (
710        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
711        &dyn DeviceMapper,
712    ) {
713        let mut tensors = Vec::new();
714        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
715        for (i, layer) in self.blocks.iter_mut().enumerate() {
716            tensors.push((
717                Arc::get_mut(&mut layer.attn.q_proj).unwrap().quant_inner(),
718                Some(i),
719            ));
720            tensors.push((
721                Arc::get_mut(&mut layer.attn.k_proj).unwrap().quant_inner(),
722                Some(i),
723            ));
724            tensors.push((
725                Arc::get_mut(&mut layer.attn.v_proj).unwrap().quant_inner(),
726                Some(i),
727            ));
728            tensors.push((
729                Arc::get_mut(&mut layer.attn.o_proj).unwrap().quant_inner(),
730                Some(i),
731            ));
732            tensors.push((
733                Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().quant_inner(),
734                Some(i),
735            ));
736            tensors.push((
737                Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().quant_inner(),
738                Some(i),
739            ));
740            tensors.push((
741                Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
742                Some(i),
743            ));
744        }
745        (tensors, &*self.mapper)
746    }
747
748    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
749        panic!("Cannot generate UQFF for an adapter model.")
750    }
751}
752
753impl NormalModel for XLoraLlama {
754    fn forward(
755        &self,
756        _input_ids: &Tensor,
757        _seqlen_offsets: &[usize],
758        _context_lens: Vec<(usize, usize)>,
759        _position_ids: Vec<usize>,
760        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
761        _flash_params: &FlashParams,
762    ) -> Result<Tensor> {
763        unreachable!()
764    }
765    fn xlora_forward(
766        &self,
767        input_ids: &Tensor,
768        input_ids_full: &Tensor,
769        seqlen_offsets: &[usize],
770        seqlen_offsets_full: &[usize],
771        no_kv_cache: bool,
772        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
773        context_lens: Vec<(usize, usize)>,
774        _position_ids: Vec<usize>,
775        flash_params: &FlashParams,
776        flash_params_full: &FlashParams,
777    ) -> Result<Tensor> {
778        self.forward(
779            input_ids,
780            input_ids_full,
781            seqlen_offsets,
782            seqlen_offsets_full,
783            no_kv_cache,
784            non_granular_state,
785            context_lens,
786            flash_params,
787            flash_params_full,
788        )
789    }
790    fn cache(&self) -> &super::EitherCache {
791        &self.kv_cache
792    }
793    fn cache_mut(&mut self) -> &mut super::EitherCache {
794        &mut self.kv_cache
795    }
796    fn device(&self) -> &Device {
797        &self.device
798    }
799    fn is_xlora(&self) -> bool {
800        true
801    }
802    fn max_seq_len(&self) -> usize {
803        self.blocks[0].attn.max_seq_len
804    }
805    fn config(&self) -> &ModelConfigMetadata {
806        &self.cfg
807    }
808}
809
810impl ScalingsMaker for XLoraLlama {
811    fn dtype(&self) -> DType {
812        self.dtype
813    }
814    fn get_cache(&self) -> &pipeline::EitherCache {
815        &self.kv_cache
816    }
817    fn get_classifier(&self) -> &XLoraClassifier {
818        self.xlora_classifier.as_ref().unwrap()
819    }
820    fn forward(
821        &self,
822        input_ids: &Tensor,
823        seqlen_offsets: &[usize],
824        scalings: Tensor,
825        is_full_pass: bool,
826        no_kv_cache: bool,
827        is_scaling_pass: Option<f64>,
828        _context_lens: &[usize],
829        flash_params: &FlashParams,
830    ) -> Result<Tensor> {
831        self.inner_forward(
832            input_ids,
833            seqlen_offsets,
834            Some(scalings),
835            is_full_pass,
836            no_kv_cache,
837            is_scaling_pass,
838            flash_params,
839        )
840    }
841}
842
843impl AnyMoeBaseModelMixin for XLoraLlama {}