mistralrs_core/xlora_models/
llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{
4    amoe::AnyMoeBaseModelMixin,
5    attention::SdpaParams,
6    layers::{Llama3RotaryEmbedding, Sdpa},
7    lora::{linear_no_bias as linear, LinearLayerLike, LoraConfig, Ordering},
8    paged_attention::ModelConfigMetadata,
9    pipeline::{
10        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
11        EitherCache, IsqModel,
12    },
13    utils::progress::NiceProgressBar,
14};
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
18use std::{collections::HashMap, sync::Arc};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::{
23    device_map::DeviceMapper,
24    layers::{embedding, CausalMasker, RmsNorm},
25    models::llama::Config,
26    pipeline::{self, extract_logits, LayerCaches, NormalLoadingMetadata, NormalModel},
27};
28
29use super::{classifier::XLoraClassifier, NonGranularState, ScalingsMaker, XLoraConfig};
30
31struct CausalSelfAttention {
32    q_proj: Arc<dyn LinearLayerLike + Send + Sync>,
33    k_proj: Arc<dyn LinearLayerLike + Send + Sync>,
34    v_proj: Arc<dyn LinearLayerLike + Send + Sync>,
35    o_proj: Arc<dyn LinearLayerLike + Send + Sync>,
36    num_attention_heads: usize,
37    num_key_value_heads: usize,
38    head_dim: usize,
39    rotary_emb: Arc<Llama3RotaryEmbedding>,
40    max_seq_len: usize,
41    sdpa_params: SdpaParams,
42}
43
44impl CausalSelfAttention {
45    #[allow(clippy::too_many_arguments)]
46    fn forward(
47        &self,
48        x: &Tensor,
49        mask: &Option<Tensor>,
50        seqlen_offsets: &[usize],
51        block_idx: usize,
52        kv_cache: &mut LayerCaches,
53        scalings: Option<Tensor>,
54        global_scaling_weight: f64,
55        is_scaling_pass: Option<f64>,
56        flash_params: &FlashParams,
57    ) -> Result<Tensor> {
58        let (b_sz, seq_len, hidden_size) = x.dims3()?;
59
60        let original_dtype = x.dtype();
61        let mut x = x.clone();
62        if let Some(t) = self.q_proj.quantized_act_type() {
63            x = x.to_dtype(t)?;
64        }
65        let mut q = self.q_proj.lora_forward(
66            &x,
67            scalings.clone(),
68            global_scaling_weight,
69            is_scaling_pass,
70        )?;
71        let mut k = self.k_proj.lora_forward(
72            &x,
73            scalings.clone(),
74            global_scaling_weight,
75            is_scaling_pass,
76        )?;
77        let mut v = self.v_proj.lora_forward(
78            &x,
79            scalings.clone(),
80            global_scaling_weight,
81            is_scaling_pass,
82        )?;
83        if self.q_proj.quantized_act_type().is_some() {
84            q = q.to_dtype(original_dtype)?;
85            k = k.to_dtype(original_dtype)?;
86            v = v.to_dtype(original_dtype)?;
87        }
88
89        let (q, k, v) = if seq_len != 1 {
90            let q = q
91                .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
92                .transpose(1, 2)?;
93            let k = k
94                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
95                .transpose(1, 2)?;
96            let v = v
97                .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
98                .transpose(1, 2)?;
99            (q, k, v)
100        } else {
101            let q = q.reshape((b_sz, self.num_attention_heads, seq_len, self.head_dim))?;
102            let k = k.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
103            let v = v.reshape((b_sz, self.num_key_value_heads, seq_len, self.head_dim))?;
104            (q, k, v)
105        };
106
107        let (q, k) = self.rotary_emb.forward(&q, &k, seqlen_offsets)?;
108
109        let (k, v) =
110            crate::pipeline::Cache::update_kv_cache(&mut kv_cache[block_idx], k, v, false)?;
111
112        let y = Sdpa.run_attention(
113            &q,
114            &k,
115            &v,
116            mask.clone().as_ref(),
117            Some(flash_params),
118            &self.sdpa_params,
119        )?;
120
121        let mut y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
122        if let Some(t) = self.q_proj.quantized_act_type() {
123            y = y.to_dtype(t)?;
124        }
125        let mut res = self.o_proj.lora_forward(
126            &y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))?,
127            scalings.clone(),
128            global_scaling_weight,
129            is_scaling_pass,
130        )?;
131        if self.q_proj.quantized_act_type().is_some() {
132            res = res.to_dtype(original_dtype)?;
133        }
134        Ok(res)
135    }
136
137    #[allow(clippy::too_many_arguments)]
138    fn load(
139        vb: ShardedVarBuilder,
140        cfg: &Config,
141        lora_config: &[((String, String), LoraConfig)],
142        count: &mut usize,
143        ord: &Ordering,
144        mapper: &dyn DeviceMapper,
145        layer_idx: usize,
146        loading_isq: bool,
147        rope: Arc<Llama3RotaryEmbedding>,
148        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
149    ) -> Result<Self> {
150        let size_in = cfg.hidden_size;
151        let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
152        let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
153        let q_proj = linear(
154            size_in,
155            size_q,
156            mapper.set_device(layer_idx, vb.pp("q_proj"), loading_isq),
157            mapper.set_device(layer_idx, vb.pp("q_proj"), false),
158            lora_config,
159            count,
160            ord,
161            preload_adapters,
162        )?;
163        let k_proj = linear(
164            size_in,
165            size_kv,
166            mapper.set_device(layer_idx, vb.pp("k_proj"), loading_isq),
167            mapper.set_device(layer_idx, vb.pp("k_proj"), false),
168            lora_config,
169            count,
170            ord,
171            preload_adapters,
172        )?;
173        let v_proj = linear(
174            size_in,
175            size_kv,
176            mapper.set_device(layer_idx, vb.pp("v_proj"), loading_isq),
177            mapper.set_device(layer_idx, vb.pp("v_proj"), false),
178            lora_config,
179            count,
180            ord,
181            preload_adapters,
182        )?;
183        let o_proj = linear(
184            size_q,
185            size_in,
186            mapper.set_device(layer_idx, vb.pp("o_proj"), loading_isq),
187            mapper.set_device(layer_idx, vb.pp("o_proj"), false),
188            lora_config,
189            count,
190            ord,
191            preload_adapters,
192        )?;
193        Ok(Self {
194            q_proj,
195            k_proj,
196            v_proj,
197            o_proj,
198            num_attention_heads: cfg.num_attention_heads,
199            num_key_value_heads: cfg.num_key_value_heads,
200            head_dim: cfg.hidden_size / cfg.num_attention_heads,
201            rotary_emb: rope,
202            max_seq_len: cfg.max_position_embeddings,
203            sdpa_params: SdpaParams {
204                n_kv_groups: cfg.num_attention_heads / cfg.num_key_value_heads,
205                use_flash_attn: cfg.use_flash_attn,
206                softcap: None,
207                softmax_scale: 1.0 / ((cfg.hidden_size / cfg.num_attention_heads) as f32).sqrt(),
208                sliding_window: None,
209            },
210        })
211    }
212}
213
214#[derive(Clone)]
215struct Mlp {
216    c_fc1: Arc<dyn LinearLayerLike + Send + Sync>,
217    c_fc2: Arc<dyn LinearLayerLike + Send + Sync>,
218    c_proj: Arc<dyn LinearLayerLike + Send + Sync>,
219}
220
221impl Mlp {
222    fn forward(
223        &self,
224        x: &Tensor,
225        scalings: Option<Tensor>,
226        global_scaling_weight: f64,
227        is_scaling_pass: Option<f64>,
228    ) -> Result<Tensor> {
229        let original_dtype = x.dtype();
230        let mut x = x.clone();
231        if let Some(t) = self.c_fc1.quantized_act_type() {
232            x = x.to_dtype(t)?;
233        }
234        let x = (candle_nn::ops::silu(&self.c_fc1.lora_forward(
235            &x,
236            scalings.clone(),
237            global_scaling_weight,
238            is_scaling_pass,
239        )?)? * self.c_fc2.lora_forward(
240            &x,
241            scalings.clone(),
242            global_scaling_weight,
243            is_scaling_pass,
244        )?)?;
245        let mut res = self.c_proj.lora_forward(
246            &x,
247            scalings.clone(),
248            global_scaling_weight,
249            is_scaling_pass,
250        )?;
251        if self.c_fc1.quantized_act_type().is_some() {
252            res = res.to_dtype(original_dtype)?;
253        }
254        Ok(res)
255    }
256
257    #[allow(clippy::too_many_arguments)]
258    fn load(
259        vb: ShardedVarBuilder,
260        cfg: &Config,
261        lora_config: &[((String, String), LoraConfig)],
262        count: &mut usize,
263        ord: &Ordering,
264        mapper: &dyn DeviceMapper,
265        layer_idx: usize,
266        loading_isq: bool,
267        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
268    ) -> Result<Self> {
269        let h_size = cfg.hidden_size;
270        let i_size = cfg.intermediate_size;
271        let c_fc1 = linear(
272            h_size,
273            i_size,
274            mapper.set_device(layer_idx, vb.pp("gate_proj"), loading_isq),
275            mapper.set_device(layer_idx, vb.pp("gate_proj"), false),
276            lora_config,
277            count,
278            ord,
279            preload_adapters,
280        )?;
281        let c_fc2 = linear(
282            h_size,
283            i_size,
284            mapper.set_device(layer_idx, vb.pp("up_proj"), loading_isq),
285            mapper.set_device(layer_idx, vb.pp("up_proj"), false),
286            lora_config,
287            count,
288            ord,
289            preload_adapters,
290        )?;
291        let c_proj = linear(
292            i_size,
293            h_size,
294            mapper.set_device(layer_idx, vb.pp("down_proj"), loading_isq),
295            mapper.set_device(layer_idx, vb.pp("down_proj"), false),
296            lora_config,
297            count,
298            ord,
299            preload_adapters,
300        )?;
301        Ok(Self {
302            c_fc1,
303            c_fc2,
304            c_proj,
305        })
306    }
307}
308
309struct Block {
310    rms_1: RmsNorm,
311    attn: CausalSelfAttention,
312    rms_2: RmsNorm,
313    mlp: Mlp,
314}
315
316impl Block {
317    #[allow(clippy::too_many_arguments)]
318    fn forward(
319        &self,
320        x: &Tensor,
321        mask: &Option<Tensor>,
322        seqlen_offsets: &[usize],
323        block_idx: usize,
324        kv_cache: &mut LayerCaches,
325        scalings: Option<Tensor>,
326        global_scaling_weight: f64,
327        is_scaling_pass: Option<f64>,
328        flash_params: &FlashParams,
329    ) -> Result<Tensor> {
330        let residual = x;
331        let x = self.rms_1.forward(x)?;
332        let x = (self.attn.forward(
333            &x,
334            mask,
335            seqlen_offsets,
336            block_idx,
337            kv_cache,
338            scalings.clone(),
339            global_scaling_weight,
340            is_scaling_pass,
341            flash_params,
342        )? + residual)?;
343        let residual = &x;
344        let x = (self.mlp.forward(
345            &self.rms_2.forward(&x)?,
346            scalings,
347            global_scaling_weight,
348            is_scaling_pass,
349        )? + residual)?;
350        Ok(x)
351    }
352
353    #[allow(clippy::too_many_arguments)]
354    fn load(
355        vb: ShardedVarBuilder,
356        cfg: &Config,
357        lora_config: &[((String, String), LoraConfig)],
358        count: &mut usize,
359        ord: &Ordering,
360        mapper: &dyn DeviceMapper,
361        layer_idx: usize,
362        loading_isq: bool,
363        rope: Arc<Llama3RotaryEmbedding>,
364        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
365    ) -> Result<Self> {
366        let attn = CausalSelfAttention::load(
367            vb.pp("self_attn"),
368            cfg,
369            lora_config,
370            count,
371            ord,
372            mapper,
373            layer_idx,
374            loading_isq,
375            rope,
376            preload_adapters,
377        )?;
378        let mlp = Mlp::load(
379            vb.pp("mlp"),
380            cfg,
381            lora_config,
382            count,
383            ord,
384            mapper,
385            layer_idx,
386            loading_isq,
387            preload_adapters,
388        )?;
389        let rms_1 = RmsNorm::new(
390            cfg.hidden_size,
391            cfg.rms_norm_eps,
392            mapper.set_device(layer_idx, vb.pp("input_layernorm"), false),
393        )?;
394        let rms_2 = RmsNorm::new(
395            cfg.hidden_size,
396            cfg.rms_norm_eps,
397            mapper.set_device(layer_idx, vb.pp("post_attention_layernorm"), false),
398        )?;
399        Ok(Self {
400            rms_1,
401            attn,
402            rms_2,
403            mlp,
404        })
405    }
406}
407
408pub struct XLoraLlama {
409    wte: Embedding,
410    blocks: Vec<Block>,
411    ln_f: RmsNorm,
412    lm_head: Arc<dyn LinearLayerLike + Send + Sync>,
413    kv_cache: pipeline::EitherCache,
414    device: Device,
415    xlora_classifier: Option<XLoraClassifier>,
416    dtype: DType,
417    mapper: Box<dyn DeviceMapper + Send + Sync>,
418    cfg: ModelConfigMetadata,
419}
420
421impl XLoraLlama {
422    #[allow(clippy::too_many_arguments)]
423    fn inner_forward(
424        &self,
425        input_ids: &Tensor,
426        seqlen_offsets: &[usize],
427        scalings: Option<Tensor>,
428        is_full_pass: bool,
429        no_kv_cache: bool,
430        is_scaling_pass: Option<f64>,
431        flash_params: &FlashParams,
432    ) -> Result<Tensor> {
433        let mut x = self.wte.forward(input_ids)?;
434        let mut cache = if is_full_pass {
435            if no_kv_cache {
436                let mut new_cache = Vec::new();
437                for _ in 0..self.kv_cache.full().xlora_lock().len() {
438                    new_cache.push(None);
439                }
440
441                self.kv_cache.full().xlora_lock().clone_from(&new_cache);
442            }
443            self.kv_cache.full().xlora_lock()
444        } else {
445            self.kv_cache.full().lock()
446        };
447        let mask = CausalMasker.make_causal_mask_matrix(
448            input_ids,
449            &*cache,
450            x.dtype(),
451            self.cfg.num_attn_heads,
452        )?;
453        for (block_idx, block) in self.blocks.iter().enumerate() {
454            x = self.mapper.map(x, block_idx)?;
455            x = block.forward(
456                &x,
457                &mask.clone().map(|m| m.to_device(x.device()).unwrap()),
458                seqlen_offsets,
459                block_idx,
460                &mut cache,
461                scalings.clone(),
462                self.xlora_classifier
463                    .as_ref()
464                    .map(|classifier| classifier.get_global_scaling_weight())
465                    .unwrap_or(1.0),
466                is_scaling_pass,
467                flash_params,
468            )?;
469        }
470        let x = x.to_device(&self.device)?;
471        self.ln_f.forward(&x)
472    }
473
474    #[allow(clippy::too_many_arguments)]
475    pub fn forward(
476        &self,
477        input_ids: &Tensor,
478        input_ids_full: &Tensor,
479        seqlen_offsets: &[usize],
480        seqlen_offsets_full: &[usize],
481        no_kv_cache: bool,
482        non_granular_state: &Option<NonGranularState>,
483        context_lens: Vec<(usize, usize)>,
484        flash_params: &FlashParams,
485        flash_params_full: &FlashParams,
486    ) -> Result<Tensor> {
487        if self.xlora_classifier.is_some() {
488            let scalings = self.get_scalings(
489                input_ids,
490                input_ids_full,
491                seqlen_offsets,
492                seqlen_offsets_full,
493                no_kv_cache,
494                non_granular_state,
495                &vec![usize::MAX; context_lens.len()],
496                flash_params,
497                flash_params_full,
498            )?;
499
500            if no_kv_cache {
501                let mut res = self
502                    .inner_forward(
503                        input_ids_full,
504                        seqlen_offsets_full,
505                        Some(scalings),
506                        true,
507                        no_kv_cache,
508                        None,
509                        flash_params_full,
510                    )?
511                    .contiguous()?;
512                if let Some(t) = self.lm_head.quantized_act_type() {
513                    res = res.to_dtype(t)?;
514                }
515                extract_logits(
516                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
517                    context_lens,
518                )
519            } else {
520                // is_full_pass=true is ok because no_kv_cache=false
521                let mut res = self
522                    .inner_forward(
523                        input_ids,
524                        seqlen_offsets,
525                        Some(scalings),
526                        true,
527                        no_kv_cache,
528                        None,
529                        flash_params,
530                    )?
531                    .contiguous()?;
532                if let Some(t) = self.lm_head.quantized_act_type() {
533                    res = res.to_dtype(t)?;
534                }
535                extract_logits(
536                    &self.lm_head.lora_forward(&res, None, 1.0, None)?,
537                    context_lens,
538                )
539            }
540        } else {
541            let mut res = self
542                .inner_forward(
543                    input_ids,
544                    seqlen_offsets,
545                    None,
546                    false,
547                    no_kv_cache,
548                    None,
549                    flash_params,
550                )?
551                .contiguous()?;
552            if let Some(t) = self.lm_head.quantized_act_type() {
553                res = res.to_dtype(t)?;
554            }
555            extract_logits(
556                &self.lm_head.lora_forward(&res, None, 1.0, None)?,
557                context_lens,
558            )
559        }
560    }
561
562    #[allow(clippy::too_many_arguments)]
563    pub fn new(
564        cfg: &Config,
565        vb: ShardedVarBuilder,
566        lora_config: &[((String, String), LoraConfig)],
567        xlora_config: Option<XLoraConfig>,
568        xlora_ordering: Ordering,
569        is_gptx: bool,
570        normal_loading_metadata: NormalLoadingMetadata,
571        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
572    ) -> Result<Self> {
573        if let Some(ref quant_cfg) = &cfg.quantization_config {
574            tracing::info!(
575                "Using {} quantization: {}.",
576                quant_cfg.quant_method.to_string(),
577                quant_cfg.get_bits_name(&vb)
578            );
579        }
580        let mapper = normal_loading_metadata.mapper;
581        let dtype = vb.dtype();
582        let mut count = 0;
583
584        let wte = embedding(
585            cfg.vocab_size,
586            cfg.hidden_size,
587            mapper.set_nm_device(vb.pp("model.embed_tokens"), false),
588        )?;
589        let lm_head = linear(
590            cfg.hidden_size,
591            cfg.vocab_size,
592            mapper.set_nm_device(vb.pp("lm_head"), normal_loading_metadata.loading_isq),
593            mapper.set_nm_device(vb.pp("lm_head"), false),
594            lora_config,
595            &mut count,
596            &xlora_ordering,
597            preload_adapters,
598        )?;
599        if xlora_config.is_some() && lm_head.is_lora() {
600            // This is why we can pass dummy values (..., None, 1.0, None)?
601            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
602        }
603        let ln_f = RmsNorm::new(
604            cfg.hidden_size,
605            cfg.rms_norm_eps,
606            mapper.set_nm_device(vb.pp("model.norm"), false),
607        )?;
608        let mut ropes = HashMap::new();
609        for i in 0..cfg.num_hidden_layers {
610            let device = mapper
611                .device_for(i, false)
612                .unwrap_or(&normal_loading_metadata.real_device);
613            ropes.insert(
614                device.location(),
615                Arc::new(Llama3RotaryEmbedding::new_llama3(
616                    vb.dtype(),
617                    cfg,
618                    device,
619                    is_gptx,
620                )?),
621            );
622        }
623        let mut blocks: Vec<_> = NiceProgressBar::<_, 'b'>(
624            0..cfg.num_hidden_layers,
625            "Loading repeating layers",
626            &normal_loading_metadata.multi_progress,
627        )
628        .into_iter()
629        .map(|i| {
630            let device = mapper
631                .device_for(i, false)
632                .unwrap_or(&normal_loading_metadata.real_device);
633            let rotary_emb = ropes
634                .get(&device.location())
635                .expect("No RoPE for device location!")
636                .clone();
637            Block::load(
638                vb.pp(format!("model.layers.{i}")),
639                cfg,
640                lora_config,
641                &mut count,
642                &xlora_ordering,
643                &*mapper,
644                i,
645                normal_loading_metadata.loading_isq,
646                rotary_emb,
647                preload_adapters,
648            )
649            .expect("Failed to load block.")
650        })
651        .collect();
652        if xlora_config.is_none() && preload_adapters.is_none() {
653            // We are now a LoRA model so we must merge the weights
654            info!("Merging LoRA adapters.");
655            for layer in blocks.iter_mut().tqdm() {
656                Arc::get_mut(&mut layer.attn.k_proj)
657                    .unwrap()
658                    .merge_weights()?;
659                Arc::get_mut(&mut layer.attn.o_proj)
660                    .unwrap()
661                    .merge_weights()?;
662                Arc::get_mut(&mut layer.attn.q_proj)
663                    .unwrap()
664                    .merge_weights()?;
665                Arc::get_mut(&mut layer.attn.v_proj)
666                    .unwrap()
667                    .merge_weights()?;
668
669                Arc::get_mut(&mut layer.mlp.c_fc1)
670                    .unwrap()
671                    .merge_weights()?;
672                Arc::get_mut(&mut layer.mlp.c_fc2)
673                    .unwrap()
674                    .merge_weights()?;
675                Arc::get_mut(&mut layer.mlp.c_proj)
676                    .unwrap()
677                    .merge_weights()?;
678            }
679        }
680
681        Ok(Self {
682            wte,
683            blocks,
684            ln_f,
685            lm_head,
686            kv_cache: EitherCache::Full(pipeline::Cache::new(cfg.num_hidden_layers, true)),
687            device: normal_loading_metadata.real_device,
688            xlora_classifier: xlora_config.map(|xlora_config| {
689                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb, false).unwrap()
690            }),
691            dtype,
692            mapper,
693            cfg: ModelConfigMetadata {
694                max_seq_len: cfg.max_position_embeddings,
695                num_layers: cfg.num_hidden_layers,
696                hidden_size: cfg.hidden_size,
697                num_kv_heads: cfg.num_key_value_heads,
698                num_attn_heads: cfg.num_attention_heads,
699                sliding_window: None,
700                k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
701                v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
702            },
703        })
704    }
705}
706
707impl IsqModel for XLoraLlama {
708    fn get_layers(
709        &mut self,
710    ) -> (
711        Vec<(&mut Arc<dyn QuantMethod>, Option<usize>)>,
712        &dyn DeviceMapper,
713    ) {
714        let mut tensors = Vec::new();
715        tensors.push((Arc::get_mut(&mut self.lm_head).unwrap().quant_inner(), None));
716        for (i, layer) in self.blocks.iter_mut().enumerate() {
717            tensors.push((
718                Arc::get_mut(&mut layer.attn.q_proj).unwrap().quant_inner(),
719                Some(i),
720            ));
721            tensors.push((
722                Arc::get_mut(&mut layer.attn.k_proj).unwrap().quant_inner(),
723                Some(i),
724            ));
725            tensors.push((
726                Arc::get_mut(&mut layer.attn.v_proj).unwrap().quant_inner(),
727                Some(i),
728            ));
729            tensors.push((
730                Arc::get_mut(&mut layer.attn.o_proj).unwrap().quant_inner(),
731                Some(i),
732            ));
733            tensors.push((
734                Arc::get_mut(&mut layer.mlp.c_fc1).unwrap().quant_inner(),
735                Some(i),
736            ));
737            tensors.push((
738                Arc::get_mut(&mut layer.mlp.c_fc2).unwrap().quant_inner(),
739                Some(i),
740            ));
741            tensors.push((
742                Arc::get_mut(&mut layer.mlp.c_proj).unwrap().quant_inner(),
743                Some(i),
744            ));
745        }
746        (tensors, &*self.mapper)
747    }
748
749    fn residual_tensors(&self) -> Vec<(String, Tensor)> {
750        panic!("Cannot generate UQFF for an adapter model.")
751    }
752}
753
754impl NormalModel for XLoraLlama {
755    fn forward(
756        &self,
757        _input_ids: &Tensor,
758        _seqlen_offsets: &[usize],
759        _context_lens: Vec<(usize, usize)>,
760        _position_ids: Vec<usize>,
761        _metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
762        _flash_params: &FlashParams,
763    ) -> Result<Tensor> {
764        unreachable!()
765    }
766    fn xlora_forward(
767        &self,
768        input_ids: &Tensor,
769        input_ids_full: &Tensor,
770        seqlen_offsets: &[usize],
771        seqlen_offsets_full: &[usize],
772        no_kv_cache: bool,
773        non_granular_state: &Option<crate::xlora_models::NonGranularState>,
774        context_lens: Vec<(usize, usize)>,
775        _position_ids: Vec<usize>,
776        flash_params: &FlashParams,
777        flash_params_full: &FlashParams,
778    ) -> Result<Tensor> {
779        self.forward(
780            input_ids,
781            input_ids_full,
782            seqlen_offsets,
783            seqlen_offsets_full,
784            no_kv_cache,
785            non_granular_state,
786            context_lens,
787            flash_params,
788            flash_params_full,
789        )
790    }
791    fn cache(&self) -> &super::EitherCache {
792        &self.kv_cache
793    }
794    fn cache_mut(&mut self) -> &mut super::EitherCache {
795        &mut self.kv_cache
796    }
797    fn device(&self) -> &Device {
798        &self.device
799    }
800    fn is_xlora(&self) -> bool {
801        true
802    }
803    fn max_seq_len(&self) -> usize {
804        self.blocks[0].attn.max_seq_len
805    }
806    fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
807        if self.xlora_classifier.is_some() {
808            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
809        }
810        let mut sum = 0;
811        for layer in self.blocks.iter_mut() {
812            sum += Arc::get_mut(&mut layer.attn.k_proj)
813                .unwrap()
814                .activate(&adapter_names)?;
815            sum += Arc::get_mut(&mut layer.attn.o_proj)
816                .unwrap()
817                .activate(&adapter_names)?;
818            sum += Arc::get_mut(&mut layer.attn.q_proj)
819                .unwrap()
820                .activate(&adapter_names)?;
821            sum += Arc::get_mut(&mut layer.attn.v_proj)
822                .unwrap()
823                .activate(&adapter_names)?;
824
825            sum += Arc::get_mut(&mut layer.mlp.c_fc1)
826                .unwrap()
827                .activate(&adapter_names)?;
828            sum += Arc::get_mut(&mut layer.mlp.c_fc2)
829                .unwrap()
830                .activate(&adapter_names)?;
831            sum += Arc::get_mut(&mut layer.mlp.c_proj)
832                .unwrap()
833                .activate(&adapter_names)?;
834        }
835        Ok(sum)
836    }
837    fn config(&self) -> &ModelConfigMetadata {
838        &self.cfg
839    }
840}
841
842impl ScalingsMaker for XLoraLlama {
843    fn dtype(&self) -> DType {
844        self.dtype
845    }
846    fn get_cache(&self) -> &pipeline::EitherCache {
847        &self.kv_cache
848    }
849    fn get_classifier(&self) -> &XLoraClassifier {
850        self.xlora_classifier.as_ref().unwrap()
851    }
852    fn forward(
853        &self,
854        input_ids: &Tensor,
855        seqlen_offsets: &[usize],
856        scalings: Tensor,
857        is_full_pass: bool,
858        no_kv_cache: bool,
859        is_scaling_pass: Option<f64>,
860        _context_lens: &[usize],
861        flash_params: &FlashParams,
862    ) -> Result<Tensor> {
863        self.inner_forward(
864            input_ids,
865            seqlen_offsets,
866            Some(scalings),
867            is_full_pass,
868            no_kv_cache,
869            is_scaling_pass,
870            flash_params,
871        )
872    }
873}
874
875impl AnyMoeBaseModelMixin for XLoraLlama {}