mistralrs_core/xlora_models/
quantized_llama.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::attention::SdpaParams;
7use crate::gguf::Content;
8use crate::lora::{
9    get_lora_cfg, AdapterSwapper, LinearLayerLike, LoraConfig, Merge, Ordering, QLoraLinear,
10};
11use crate::pipeline::text_models_inputs_processor::FlashParams;
12use crate::utils::progress::NiceProgressBar;
13use candle_core::quantized::ggml_file;
14use candle_core::quantized::QMatMul;
15use candle_core::{DType, Device, Result, Tensor};
16use candle_nn::{Embedding, Module};
17use indicatif::MultiProgress;
18use mistralrs_quant::{MatMul, ShardedVarBuilder};
19use tqdm::Iter;
20use tracing::info;
21
22use crate::device_map::DeviceMapper;
23use crate::layers::{CausalMasker, QRmsNorm, RotaryEmbedding, Sdpa};
24use crate::pipeline::{extract_logits, Cache, EitherCache};
25
26use super::classifier::XLoraClassifier;
27use super::{verify_sanity_adapters, NonGranularState, ScalingsMaker, XLoraConfig};
28use crate::models::quantized_llama::PropsGGUF;
29use crate::utils::gguf_metadata::ContentMetadata;
30use crate::utils::model_config as ModelConfig;
31
32const MAX_SEQ_LEN: u32 = 4096;
33const SUPPORTED_LAYERS: [&str; 8] = [
34    "self_attn.q_proj",
35    "self_attn.k_proj",
36    "self_attn.v_proj",
37    "self_attn.o_proj",
38    "mlp.up_proj",
39    "mlp.down_proj",
40    "mlp.gate_proj",
41    "lm_head",
42];
43
44#[derive(Debug)]
45struct Mlp {
46    feed_forward_w1: QLoraLinear,
47    feed_forward_w2: QLoraLinear,
48    feed_forward_w3: QLoraLinear,
49}
50
51impl Mlp {
52    fn forward(
53        &self,
54        xs: &Tensor,
55        scalings: Option<Tensor>,
56        global_scaling_weight: f64,
57        is_scaling_pass: Option<f64>,
58    ) -> Result<Tensor> {
59        let w1 = self.feed_forward_w1.lora_forward(
60            xs,
61            scalings.clone(),
62            global_scaling_weight,
63            is_scaling_pass,
64        )?;
65        let w3 = self.feed_forward_w3.lora_forward(
66            xs,
67            scalings.clone(),
68            global_scaling_weight,
69            is_scaling_pass,
70        )?;
71        self.feed_forward_w2.lora_forward(
72            &(candle_nn::ops::silu(&w1)? * w3)?,
73            scalings.clone(),
74            global_scaling_weight,
75            is_scaling_pass,
76        )
77    }
78}
79
80#[derive(Debug)]
81enum MlpOrMoe {
82    Mlp(Mlp),
83    MoE {
84        n_expert_used: usize,
85        feed_forward_gate_inp: QMatMul,
86        experts: Vec<Mlp>,
87    },
88}
89
90impl MlpOrMoe {
91    fn forward(
92        &self,
93        xs: &Tensor,
94        scalings: Option<Tensor>,
95        global_scaling_weight: f64,
96        is_scaling_pass: Option<f64>,
97    ) -> Result<Tensor> {
98        match self {
99            Self::MoE {
100                feed_forward_gate_inp,
101                experts,
102                n_expert_used,
103            } => {
104                let (b_size, seq_len, hidden_dim) = xs.dims3()?;
105                let xs = xs.reshape(((), hidden_dim))?;
106                let router_logits = MatMul.qmatmul(&xs, feed_forward_gate_inp)?;
107                let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
108
109                // In order to extract topk, we extract the data from the tensor and manipulate it
110                // directly. Maybe we will want to use some custom ops instead at some point.
111                let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
112
113                // routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
114                // top_x contains the row indexes to evaluate for each expert.
115                let mut top_x = vec![vec![]; experts.len()];
116                let mut selected_rws = vec![vec![]; experts.len()];
117                for (row_idx, rw) in routing_weights.iter().enumerate() {
118                    let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
119                    dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
120                    let mut sum_routing_weights = 0f32;
121                    for &expert_idx in dst.iter().take(*n_expert_used) {
122                        let expert_idx = expert_idx as usize;
123                        let routing_weight = rw[expert_idx];
124                        sum_routing_weights += routing_weight;
125                        top_x[expert_idx].push(row_idx as u32);
126                    }
127                    for &expert_idx in dst.iter().take(*n_expert_used) {
128                        let expert_idx = expert_idx as usize;
129                        let routing_weight = rw[expert_idx];
130                        selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
131                    }
132                }
133
134                // routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
135                // expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
136
137                let mut ys = xs.zeros_like()?;
138                for (expert_idx, expert_layer) in experts.iter().enumerate() {
139                    let top_x = &top_x[expert_idx];
140                    if top_x.is_empty() {
141                        continue;
142                    }
143                    let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
144                    let selected_rws =
145                        Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
146                            .reshape(((), 1))?;
147                    // Index the correct hidden states and compute the expert hidden state for
148                    // the current expert. We need to make sure to multiply the output hidden
149                    // states by `routing_weights` on the corresponding tokens (top-1 and top-2)
150                    let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
151                    // current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
152                    let current_hidden_states = expert_layer.forward(
153                        &current_state,
154                        scalings.clone(),
155                        global_scaling_weight,
156                        is_scaling_pass,
157                    )?;
158                    let current_hidden_states =
159                        current_hidden_states.broadcast_mul(&selected_rws)?;
160                    ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
161                }
162
163                let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
164                Ok(ys)
165            }
166            Self::Mlp(mlp) => {
167                mlp.forward(xs, scalings.clone(), global_scaling_weight, is_scaling_pass)
168            }
169        }
170    }
171}
172
173struct LayerWeights {
174    attention_wq: QLoraLinear,
175    attention_wk: QLoraLinear,
176    attention_wv: QLoraLinear,
177    attention_wo: QLoraLinear,
178    attention_norm: QRmsNorm,
179    mlp_or_moe: MlpOrMoe,
180    ffn_norm: QRmsNorm,
181    n_head: usize,
182    n_kv_head: usize,
183    head_dim: usize,
184    rotary: Arc<RotaryEmbedding>,
185    sdpa_params: SdpaParams,
186    dtype: DType,
187}
188
189impl LayerWeights {
190    #[allow(clippy::too_many_arguments)]
191    fn forward_attn(
192        &self,
193        x: &Tensor,
194        mask: &Option<Tensor>,
195        start_offsets: &[usize],
196        kv_cache: &mut Option<(Tensor, Tensor)>,
197        scalings: Option<Tensor>,
198        global_scaling_weight: f64,
199        is_scaling_pass: Option<f64>,
200        flash_params: &FlashParams,
201    ) -> Result<Tensor> {
202        let (b_sz, seq_len, n_embd) = x.dims3()?;
203        let q = self
204            .attention_wq
205            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
206            .to_dtype(self.dtype)?;
207        let k = self
208            .attention_wk
209            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
210            .to_dtype(self.dtype)?;
211        let v = self
212            .attention_wv
213            .lora_forward(x, scalings.clone(), global_scaling_weight, is_scaling_pass)?
214            .to_dtype(self.dtype)?;
215
216        let (q, k, v) = if seq_len != 1 {
217            let q = q
218                .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
219                .transpose(1, 2)?;
220            let k = k
221                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
222                .transpose(1, 2)?;
223            let v = v
224                .reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
225                .transpose(1, 2)?;
226            (q, k, v)
227        } else {
228            let q = q.reshape((b_sz, self.n_head, seq_len, self.head_dim))?;
229            let k = k.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
230            let v = v.reshape((b_sz, self.n_kv_head, seq_len, self.head_dim))?;
231            (q, k, v)
232        };
233
234        let (q, k) = self.rotary.forward(&q, &k, start_offsets)?;
235
236        let (k, v) = Cache::update_kv_cache(kv_cache, k, v, false)?;
237
238        let y = Sdpa.run_attention(
239            &q,
240            &k,
241            &v,
242            mask.as_ref(),
243            Some(flash_params),
244            &self.sdpa_params,
245        )?;
246
247        let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
248        let y = self.attention_wo.lora_forward(
249            &y.to_dtype(x.dtype())?,
250            scalings.clone(),
251            global_scaling_weight,
252            is_scaling_pass,
253        )?;
254        Ok(y)
255    }
256}
257
258pub struct ModelWeights {
259    tok_embeddings: Embedding,
260    layers: Vec<LayerWeights>,
261    norm: QRmsNorm,
262    output: QLoraLinear,
263    pub device: Device,
264    pub cache: EitherCache,
265    xlora_classifier: Option<XLoraClassifier>,
266    pub max_seq_len: usize,
267    mapper: Option<Box<dyn DeviceMapper + Send + Sync>>,
268    dtype: DType,
269}
270
271impl ModelConfig::FromAdapterGGML for ModelWeights {
272    fn from_ggml(
273        mut ct: ggml_file::Content,
274        gqa: usize,
275        lora_config: &[((String, String), LoraConfig)],
276        vb: &ShardedVarBuilder,
277        ordering: &Ordering,
278        xlora_config: Option<XLoraConfig>,
279        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
280        dtype: DType,
281    ) -> Result<Self> {
282        let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
283        let rotary = RotaryEmbedding::new_partial(
284            10000.,
285            ct.hparams.n_rot as usize,
286            MAX_SEQ_LEN as usize,
287            &ct.device,
288            false,
289            dtype,
290        )?;
291        let tok_embeddings = ct.remove("tok_embeddings.weight")?;
292        let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
293        let norm = QRmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
294        let output = ct.remove("output.weight")?;
295        let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
296        let mut count = 0;
297        for layer_idx in 0..ct.hparams.n_layer {
298            let prefix = format!("layers.{layer_idx}");
299            let attention_wq = ct.remove(&format!("{prefix}.attention.wq.weight"))?;
300            let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
301            let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
302            let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
303            let mlp_or_moe = {
304                let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
305                let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
306                let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
307                let cfg_w1 = get_lora_cfg(&feed_forward_w1);
308                let cfg_w2 = get_lora_cfg(&feed_forward_w2);
309                let cfg_w3 = get_lora_cfg(&feed_forward_w3);
310                MlpOrMoe::Mlp(Mlp {
311                    feed_forward_w1: QLoraLinear::new(
312                        QMatMul::from_qtensor(feed_forward_w1)?,
313                        &cfg_w1,
314                        lora_config,
315                        vb,
316                        ordering,
317                        format!("model.layers.{layer_idx}.mlp.gate_proj"),
318                        &mut count,
319                        preload_adapters,
320                    )?,
321                    feed_forward_w2: QLoraLinear::new(
322                        QMatMul::from_qtensor(feed_forward_w2)?,
323                        &cfg_w2,
324                        lora_config,
325                        vb,
326                        ordering,
327                        format!("model.layers.{layer_idx}.mlp.down_proj"),
328                        &mut count,
329                        preload_adapters,
330                    )?,
331                    feed_forward_w3: QLoraLinear::new(
332                        QMatMul::from_qtensor(feed_forward_w3)?,
333                        &cfg_w3,
334                        lora_config,
335                        vb,
336                        ordering,
337                        format!("model.layers.{layer_idx}.mlp.up_proj"),
338                        &mut count,
339                        preload_adapters,
340                    )?,
341                })
342            };
343            let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
344            let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
345            let cfgq = get_lora_cfg(&attention_wq);
346            let cfgk = get_lora_cfg(&attention_wk);
347            let cfgv = get_lora_cfg(&attention_wv);
348            let cfgo = get_lora_cfg(&attention_wo);
349            let n_kv_head = ct.hparams.n_head as usize / gqa;
350            layers.push(LayerWeights {
351                attention_wq: QLoraLinear::new(
352                    QMatMul::from_qtensor(attention_wq)?,
353                    &cfgq,
354                    lora_config,
355                    vb,
356                    ordering,
357                    format!("model.layers.{layer_idx}.self_attn.q_proj"),
358                    &mut count,
359                    preload_adapters,
360                )?,
361                attention_wk: QLoraLinear::new(
362                    QMatMul::from_qtensor(attention_wk)?,
363                    &cfgk,
364                    lora_config,
365                    vb,
366                    ordering,
367                    format!("model.layers.{layer_idx}.self_attn.k_proj"),
368                    &mut count,
369                    preload_adapters,
370                )?,
371                attention_wv: QLoraLinear::new(
372                    QMatMul::from_qtensor(attention_wv)?,
373                    &cfgv,
374                    lora_config,
375                    vb,
376                    ordering,
377                    format!("model.layers.{layer_idx}.self_attn.v_proj"),
378                    &mut count,
379                    preload_adapters,
380                )?,
381                attention_wo: QLoraLinear::new(
382                    QMatMul::from_qtensor(attention_wo)?,
383                    &cfgo,
384                    lora_config,
385                    vb,
386                    ordering,
387                    format!("model.layers.{layer_idx}.self_attn.o_proj"),
388                    &mut count,
389                    preload_adapters,
390                )?,
391                attention_norm: QRmsNorm::new(attention_norm, 1e-5)?,
392                mlp_or_moe,
393                ffn_norm: QRmsNorm::new(ffn_norm, 1e-5)?,
394                n_head: ct.hparams.n_head as usize,
395                n_kv_head: ct.hparams.n_head as usize / gqa,
396                head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
397                rotary: rotary.clone().into(),
398                sdpa_params: SdpaParams {
399                    n_kv_groups: ct.hparams.n_head as usize / n_kv_head,
400                    use_flash_attn: false,
401                    softcap: None,
402                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
403                    sliding_window: None,
404                },
405                dtype,
406            })
407        }
408        if xlora_config.is_none() && preload_adapters.is_none() {
409            // We are now a LoRA model so we must merge the weights
410            info!("Merging LoRA adapters.");
411            for layer in layers.iter_mut().tqdm() {
412                layer.attention_wk.merge_weights()?;
413                layer.attention_wo.merge_weights()?;
414                layer.attention_wq.merge_weights()?;
415                layer.attention_wv.merge_weights()?;
416                match &mut layer.mlp_or_moe {
417                    MlpOrMoe::Mlp(ref mut m) => {
418                        m.feed_forward_w1.merge_weights()?;
419                        m.feed_forward_w2.merge_weights()?;
420                        m.feed_forward_w3.merge_weights()?;
421                    }
422                    MlpOrMoe::MoE {
423                        n_expert_used: _,
424                        feed_forward_gate_inp: _,
425                        experts,
426                    } => {
427                        for expert in experts {
428                            expert.feed_forward_w1.merge_weights()?;
429                            expert.feed_forward_w2.merge_weights()?;
430                            expert.feed_forward_w3.merge_weights()?;
431                        }
432                    }
433                }
434            }
435        }
436        let output_cfg = get_lora_cfg(&output);
437        let output = QLoraLinear::new(
438            QMatMul::from_qtensor(output)?,
439            &output_cfg,
440            lora_config,
441            vb,
442            ordering,
443            "lm_head".to_string(),
444            &mut count,
445            preload_adapters,
446        )?;
447        if xlora_config.is_some() && output.is_lora() {
448            // This is why we can pass dummy values (..., None, 1.0, None)?
449            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
450        }
451        Ok(Self {
452            tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
453            layers,
454            norm,
455            output,
456            device: ct.device.clone(),
457            cache: EitherCache::Full(Cache::new(ct.hparams.n_layer as usize, true)),
458            xlora_classifier: xlora_config.map(|xlora_config| {
459                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
460                    .unwrap()
461            }),
462            max_seq_len: MAX_SEQ_LEN as usize, // Cannot determine from ggml.
463            mapper: None,
464            dtype,
465        })
466    }
467}
468
469impl ModelConfig::FromAdapterGGUF for ModelWeights {
470    #[allow(clippy::too_many_arguments)]
471    fn from_gguf<R: std::io::Seek + std::io::Read>(
472        mut ct: Content<'_, R>,
473        device: &Device,
474        lora_config: &[((String, String), LoraConfig)],
475        vb: &ShardedVarBuilder,
476        ordering: &Ordering,
477        xlora_config: Option<XLoraConfig>,
478        mapper: Box<dyn DeviceMapper + Send + Sync>,
479        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
480        dtype: DType,
481    ) -> Result<Self> {
482        verify_sanity_adapters(ordering, &SUPPORTED_LAYERS)?;
483
484        // Parameter extraction from metadata.
485        let metadata = ContentMetadata {
486            path_prefix: "llama",
487            metadata: ct.get_metadata(),
488        };
489        let PropsGGUF {
490            n_expert,
491            n_expert_used,
492            head_count,
493            head_count_kv,
494            block_count,
495            embedding_length,
496            rope_dim,
497            rms_norm_eps,
498            max_seq_len,
499            rope_freq_base,
500            key_length,
501            value_length,
502        } = PropsGGUF::try_from(metadata).or_else(|err| candle_core::bail!("{err}"))?;
503
504        let head_dim = key_length;
505        if key_length != value_length {
506            candle_core::bail!(
507                "Expected key_length == value_length, got {key_length} != {value_length}"
508            );
509        }
510
511        let qtok_embeddings = ct.tensor("token_embd.weight", device)?;
512        let tok_embeddings = qtok_embeddings.dequantize(device)?;
513        let norm = QRmsNorm::new(ct.tensor("output_norm.weight", device)?, rms_norm_eps)?;
514        let output = if !ct.has_tensor("output.weight") {
515            ct.tensor("token_embd.weight", device)?
516        } else {
517            ct.tensor("output.weight", device)?
518        };
519        let mut layers = Vec::with_capacity(block_count);
520        let mut count = 0;
521
522        let mut ropes = HashMap::new();
523        for layer_idx in 0..block_count {
524            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
525            ropes.insert(
526                device.location(),
527                Arc::new(RotaryEmbedding::new(
528                    rope_freq_base,
529                    rope_dim,
530                    max_seq_len,
531                    device,
532                    false,
533                    dtype,
534                )?),
535            );
536        }
537
538        for layer_idx in NiceProgressBar::<_, 'b'>(
539            0..block_count,
540            "Loading repeating layers",
541            &MultiProgress::new(),
542        ) {
543            let prefix = format!("blk.{layer_idx}");
544            let device = mapper.device_for(layer_idx, false).unwrap_or(device);
545            let rotary = ropes
546                .get(&device.location())
547                .expect("No RoPE for device location!")
548                .clone();
549
550            let attention_wq = ct.tensor(&format!("{prefix}.attn_q.weight"), device)?;
551            let attention_wk = ct.tensor(&format!("{prefix}.attn_k.weight"), device)?;
552            let attention_wv = ct.tensor(&format!("{prefix}.attn_v.weight"), device)?;
553            let attention_wo = ct.tensor(&format!("{prefix}.attn_output.weight"), device)?;
554            let mlp_or_moe = if n_expert <= 1 {
555                let feed_forward_w1 = ct.tensor(&format!("{prefix}.ffn_gate.weight"), device)?;
556                let feed_forward_w2 = ct.tensor(&format!("{prefix}.ffn_down.weight"), device)?;
557                let feed_forward_w3 = ct.tensor(&format!("{prefix}.ffn_up.weight"), device)?;
558                let cfg_w1 = get_lora_cfg(&feed_forward_w1);
559                let cfg_w2 = get_lora_cfg(&feed_forward_w2);
560                let cfg_w3 = get_lora_cfg(&feed_forward_w3);
561                MlpOrMoe::Mlp(Mlp {
562                    feed_forward_w1: QLoraLinear::new(
563                        QMatMul::from_qtensor(feed_forward_w1)?,
564                        &cfg_w1,
565                        lora_config,
566                        vb,
567                        ordering,
568                        format!("model.layers.{layer_idx}.mlp.gate_proj"),
569                        &mut count,
570                        preload_adapters,
571                    )?,
572                    feed_forward_w2: QLoraLinear::new(
573                        QMatMul::from_qtensor(feed_forward_w2)?,
574                        &cfg_w2,
575                        lora_config,
576                        vb,
577                        ordering,
578                        format!("model.layers.{layer_idx}.mlp.down_proj"),
579                        &mut count,
580                        preload_adapters,
581                    )?,
582                    feed_forward_w3: QLoraLinear::new(
583                        QMatMul::from_qtensor(feed_forward_w3)?,
584                        &cfg_w3,
585                        lora_config,
586                        vb,
587                        ordering,
588                        format!("model.layers.{layer_idx}.mlp.up_proj"),
589                        &mut count,
590                        preload_adapters,
591                    )?,
592                })
593            } else {
594                let feed_forward_gate_inp =
595                    ct.tensor(&format!("{prefix}.ffn_gate_inp.weight"), device)?;
596                let mut experts = Vec::with_capacity(n_expert);
597                for i in 0..n_expert {
598                    let feed_forward_w1 =
599                        ct.tensor(&format!("{prefix}.ffn_gate.{i}.weight"), device)?;
600                    let feed_forward_w2 =
601                        ct.tensor(&format!("{prefix}.ffn_down.{i}.weight"), device)?;
602                    let feed_forward_w3 =
603                        ct.tensor(&format!("{prefix}.ffn_up.{i}.weight"), device)?;
604                    let cfg_w1 = get_lora_cfg(&feed_forward_w1);
605                    let cfg_w2 = get_lora_cfg(&feed_forward_w2);
606                    let cfg_w3 = get_lora_cfg(&feed_forward_w3);
607                    experts.push(Mlp {
608                        feed_forward_w1: QLoraLinear::new(
609                            QMatMul::from_qtensor(feed_forward_w1)?,
610                            &cfg_w1,
611                            lora_config,
612                            vb,
613                            ordering,
614                            format!("model.layers.{layer_idx}.mlp.gate_proj.{i}"),
615                            &mut count,
616                            preload_adapters,
617                        )?,
618                        feed_forward_w2: QLoraLinear::new(
619                            QMatMul::from_qtensor(feed_forward_w2)?,
620                            &cfg_w2,
621                            lora_config,
622                            vb,
623                            ordering,
624                            format!("model.layers.{layer_idx}.mlp.down_proj.{i}"),
625                            &mut count,
626                            preload_adapters,
627                        )?,
628                        feed_forward_w3: QLoraLinear::new(
629                            QMatMul::from_qtensor(feed_forward_w3)?,
630                            &cfg_w3,
631                            lora_config,
632                            vb,
633                            ordering,
634                            format!("model.layers.{layer_idx}.mlp.up_proj.{i}"),
635                            &mut count,
636                            preload_adapters,
637                        )?,
638                    })
639                }
640                MlpOrMoe::MoE {
641                    n_expert_used,
642                    feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
643                    experts,
644                }
645            };
646            let attention_norm = ct.tensor(&format!("{prefix}.attn_norm.weight"), device)?;
647            let ffn_norm = ct.tensor(&format!("{prefix}.ffn_norm.weight"), device)?;
648            let cfgq = get_lora_cfg(&attention_wq);
649            let cfgk = get_lora_cfg(&attention_wk);
650            let cfgv = get_lora_cfg(&attention_wv);
651            let cfgo = get_lora_cfg(&attention_wo);
652            layers.push(LayerWeights {
653                attention_wq: QLoraLinear::new(
654                    QMatMul::from_qtensor(attention_wq)?,
655                    &cfgq,
656                    lora_config,
657                    vb,
658                    ordering,
659                    format!("model.layers.{layer_idx}.self_attn.q_proj"),
660                    &mut count,
661                    preload_adapters,
662                )?,
663                attention_wk: QLoraLinear::new(
664                    QMatMul::from_qtensor(attention_wk)?,
665                    &cfgk,
666                    lora_config,
667                    vb,
668                    ordering,
669                    format!("model.layers.{layer_idx}.self_attn.k_proj"),
670                    &mut count,
671                    preload_adapters,
672                )?,
673                attention_wv: QLoraLinear::new(
674                    QMatMul::from_qtensor(attention_wv)?,
675                    &cfgv,
676                    lora_config,
677                    vb,
678                    ordering,
679                    format!("model.layers.{layer_idx}.self_attn.v_proj"),
680                    &mut count,
681                    preload_adapters,
682                )?,
683                attention_wo: QLoraLinear::new(
684                    QMatMul::from_qtensor(attention_wo)?,
685                    &cfgo,
686                    lora_config,
687                    vb,
688                    ordering,
689                    format!("model.layers.{layer_idx}.self_attn.o_proj"),
690                    &mut count,
691                    preload_adapters,
692                )?,
693                attention_norm: QRmsNorm::new(attention_norm, rms_norm_eps)?,
694                mlp_or_moe,
695                ffn_norm: QRmsNorm::new(ffn_norm, rms_norm_eps)?,
696                n_head: head_count,
697                n_kv_head: head_count_kv,
698                head_dim: embedding_length / head_count,
699                rotary: rotary.clone(),
700                sdpa_params: SdpaParams {
701                    n_kv_groups: head_count / head_count_kv,
702                    use_flash_attn: false,
703                    softcap: None,
704                    softmax_scale: 1.0 / (head_dim as f32).sqrt(),
705                    sliding_window: None,
706                },
707                dtype,
708            })
709        }
710        if xlora_config.is_none() && preload_adapters.is_none() {
711            // We are now a LoRA model so we must merge the weights
712            info!("Merging LoRA adapters.");
713            for layer in layers.iter_mut().tqdm() {
714                layer.attention_wk.merge_weights()?;
715                layer.attention_wo.merge_weights()?;
716                layer.attention_wq.merge_weights()?;
717                layer.attention_wv.merge_weights()?;
718                match &mut layer.mlp_or_moe {
719                    MlpOrMoe::Mlp(ref mut m) => {
720                        m.feed_forward_w1.merge_weights()?;
721                        m.feed_forward_w2.merge_weights()?;
722                        m.feed_forward_w3.merge_weights()?;
723                    }
724                    MlpOrMoe::MoE {
725                        n_expert_used: _,
726                        feed_forward_gate_inp: _,
727                        experts,
728                    } => {
729                        for expert in experts {
730                            expert.feed_forward_w1.merge_weights()?;
731                            expert.feed_forward_w2.merge_weights()?;
732                            expert.feed_forward_w3.merge_weights()?;
733                        }
734                    }
735                }
736            }
737        }
738        let output_cfg = get_lora_cfg(&output);
739        let output = QLoraLinear::new(
740            QMatMul::from_qtensor(output)?,
741            &output_cfg,
742            lora_config,
743            vb,
744            ordering,
745            "lm_head".to_string(),
746            &mut count,
747            preload_adapters,
748        )?;
749        if xlora_config.is_some() && output.is_lora() {
750            // This is why we can pass dummy values (..., None, 1.0, None)?
751            candle_core::bail!("Got an adapter `lm_head` layer, this is unsupported with X-LoRA.");
752        }
753        Ok(Self {
754            tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
755            layers,
756            norm,
757            output,
758            device: device.clone(),
759            cache: EitherCache::Full(Cache::new(block_count, true)),
760            xlora_classifier: xlora_config.map(|xlora_config| {
761                XLoraClassifier::new(xlora_config, count, lora_config.len(), vb.clone(), true)
762                    .unwrap()
763            }),
764            max_seq_len,
765            mapper: Some(mapper),
766            dtype,
767        })
768    }
769}
770
771impl ModelWeights {
772    pub fn activate_adapters(&mut self, adapter_names: Vec<String>) -> Result<usize> {
773        if self.xlora_classifier.is_some() {
774            candle_core::bail!("Adapter activation is not supported for X-LoRA models as the adapter set must remain the same.");
775        }
776        let mut sum = 0;
777        for layer in self.layers.iter_mut() {
778            sum += layer.attention_wk.activate(&adapter_names)?;
779            sum += layer.attention_wo.activate(&adapter_names)?;
780            sum += layer.attention_wq.activate(&adapter_names)?;
781            sum += layer.attention_wv.activate(&adapter_names)?;
782            match &mut layer.mlp_or_moe {
783                MlpOrMoe::Mlp(ref mut m) => {
784                    sum += m.feed_forward_w1.activate(&adapter_names)?;
785                    sum += m.feed_forward_w2.activate(&adapter_names)?;
786                    sum += m.feed_forward_w3.activate(&adapter_names)?;
787                }
788                MlpOrMoe::MoE {
789                    n_expert_used: _,
790                    feed_forward_gate_inp: _,
791                    experts,
792                } => {
793                    for expert in experts {
794                        sum += expert.feed_forward_w1.activate(&adapter_names)?;
795                        sum += expert.feed_forward_w2.activate(&adapter_names)?;
796                        sum += expert.feed_forward_w3.activate(&adapter_names)?;
797                    }
798                }
799            }
800        }
801        Ok(sum)
802    }
803
804    #[allow(clippy::too_many_arguments)]
805    fn inner_forward(
806        &self,
807        x: &Tensor,
808        start_offsets: &[usize],
809        scalings: Option<Tensor>,
810        is_full_pass: bool,
811        no_kv_cache: bool,
812        is_scaling_pass: Option<f64>,
813        flash_params: &FlashParams,
814    ) -> Result<Tensor> {
815        let mut layer_in = self.tok_embeddings.forward(x)?;
816        let mut cache = if is_full_pass {
817            if no_kv_cache {
818                let mut new_cache = Vec::new();
819                for _ in 0..self.cache.full().xlora_lock().len() {
820                    new_cache.push(None);
821                }
822
823                self.cache.full().xlora_lock().clone_from(&new_cache);
824            }
825            self.cache.full().xlora_lock()
826        } else {
827            self.cache.full().lock()
828        };
829        let mask =
830            CausalMasker.make_causal_mask_matrix(x, &*cache, self.dtype, self.layers[0].n_head)?;
831        for (i, layer) in self.layers.iter().enumerate() {
832            if let Some(ref mapper) = self.mapper {
833                layer_in = mapper.map(layer_in, i)?;
834            }
835            let x = layer_in;
836            let residual = &x;
837            let x = layer.attention_norm.forward(&x)?;
838            let attn = layer.forward_attn(
839                &x,
840                &mask.as_ref().map(|m| m.to_device(x.device()).unwrap()),
841                start_offsets,
842                &mut cache[i],
843                scalings.clone(),
844                self.xlora_classifier
845                    .as_ref()
846                    .map(|classifier| classifier.get_global_scaling_weight())
847                    .unwrap_or(1.0),
848                is_scaling_pass,
849                flash_params,
850            )?;
851            let x = (attn + residual)?;
852
853            // MLP
854            let residual = &x;
855            let x = layer.ffn_norm.forward(&x)?;
856            let x = layer.mlp_or_moe.forward(
857                &x,
858                scalings.clone(),
859                self.xlora_classifier
860                    .as_ref()
861                    .map(|classifier| classifier.get_global_scaling_weight())
862                    .unwrap_or(1.0),
863                is_scaling_pass,
864            )?;
865            let x = (x + residual)?;
866            layer_in = x;
867        }
868        let layer_in = layer_in.to_device(&self.device)?;
869        self.norm.forward(&layer_in)
870    }
871
872    #[allow(clippy::too_many_arguments)]
873    pub fn forward(
874        &self,
875        input_ids: &Tensor,
876        input_ids_full: &Tensor,
877        seqlen_offsets: &[usize],
878        seqlen_offsets_full: &[usize],
879        no_kv_cache: bool,
880        non_granular_state: &Option<NonGranularState>,
881        context_lens: Vec<(usize, usize)>,
882        flash_params: &FlashParams,
883        flash_params_full: &FlashParams,
884    ) -> Result<Tensor> {
885        if self.xlora_classifier.is_some() {
886            let scalings = self.get_scalings(
887                input_ids,
888                input_ids_full,
889                seqlen_offsets,
890                seqlen_offsets_full,
891                no_kv_cache,
892                non_granular_state,
893                &vec![usize::MAX; context_lens.len()],
894                flash_params,
895                flash_params_full,
896            )?;
897
898            if no_kv_cache {
899                extract_logits(
900                    &self.output.lora_forward(
901                        &self
902                            .inner_forward(
903                                input_ids_full,
904                                seqlen_offsets_full,
905                                Some(scalings),
906                                true,
907                                no_kv_cache,
908                                None,
909                                flash_params_full,
910                            )?
911                            .contiguous()?,
912                        None,
913                        1.0,
914                        None,
915                    )?,
916                    context_lens,
917                )
918            } else {
919                // is_full_pass=true is ok because no_kv_cache=false
920                extract_logits(
921                    &self.output.lora_forward(
922                        &self
923                            .inner_forward(
924                                input_ids,
925                                seqlen_offsets,
926                                Some(scalings),
927                                true,
928                                no_kv_cache,
929                                None,
930                                flash_params,
931                            )?
932                            .contiguous()?,
933                        None,
934                        1.0,
935                        None,
936                    )?,
937                    context_lens,
938                )
939            }
940        } else {
941            extract_logits(
942                &self.output.lora_forward(
943                    &self
944                        .inner_forward(
945                            input_ids,
946                            seqlen_offsets,
947                            None,
948                            false,
949                            no_kv_cache,
950                            None,
951                            flash_params,
952                        )?
953                        .contiguous()?,
954                    None,
955                    1.0,
956                    None,
957                )?,
958                context_lens,
959            )
960        }
961    }
962}
963
964impl ScalingsMaker for ModelWeights {
965    fn dtype(&self) -> DType {
966        DType::F32 // for dummy scalings
967    }
968    fn get_cache(&self) -> &EitherCache {
969        &self.cache
970    }
971    fn get_classifier(&self) -> &XLoraClassifier {
972        self.xlora_classifier.as_ref().unwrap()
973    }
974    fn forward(
975        &self,
976        input_ids: &Tensor,
977        seqlen_offsets: &[usize],
978        scalings: Tensor,
979        is_full_pass: bool,
980        no_kv_cache: bool,
981        is_scaling_pass: Option<f64>,
982        _context_lens: &[usize],
983        flash_params: &FlashParams,
984    ) -> Result<Tensor> {
985        self.inner_forward(
986            input_ids,
987            seqlen_offsets,
988            Some(scalings),
989            is_full_pass,
990            no_kv_cache,
991            is_scaling_pass,
992            flash_params,
993        )
994    }
995}